使用自推測解碼加速大模型推理
推測解碼是一種新穎的文本生成方法,它結合了推測解碼 (Speculative Decoding) 的優勢和大語言模型 (LLM) 的提前退出 (Early Exit) 機制該方法出自論文 LayerSkip: Enabling Early-Exit Inference and Self-Speculative Decoding它通過使用 同一個模型 的早期層來生成候選詞元 (token),并使用后期層進行驗證,從而實現高效生成
這項技術不僅加快了文本生成速度,還顯著節省了內存并降低了計算延遲為了實現端到端的加速,早期層的輸出需要與最終層的輸出足夠接近正如論文中所述,這可以通過一種訓練方法來實現,該方法可以在預訓練期間應用,也可以在特定領域進行微調時應用自推測解碼對于實際應用特別高效,它可以在較小的 GPU 上部署,并降低 大規模推理 所需的整體硬件資源-
在本博客中,我們將探討自推測解碼的概念、其實現方式以及在 ?? transformers 庫中的實際應用-您將了解到其技術原理,包括 提前退出層 (Early-Exit Layers) 、 反嵌入 (Unembedding) 和 訓練修改 (Training Modifications)為了將這些概念付諸實踐,我們提供了代碼示例、與傳統推測解碼的基準比較,以及對性能權衡的見解-
您還可以直接查看以下 Hugging Face 資源,了解更多關于該方法的信息并親自嘗試:
Hugging Face 論文討論論壇
LayerSkip 模型集合
展示自推測解碼深入工作原理的 Colab 筆記本
推測解碼與自推測解碼
LayerSkip 演示 GIF
在 facebook/layerskip-llama2-7B 上的 LayerSkip 推理演示 (使用 LayerSkip 方法持續預訓練的 Llama2 7B)
傳統的推測解碼 使用 兩個 模型: 一個較小的模型 (草稿模型) 用于生成一系列候選詞元,一個較大的模型 (驗證模型) 用于驗證草稿的準確性較小的模型執行大部分生成工作,而較大的模型則負責改進結果這提高了文本生成速度,因為較大的模型一次性驗證完整序列,而不是逐個生成詞元-
在自推測解碼中,作者在此概念的基礎上,使用大模型的早期層來生成草稿詞元,然后由模型的更深層進行驗證這種推測解碼的“自洽”特性需要特定的訓練,使模型能夠同時執行草稿生成和驗證這反過來又比傳統的推測解碼提高了速度并降低了計算成本
在 transformers 中的使用
為了在 ?? transformers 庫中啟用提前退出自推測解碼,我們只需在 generate() 函數中添加 assistant_early_exit 參數
以下是一個簡單的代碼片段,展示了該功能:
pip install transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
early_exit_layer = 4
prompt = "Alice and Bob"
checkpoint = "facebook/layerskip-llama2-7B"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
model = AutoModelForCausalLM.from_pretrained(checkpoint).to("cuda")
outputs = model.generate(**inputs, assistant_early_exit=early_exit_layer)
注意: 雖然 assistant_early_exit 參數可以為任何僅解碼器的 transformer 啟用提前退出自推測解碼,但除非模型經過專門訓練,否則無法反嵌入 (通過 LM 頭進行解碼的過程,在博客文章后面有描述) 中間層的 logits只有對檢查點進行這樣的訓練,以提高早期層的準確性,您才能獲得加速LayerSkip 論文提出了一種訓練方法來實現這一點 (即應用提前退出損失,并逐步增加層丟棄率)這里 提供了使用 LayerSkip 訓練方法持續預訓練的 Llama2、Llama3 和 Code Llama 檢查點的集合
基準測試
我們進行了一系列廣泛的基準測試,以衡量 LayerSkip 的自推測解碼相對于自回歸解碼在各種模型上的加速情況我們還將自推測解碼 (基于提前退出) 與標準推測解碼技術進行了比較要復現這些結果,您可以在 這里 找到代碼,并在 此電子表格 中找到運行每個實驗的命令所有實驗均在單個 80GB A100 GPU 上運行,除了 Llama2 70B 實驗在 8 個 A100 GPU 的節點上運行
Llama3.2 1B
Model Variant (模型變體) Layers (層數) Assistant Model (輔助模型) Assistant Layers (輔助層數) Task (任務) Total Layers (總層數) FLOPs/Input (G) (輸入 FLOPs) Time/Input (s) (輸入時間) FLOPs/Output (G) (輸出 FLOPs) Time/Output (s) (輸出時間) Efficiency (效率)
facebook/layerskip-llama3.2-1B 1 Early Exit @ Layer 4 summarization 1 1195.28 9.96 2147.7 17.9 1.80
Llama3 8B
Model Variant (模型變體) Layers (層數) Assistant Model (輔助模型) Assistant Layers (輔助層數) Task (任務) Total Layers (總層數) FLOPs/Input (G) (輸入 FLOPs) Time/Input (s) (輸入時間) FLOPs/Output (G) (輸出 FLOPs) Time/Output (s) (輸出時間) Efficiency (效率)
meta-llama/Meta-Llama-3-8B 8 meta-llama/Llama-3.2-1B 1 summarization 9 1872.46 19.04 2859.35 29.08 1.53
meta-llama/Meta-Llama-3-8B 8 meta-llama/Llama-3.2-3B 3 summarization 11 2814.82 28.63 2825.36 28.73 1.00
facebook/layerskip-llama3-8B 8 Early Exit @ Layer 4 summarization 8 1949.02 15.75 3571.81 28.87 1.83
Llama2 70B
Model Variant (模型變體) Layers (層數) Assistant Model (輔助模型) Assistant Layers (輔助層數) Task (任務) Total Layers (總層數) FLOPs/Input (G) (輸入 FLOPs) Time/Input (s) (輸入時間) FLOPs/Output (G) (輸出 FLOPs) Time/Output (s) (輸出時間) Efficiency (效率)
meta-llama/Llama-2-70b-hf 70 meta-llama/Llama-2-13b-hf 13 summarization 83 5036.54 46.3 12289.01 112.97 2.44
meta-llama/Llama-2-70b-hf 70 meta-llama/Llama-2-7b-hf 7 summarization 77 4357.55 40.06 12324.19 113.3 2.83
meta-llama/Llama-2-70b-hf 70 TinyLlama/TinyLlama_v1.1 1 summarization 71 4356.21 40.05 12363.22 113.66 2.84
facebook/layerskip-llama2-70B 70 Early Exit @ Layer 10 summarization 70 6012.04 54.96 1283.34 113.2 2.06
Llama2 13B
Model Variant (模型變體) Layers (層數) Assistant Model (輔助模型) Assistant Layers (輔助層數) Task (任務) Total Layers (總層數) FLOPs/Input (G) (輸入 FLOPs) Time/Input (s) (輸入時間) FLOPs/Output (G) (輸出 FLOPs) Time/Output (s) (輸出時間) Efficiency (效率)
meta-llama/Llama-2-13b-hf 13 meta-llama/Llama-2-7b-hf 7 summarization 20 3557.07 27.79 4088.48 31.94 1.15
meta-llama/Llama-2-13b-hf 13 TinyLlama/TinyLlama_v1.1 1 summarization 14 2901.92 22.67 4190.42 32.74 1.44
meta-llama/Llama-2-13b-hf 13 apple/OpenELM-270M 0.27 summarization 13.27 2883.33 22.53 4521.12 35.32 1.57
meta-llama/Llama-2-13b-hf 13 apple/OpenELM-450M 0.45 summarization 13.45 3267.69 25.53 4321.75 33.76 1.32
facebook/layerskip-llama2-13B 13 Early Exit @ Layer 4 summarization 13 4238.45 33.11 4217.78 32.95 0.995
facebook/layerskip-llama2-13B 13 Early Exit @ Layer 8 summarization 13 2459.61 19.22 4294.98 33.55 1.746
Llama2 7B
Model Variant (模型變體) Layers (層數) Assistant Model (輔助模型) Assistant Layers (輔助層數) Task (任務) Total Layers (總層數) FLOPs/Input (G) (輸入 FLOPs) Time/Input (s) (輸入時間) FLOPs/Output (G) (輸出 FLOPs) Time/Output (s) (輸出時間) Efficiency (效率)
meta-llama/Llama-2-7b-hf 7 TinyLlama/TinyLlama_v1.1 1 summarization 8 2771.54 21.65 3368.48 26.32 1.22
meta-llama/Llama-2-7b-hf 7 apple/OpenELM-270M 0.27 summarization 7.27 2607.82 20.37 4221.14 32.98 1.62
meta-llama/Llama-2-7b-hf 7 apple/OpenELM-450M 0.45 summarization 7.45 3324.68 25.97 4178.66 32.65 1.26
facebook/layerskip-llama2-7B 7 Early Exit @ Layer 4 summarization 7 2548.4 19.91 3306.73 25.83 1.297
我們可以觀察到以下幾點:
從“ 總參數數量”列可以看出,自推測解碼消耗的內存更少,因為它不需要單獨的草稿模型,并且草稿階段層的權重被重用
對于除 Llama2 70B 之外的所有模型大小和生成,提前退出自推測解碼比常規的兩模型推測解碼更快
與其它模型相比,Llama2 70B 的自推測解碼速度提升相對有限,可能有不同的原因,例如,Llama2 70B 的 LayerSkip 檢查點持續預訓練的 token 較少 (Llama2 70B 為 328M token,而 Llama2 7B 為 52B token)但這是未來研究需要改進的一個方面盡管如此,70B 的自推測解碼明顯快于自回歸解碼
自生成和自驗證
自推測解碼過程從自生成開始,其中詞元是通過從某個中間層提前退出來生成的推測詞元的數量定義了在此階段生成多少草稿詞元,而我們退出的層定義了草稿階段的規模和準確性-這兩個參數都可以在推理時根據草稿階段的速度和準確性之間的權衡來指定
下一步是自驗證,其中使用完整模型來驗證草稿詞元驗證模型重用草稿模型中的緩存部分如果草稿詞元與驗證的詞元一致,則將它們添加到最終輸出中,從而更好地利用我們系統中的內存帶寬,因為使用完整模型生成一系列詞元比驗證草稿要昂貴得多,只要有幾個詞元匹配即可
在自驗證階段,只有剩余的層才會被計算以進行驗證,因為早期層的結果在草稿階段已被緩存-
提前退出和反嵌入
自推測解碼中的一項關鍵技術是提前退出,即生成過程可以在預先指定的層停止為了實現這一點,我們通過將這些層的 logits 投影到語言模型 (LM) 頭上來反嵌入它們,以預測下一個詞元這允許模型跳過后續層并提高推理時間
可以在任何 transformer 層執行反嵌入,將提前退出轉變為一種高效的詞元預測機制一個自然而然的問題出現了: 當 LM 頭最初被訓練為僅與最終層一起工作時,如何使其適應反嵌入較早層的 logits?這就是訓練修改發揮作用的地方
訓練修改
在訓練階段,我們引入了層丟棄,它允許模型在訓練期間跳過某些層丟棄率在較深的層中逐漸增加,使模型不太依賴其后面的層,并增強模型的泛化能力并加快訓練速度
除了層丟棄之外,還應用了提前退出損失,以確保 LM 頭學習反嵌入不同的層使用每個出口 (中間層) 的歸一化損失的總和來給出使用提前出口訓練模型的總損失函數這種技術通過在所有層之間分配學習任務來實現高效訓練
優化: 共享權重、共享 KV 緩存和共享計算
自推測解碼顯著受益于緩存重用,特別是 KV 緩存,它存儲在草稿階段計算的鍵值對此緩存允許模型跳過冗余計算,因為草稿和驗證階段都使用相同的早期層此外,退出查詢緩存存儲來自退出層的查詢向量,允許驗證從草稿階段無縫繼續
與傳統的雙模型推測解碼相比,提前退出自推測解碼可以從以下節省中受益:
共享權重: 為草稿和驗證重用前 E 層 的權重
共享 KV 緩存: 為草稿和驗證重用前 E 層的鍵值對
共享計算: 通過使用僅保存退出層 E-1 的查詢向量的退出查詢緩存來重用前 E 層的計算,以便驗證過程無需計算層 0 到 E-1
KV 和退出查詢緩存的組合稱為 KVQ 緩存,可減少內存開銷并提高推理延遲
到目前為止,?? transformers 庫已在此 pull request 中實現了第一個優化 (共享權重)隨著使用此方法的模型數量增加,我們將考慮其他優化如果您有興趣,請隨時提出 PR!
提前退出層的選擇策略
草稿階段的提前退出層是一個超參數,我們可以在推理期間調整或修改:
我們越早退出,生成草稿詞元的速度就越快,但它們的準確性就越低
我們越晚退出,生成的草稿詞元就越準確,但它們的速度就越慢
我們編寫了一個腳本來遍歷不同的提前退出層并測量 A100 GPU 上的每秒詞元數在下面的表格中,我們繪制了針對不同 Llama 模型的 LayerSkip 和基線檢查點的每秒詞元數與提前退出層的關系圖 (您可以在 此處 查看完整日志)
Llama3.2 1B
Normal (常規模型) LayerSkip (LayerSkip 模型)
llama 3.2 1b layer skip llama 3.2 1b
Llama3 8B
Normal (常規模型) LayerSkip (LayerSkip 模型)
llama 3 8b layer skip llama 3 8b
Code Llama3 34B
Normal (常規模型) LayerSkip (LayerSkip 模型)
code llama 3 34b code layer skip llama 3 34b
Code Llama3 7B
Normal (常規模型) LayerSkip (LayerSkip 模型)
code llama 3 7b code layer skip llama 3 7b
Llama2 70B
Normal (常規模型) LayerSkip (LayerSkip 模型)
llama 2 70b layer skip llama 2 70b
Llama2 13B
Normal (常規模型) LayerSkip (LayerSkip 模型)
llama 2 13b layer skip llama 2 13b
Llama2 7B
Normal (常規模型) LayerSkip (LayerSkip 模型)
llama 2 7b layer skip llama 2 7b
我們可以觀察到以下幾點:
對于沒有使用 LayerSkip 訓練方法進行預訓練或持續預訓練的基線檢查點,提前退出自推測解碼比自回歸解碼更慢這是因為在大多數 LLM 的訓練過程中,早期層并沒有被激勵去學習預測輸出,因此使用早期層生成詞元的接受率會非常低
另一方面,對于使用 LayerSkip 訓練方法持續預訓練的 Llama 檢查點,提前退出自推測解碼在至少一部分層中比自回歸解碼具有更高的加速比
對于大多數模型 (除了 Llama3.2 1B),當我們遍歷各層時,我們注意到一個規律模式: 加速比在前幾層較低,逐漸增加到一個最佳點,然后再次下降
提前退出層的最佳點是在預測的高準確性和生成詞元的低開銷之間達到最佳權衡時-這個最佳點取決于每個模型,也可能取決于提示或提示的領域
這些觀察為進一步的實驗和探索提供了有趣的機會-我們鼓勵讀者在這些想法的基礎上進行構建,測試變體,并進行自己的研究這些努力可以帶來有價值的見解,并為該領域做出有意義的貢獻
結論
LayerSkip 利用提前退出、層丟棄和緩存重用之間的協同作用,創建了一個快速高效的文本生成流程通過訓練模型從不同層反嵌入輸出,并使用緩存優化驗證過程,這種方法在速度和準確性之間取得了平衡因此,它顯著改善了大語言模型的推理時間,同時保持了高質量的輸出由于使用單個模型作為草稿和驗證模型,它還比傳統的推測解碼技術減少了內存使用
自推測是一個令人興奮的領域,同一個 LLM 可以創建草稿詞元并自我修正其他自推測方法包括:
Draft & Verify: 其中草稿階段涉及跳過預定的注意力和前饋層
MagicDec: 其中草稿階段使用 KV 緩存的子集,這對長上下文輸入很有用
Jacobi Decoding 和 Lookahead Decoding: 其中草稿階段是一系列“猜測詞元”,可以是隨機的或從 n-gram 查找表中獲得的
浙公網安備 33010602011771號