[TIL][240603] Modal + axolotlでFine-tuning

kzinmr
·
公開:2024/6/4

効率的なFine-tuningワークロードをこなす目的で、 Modal + axolotl のPJを触ってみた: https://github.com/modal-labs/llm-finetuning

  • Modal: サーバレスのGPU計算PF、オンデマンドでGPUが使用できる(クラウド破産防止によい)。バッチ計算にも推論サービングにも対応。サンプルでは以下との統合を使用:

    • HuggingFace: 基盤モデル提供PF、モデルアクセス権等を管理

      • ※ mistralai/Mistral-7B-v0.1 は利用申請が必要で、`HF_TOKEN`の権限設定でモデルアクセス設定が必要

    • Weights & Biases: モデル学習loggingのPF

  • axolotl: LLMファインチューニングに特化したFW(huggingface PEFT等のラッパー)。学習設定をYAMLで宣言的に定義し、学習・推論を透過的かつE2Eに実行できる。

学習の動作テストのために少数サンプルを暗記(memorization)

mistral-memorize.yml では、過学習させる設定が書かれている:

  • サブサンプルした少数データ(small batch)に対して

  • 正則化をoffにし (lora_dropout: 0.0)

  • 十分多いエポックの学習を回す (epoch=50)

これは、プログラムに現れない学習時のエラー(勾配計算や数値計算のミス、ラベルのエラー等)を検出する目的で行われる。過学習した内容に対して、モデル推論やデプロイまで検証することももちろん可能。

このプラクティスについての解説記事: https://fullstackdeeplearning.com/course/2022/lecture-3-troubleshooting-and-testing/#23-use-memorization-testing-on-training

プロンプトや学習データの検証

特にtokenizationで罠にハマりやすいという教訓から、学習処理に進む前に前処理だけ行ってデータのdecoding結果が正しいかどうかをチェックするプラクティスが推奨されている。

特に学習時と推論時で、プロンプト構築の方法がずれていたりすると、それに引きずられてtokenization結果がズレてモデルが壊れたりしやすいことを解説した記事: https://hamel.dev/notes/llm/finetuning/05_tokenizer_gotchas.html

このステップについて、Modalの都合上、以下の手順で実行されている:

  1. RUN_TAGを指定してmodal volumeからdownload

  2. huggingfaceのdatasetsを用いて、download結果をロード

  3. huggingfaceのtokenizerを用いて、ロード結果をdecode

※ notebookはそのままでは正しく動かず、axolotlのtokensの追加処理を参考に処理追加したらdecodingが正しく動いた: https://github.com/OpenAccess-AI-Collective/axolotl/blob/05b0bd08d229ee28cd3f11098d5b178f2ce441b6/src/axolotl/utils/models.py

※ スクリプトでワンステップ化したいところだが、CLIとpythonを行き来したり、途中結果の誤りも知りたい都合もあるためnotebookバッチがいいのかな? またこの検証を行うためにModalに渡すHF_TOKENは結局メモしておく必要があった。

LoRAとモデルマージ

LoRAでファインチューニングを行っており、Modalのお陰でシームレスに複数GPU(デフォルトnum_processes=2)で学習される。学習結果は勝手に基盤モデルにモデルマージされ、そのまま推論に使える形で保存される。

LoRAについても軽く調べておさらいした。

LoRAの基本ステップ: https://discuss.huggingface.co/t/help-with-merging-lora-weights-back-into-base-model/40968/3

  1. load the base model

  2. train the base model

  3. save the LoRA adapter

  4. reload the base model at half/full precision

  5. merge the LoRA weights with the base model

  6. save

※ 5は実はオプショナルという話も: https://www.reddit.com/r/LocalLLaMA/comments/17b5dgw/hotswapping_loras_during_inference/https://github.com/predibase/lorax

Q1. そもそもなぜモデルマージするのか?: https://speakerdeck.com/iwiwi/17-nlpkorokiumu?slide=19

  • モデルアンサンブル:関心あるタスクが共通の場合

    • 余談Q. モデルマージの結果がLLM Leaderboardで高スコア→過学習 ?

      • Leaderboardスコアを見てマージするため

  • モデル能力拡張:関心あるタスクが直交する場合

Q2. そもそものLoRA "Adapter" という言葉について: https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora

  • LoRA提唱の時期には Adapter tuning手法とは区別されていたが: https://arxiv.org/abs/1902.00751v2

  • 今どきはLoRAはAdapter-based methodsという方法のひとつと位置づけられている

    • > Adapterとは、事前学習済みの凍結したLLMに少数の学習可能なパラメータを追加する、パラメータ効率の高い微調整手法を指す

  • LoRAの2つの更新行列が「Adapter」に相当

  • LoRAは逐次でなく並列処理なアダプターであることも利点

vLLMによる推論

高速なLLM推論バックエンドとして知られるvLLM。今回はあまり深堀りできていないが、サンプルは以下サイトで解説されている: https://modal.com/docs/examples/vllm_inference#fast-inference-with-vllm-mistral-7b

雑感

  • inference.py の先頭トークンが出力されないバグ

    • [SQL]から始まることを期待するが: `🤖: SELECT name, born_state, age FROM head ORDER BY name [/SQL]`

    • mistral-memorize.ymlでもmistral.ymlでも再現

    • 推論側 (vLLM) が怪しそう…?

      • vLLMの `CompletionOutput.token_ids=` の値は間違ってる

      • `skip_special_tokens=False` の実行結果でも `<s>` が記録されない

  • Modalの学習実行で、detach可能なオプションが知りたい(screen/tmux的なのを期待)

  • axolotl (DeepSpeed?) の複数GPUはどういった並列性で使われている?データ並列?モデル並列?

    • => FSDP: モデルの層を複数のGPUに分けて学習、メモリ制約を克服

  • `PAD: 2 / </s>` で特殊トークンが使い回されているがそういうものか

Next Step