効率的なFine-tuningワークロードをこなす目的で、 Modal + axolotl のPJを触ってみた: https://github.com/modal-labs/llm-finetuning
Modal: サーバレスのGPU計算PF、オンデマンドでGPUが使用できる(クラウド破産防止によい)。バッチ計算にも推論サービングにも対応。サンプルでは以下との統合を使用:
HuggingFace: 基盤モデル提供PF、モデルアクセス権等を管理
※ mistralai/Mistral-7B-v0.1 は利用申請が必要で、`HF_TOKEN`の権限設定でモデルアクセス設定が必要
Weights & Biases: モデル学習loggingのPF
axolotl: LLMファインチューニングに特化したFW(huggingface PEFT等のラッパー)。学習設定をYAMLで宣言的に定義し、学習・推論を透過的かつE2Eに実行できる。
学習の動作テストのために少数サンプルを暗記(memorization)
mistral-memorize.yml では、過学習させる設定が書かれている:
サブサンプルした少数データ(small batch)に対して
正則化をoffにし (lora_dropout: 0.0)
十分多いエポックの学習を回す (epoch=50)
これは、プログラムに現れない学習時のエラー(勾配計算や数値計算のミス、ラベルのエラー等)を検出する目的で行われる。過学習した内容に対して、モデル推論やデプロイまで検証することももちろん可能。
このプラクティスについての解説記事: https://fullstackdeeplearning.com/course/2022/lecture-3-troubleshooting-and-testing/#23-use-memorization-testing-on-training
プロンプトや学習データの検証
特にtokenizationで罠にハマりやすいという教訓から、学習処理に進む前に前処理だけ行ってデータのdecoding結果が正しいかどうかをチェックするプラクティスが推奨されている。
特に学習時と推論時で、プロンプト構築の方法がずれていたりすると、それに引きずられてtokenization結果がズレてモデルが壊れたりしやすいことを解説した記事: https://hamel.dev/notes/llm/finetuning/05_tokenizer_gotchas.html
このステップについて、Modalの都合上、以下の手順で実行されている:
RUN_TAGを指定してmodal volumeからdownload
huggingfaceのdatasetsを用いて、download結果をロード
huggingfaceのtokenizerを用いて、ロード結果をdecode
※ notebookはそのままでは正しく動かず、axolotlのtokensの追加処理を参考に処理追加したらdecodingが正しく動いた: https://github.com/OpenAccess-AI-Collective/axolotl/blob/05b0bd08d229ee28cd3f11098d5b178f2ce441b6/src/axolotl/utils/models.py
※ スクリプトでワンステップ化したいところだが、CLIとpythonを行き来したり、途中結果の誤りも知りたい都合もあるためnotebookバッチがいいのかな? またこの検証を行うためにModalに渡すHF_TOKENは結局メモしておく必要があった。
LoRAとモデルマージ
LoRAでファインチューニングを行っており、Modalのお陰でシームレスに複数GPU(デフォルトnum_processes=2)で学習される。学習結果は勝手に基盤モデルにモデルマージされ、そのまま推論に使える形で保存される。
LoRAについても軽く調べておさらいした。
LoRAの基本ステップ: https://discuss.huggingface.co/t/help-with-merging-lora-weights-back-into-base-model/40968/3
load the base model
train the base model
save the LoRA adapter
reload the base model at half/full precision
merge the LoRA weights with the base model
save
※ 5は実はオプショナルという話も: https://www.reddit.com/r/LocalLLaMA/comments/17b5dgw/hotswapping_loras_during_inference/ や https://github.com/predibase/lorax
Q1. そもそもなぜモデルマージするのか?: https://speakerdeck.com/iwiwi/17-nlpkorokiumu?slide=19
モデルアンサンブル:関心あるタスクが共通の場合
余談Q. モデルマージの結果がLLM Leaderboardで高スコア→過学習 ?
Leaderboardスコアを見てマージするため
モデル能力拡張:関心あるタスクが直交する場合
直交するタスクなら、モデルマージが有効であることの数理的解釈:
Q2. そもそものLoRA "Adapter" という言葉について: https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora
LoRA提唱の時期には Adapter tuning手法とは区別されていたが: https://arxiv.org/abs/1902.00751v2
今どきはLoRAはAdapter-based methodsという方法のひとつと位置づけられている
> Adapterとは、事前学習済みの凍結したLLMに少数の学習可能なパラメータを追加する、パラメータ効率の高い微調整手法を指す
LoRAの2つの更新行列が「Adapter」に相当
LoRAは逐次でなく並列処理なアダプターであることも利点
vLLMによる推論
高速なLLM推論バックエンドとして知られるvLLM。今回はあまり深堀りできていないが、サンプルは以下サイトで解説されている: https://modal.com/docs/examples/vllm_inference#fast-inference-with-vllm-mistral-7b
雑感
inference.py の先頭トークンが出力されないバグ
[SQL]から始まることを期待するが: `🤖: SELECT name, born_state, age FROM head ORDER BY name [/SQL]`
mistral-memorize.ymlでもmistral.ymlでも再現
推論側 (vLLM) が怪しそう…?
vLLMの `CompletionOutput.token_ids=` の値は間違ってる
`skip_special_tokens=False` の実行結果でも `<s>` が記録されない
Modalの学習実行で、detach可能なオプションが知りたい(screen/tmux的なのを期待)
axolotl (DeepSpeed?) の複数GPUはどういった並列性で使われている?データ並列?モデル並列?
=> FSDP: モデルの層を複数のGPUに分けて学習、メモリ制約を克服
最初の数層を最初のGPUで実行し、次の数層を2番目のGPUで実行
例; 70b(140GB) のモデルを24GBのGPU8枚に分散し、それぞれに17.5GBを使うことができる
from: https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html
`PAD: 2 / </s>` で特殊トークンが使い回されているがそういうものか
Next Step
別のモデルでのファインチューニング
別のデータセット形式でのファインチューニング: https://openaccess-ai-collective.github.io/axolotl/docs/dataset-formats/inst_tune.html
axolotlのdebugging: https://openaccess-ai-collective.github.io/axolotl/docs/debugging.html#general-tips
vLLMをいじる