探秘Transformer系列之(32)--- Lookahead Decoding
探秘Transformer系列之(32)--- Lookahead Decoding
0x00 概述
投機采樣的范式是predict+verify,而另外還有一種思路是基于Jacobi迭代構建的Jacobi decoding及其演化分支。
Jacobi 迭代把自回歸的N次迭代轉換為N個方程,然后聯合求解。而 Jacobi Decoding 將每次迭代上一次的輸出整體作為下一次的輸入,其實就是把每一個 token 上的輸出視作一個 2-gram,并以此作為Draft Model。假設\(\mathbf{y}_i\)是長度為m的待預測序列,Jacobi Decoding 從隨機預測\(\mathbf{y}_0\)開始,不停地自回歸迭代,最多迭代m次能全部命中。論文“Break the Sequential Dependency of LLM Inference Using Lookahead Decoding”的作者想到,如果可以記錄下更多的歷史信息,就可以制造一個 N-gram 作為 Draft Model,這樣就能提高 Speculative Decoding 的準確率。這就是Lookahead Decoding。簡要來說,Lookahead=N-gram+Jacobi iteration+parallel verification,其利用 jacobi 迭代法同時提取和驗證 n-grams,打破自回歸解碼的順序依賴性,從而降低解碼次數,實現推理加速。相比之前的并行解碼,Lookahead Decoding即不需要草稿模型,也不需要像Medusa那樣微調head。論文作者將 Jacobi Decoding 視為Lookahead Decoding在 2-gram 情況下的特例。

0x01 Jacobi decoding
Jacobi decoding算法最早的工作來自于論文Accelerating Transformer Inference for Translation via Parallel Decoding。Jacobi Decoding類似RNN的迭代,從初始序列\(\mathbf{y}_0\)開始,最多經過m次迭代,預測出長度為m的序列\(\mathbf{y}_i\)。
1.1 動機
常規(guī)的自回歸解碼過程如下圖所示,解碼過程相當于每次都將上一式解出之,后帶入下一式。

Jacobi decoding則是基于jacobi 迭代式,將自回歸這N次迭代轉換為N個有關輸入輸出(x,y)的方程,這樣出發(fā)點就變成是把自回歸解碼的過程看作是聯立以下方程來求方程組解的問題。

1.2 思路
Jacobi decoding直接使用自行迭代的方法尋找方程組的解。即首先隨機指定一組初始解y,然后根據自回歸方程和初始解進行計算來更新y,重復以上過程直至達到迭代停止條件。m元方程組至多m次可求得精確解。具體原理如下圖所示,LLM 輸出(在 Greedy Decoding 下)是一個不動點,通過 LLM 不斷的自我迭代能用更少的次數找到尋找到方程組的解(不動點/fixed point)。

我們把自回歸解碼和Jacobi decoding對比如下圖所示。左邊的自回歸解碼是串行,右面的Jacobi算法變?yōu)橐淮涡越獯am個token,但是允許多步迭代。與自回歸解碼相比,每個Jacobi decoding步驟在所需的計算量要大,因為它需要對 >1 個token同時調用大模型進行前向計算,但由于 GPU 的并行處理特性,這通常不會導致速度變慢。

Jacobi decoding的一大缺陷是只適用于 Greedy Decoding。貪婪解碼可以保證每次迭代至少能獲得一個穩(wěn)定的 token ,因而所需步驟一定不大于自回歸解碼所需步驟。
1.3 算法
論文采用如下的算法進行迭代,首先隨機初始化一個長度為m的輸出token,隨后隨著x的輸入,不斷更新這個輸出token(選擇概率最大的輸出),直到前后兩次迭代輸出token一致。其中循環(huán)迭代的每一步類似prefill操作,速度類似解碼單個token的速度。

雖然Jacobi decoding實際加速效果并不十分明顯,但給后續(xù)其它工作帶來了很大的啟發(fā)。
我們接下來進行Lookahead Decoding的學習。
0x02 原理
2.1 思路
Jacobi decoding和之前的模型存在如下問題:
- 自回歸解碼的耗時與解碼步數成正比,并且無法很好地利用加速器的并行能力。
- 投機解碼及其變體算法無法保證draft token的接受率,通過訓練draft model來提升 draft token 接受率則成本高且通用性差。
- 雖然Jacobi decoding可以通過許多步驟來解碼多個token,但因為其隨機生成初始解,導致迭代過程中接受率很低,因此加速效果較差。
- 有時,Jacobi decoding 預測的token序列片段是正確的,但是這些序列出現的位置不正確,需要花費好幾個周期進行修正,把序列移動到正確的位置上,walltime加速效果受到影響。
Lookahead Decoding 便嘗試解決上述問題,其本質上是利用n-gram pool的記憶性來建模子序列片段,作為候選子序列送入驗證階段,同時擴展到了N步依賴,這樣可以一次性生成多個token。

2.1.1 出發(fā)點
Lookahead Decoding 算法的核心是基于 Jacobi Decoding 過程中產生的 Jacobi Trajectory 來生成 N-gram。
對于長度為m的目標序列\(\mathbf{y}_i\),最壞情況需要經歷\(\mathbf{y}_1\)到\(\mathbf{y}_m\)的完整軌跡才能最終全部生成,此時就退化成自回歸解碼,甚至可能略慢一點,因為每一步同時推理m個token。雖然Jacobi decoding隨機初始化的tokens可能都不會被接受,但是序列每個位置的每個新token在解碼過程中都會形成一組Jacobi Trajectory(雅可比迭代軌跡)。我們可以利用初始化tokens和Jacobi Trajectory來構造一系列n-gram,這些 n-gram 可能會在后面的解碼步驟中被使用,從而加速解碼過程。
在Vanilla Jacobi Decoding里的\(\mathbf{y_i}\)的生成只依賴\(\mathbf{y}_{i-1}\),Lookahead Decoding 嘗試把依賴擴展到前N-1步。比如我們在當前解碼后回溯3個迭代輪次,就會構成一組每個位置的3-grams。Lookahead Decoding會在迭代中緩存這些n-grams,在執(zhí)行Jacobi decoding的同時也并行驗證緩存中的n-grams。接受一個N-grams使得我們一次推進N個token。這樣就通過并行生成N-grams的能力克服了Jacobi decoding的缺陷。
2.1.2 并行
為了加速解碼過程,每個Lookahead Decoding步驟被分為兩個并行分支:生成n-gram的lookahead分支和驗證n-gram的verification分支,兩者都在一個前向傳播過程中執(zhí)行。
- Lookahead(前瞻)分支:這是原始雅可比解碼的過程。因為不一致性的問題,此過程不會被用作主要投機驗證的機制,而是作為一種采樣收集或者說生成 n-gram 的并行解碼過程。Lookahead 分支的目的是生成新的 N-Grams,加上其中新生成的 Token 就可以用于構建下一次 Verify 分支的候選序列。
- Verification(驗證)分支:這個分支從n-gram集合中匹配的多個candidates作為投機驗證輸入,完成具體的投機采樣過程。verification分支會選擇并驗證有希望的 n-gram ,并且會將其用于更新下一次 Lookahead 分支的序列。
2.1.3 數據結構和超參數
2D Window
與最Jacobi解碼(只使用最后一步的歷史token,或等效地生成2-gram)不同,Lookahead Decoding通過使用n-1個過去步驟的歷史token并行生成許多n≥2的n-gram,有效地利用了軌跡中的更多信息。因此,Lookahead Decoding 的關鍵設計是跟蹤Jacobi解碼的雅可比迭軌跡,并從該軌跡生成新的n-gram。這是通過維護一個大小不固定的2D window來實現的。2d window由兩個重要參數定義,分別是代表序列維度的window大小W和代表時間維度的n-gram大小N,以并行地從Jacobi迭代軌跡生成多個不相交的n-grams。
- W:W用作展望,是希望向前生成的 tokens 的數量,就是在未來的 Token 位置上再向前多遠可以并行解碼(Lookahead Decoding增加的計算量與 W 成正比,因此要設置 W 的上限來控制計算成本)。
- N:N用作回溯,即看多少步之前的 Jacobi 迭代來檢索 N-Gram。當N=2時,Lookahead Decoding退化為Jacobi decoding。
因此,2d window一共由W列,N-1行。解碼時會額外解碼 W 個 tokens,與這 N-1 行湊成 W 個 N-grams。另外,從論文作者博客的圖來看,第一個 window size 是包含了 input_ids 的最后一個在內的。2d window 具體如下圖所示,橫軸W就是每次Jacobi Decoding窗口的長度,縱軸就是歷史每一步的\(\mathbf{y}_i\)的生成結果。每步迭代生成W個N-gram。

n-gram pool
為了提高效率,Lookahead Decoding引入了一個n-gram池來緩存到目前為止,所有沿軌跡生成的歷史 n-gram。這些 n-gram 候選者稍后會通過驗證分支進行驗證,以保持LLM的輸出分布;如果通過驗證,那些不相交的n-gram將被整合到序列中。這樣,Lookahead Decoding可以通過利用自回歸解碼未使用的計算資源來顯著減少LLM推理的延遲。
Guess set size
為了限制 N-gram pool 的大小,論文作者引入第三個超參 G,代表 Guess set size,即每個 key 最多對應 G 個 N-gram,并以 LRU 策略進行更新。
2.1.4 總覽圖
下圖給出了Lookahead 總覽。藍色 0 指的是 prompt 與之前已確定輸出的最后一位,即當前步 t 的輸入。這里取 window size W=5 ,N-gram size N=4 ,verification 數量 V=2 。橙色對應之前 t-3 步的結果、綠色對應之前 t-2 步的結果、紅色對應之前 t-1 步的結果。每個 Token 上的數字表示其與當前輸入 Token(藍色 0 )的相對位置。
對于Lookahead 分支來說,在當前階段,我們遵循前3個步驟形成的軌跡,執(zhí)行修改后的Jacobi迭代算法,為所有5個位置生成新的token。生成后,我們將它們收集并緩存在n-gram pool中(n=4)——例如,4-gram由位置1處的橙色token、位置2處的綠色token、位置3處的紅色token和新生成的token組成。兩個維度(時間和序列)中最過時的token將被刪除,新生成的token將附加到 Lookahead 分支,以保持每個步驟的固定窗口大小。例如,我們將刪除圖中的所有橙色和位置1的綠色token。然后,我們用索引為2、3、4、5的綠色token、所有紅色token和下一步新生成的所有token形成一個新的 Lookahead 分支。這里要注意依賴關系,例如紅6依賴綠5和橙色的的所有token。
Verification Branch 選取樣本的方案很簡單,是直接在 N-gram Pool 里選取第一位是 藍色 token 最后一位的 N-gram。這其中驗證之后被接受的即可作為本次的輸出。

2.2 示例
我們用一個示例來展示下Lookahead Decoding:給定輸入"ABC",要預測英文字母表。
如果是自回歸解碼方案,生成流程會是如下:\(ABC \rightarrow ABCD\rightarrow ABCDE \rightarrow ABCDEF\)。
如果是Lookahead Decoding,則流程如下:
- 假設現在n-gram pool為4-gram pool,其中包括:CDEF,CDFE,CDFG三個4-gram。
- 將n-gram候選加入現有序列進行驗證,即輸入序列為:[ABC] [DEF, DFE, DFG],拼接得到ABCDEFDFEDFG,計算得到輸出ABCDEF。
- 現在已經一次性通過驗證n-gram得到DEF,實現了并行解碼。但是,這些需要驗證的n-gram從哪里來?或者說怎么繼續(xù)生成?因此需要加入2d window 這個數據結構,其用來生成n-gram序列。2d window內容是FGH,FGE,FGJ三個。把這三個也填到輸入序列。
- 因此,輸入序列由三部分組成:[現有序列,用于生成和維護n-gram的序列,用于驗證n-gram的序列],這樣才能實現并行的循環(huán)迭代。新的輸入序列是:[ABC] [FGH, FGE, FGJ] [DEF, DFE, DFG],拼接之后是ABCFGHFGEFGJDEFDFEDFG。[FGH, FGE, FGJ]會并行生成新的n-gram為后續(xù)驗證服務,[DEF, DFE, DFG]會驗證n-gram。
- 假設DEF 被接受,接下來新的輸入序列就是 [ABCDEF] [XXX,即由FGH, FGE, FGJ生成的新n-gram對應的行] [FGH, FGE, FGJ],計算會預測得到ABCDEFGH。
我們接下來會對這個流程再進行詳細解讀。
0x03 實現
我們接下來基于llama.cpp和論文作者提供的原始代碼(后文簡稱原始代碼)來看看Lookahead Decoding的一些具體技術細節(jié)。我們在示例中,對超參數做如下設置:
- N=4,所以2D Window有3行
- W=5,所以2D Window有5列,即每步可以收集5個n-gram。
- G=2,所以可以從N-gram pool中最多找出兩個匹配上的序列。
下圖中,上方是論文中的原始圖,下方是筆者基于原始圖做的解讀。

3.1 mask
由于 LLM 解碼主要受內存帶寬限制,因此我們在同一步驟中合并前瞻和驗證分支,利用 GPU 的并行處理能力來隱藏開銷。 mask就是并行解碼的關鍵。本示例的mask具體如下圖所示。圖中標記為 0 的藍色 Token 表示當前步 t 輸入。橙色對應之前 t-3 步的結果、綠色對應之前 t-2 步的結果、紅色對應之前 t-1 步的結果。每個 Token 上的數字表示其相對當前輸入 Token(藍色 0 )的位置。
該掩碼遵循兩個規(guī)則:
- Lookahead Branch 與 Verification Branch 中的 tokens 互相不可見。舉例,對于verification序列的最后一個3,它只能看到輸入的藍色token 0,和它前面的天藍色2,3。
- 每個token只能看到它前面的token和它自己,就像causal mask那樣。例如綠色 token 5(圖上紫色圈)只對 紅色 token 6(圖上天藍色圈) 可見;而紅色token 6(圖上天藍色圈)只能看到綠色token 5(圖上紫色圈)和藍色token 1(圖上紅色圈),橙色token1~token4(圖上綠色圈)。
在每個時刻t,我們利用前N-1步軌跡,執(zhí)行Jaccobi迭代生成 window size=5個位置的token,從而得到同一位置的多組N-gram,例如在當前輸入token位置的藍0-綠1-紅2。在驗證時,首先通過字符串匹配識別出第一個token與最后一個 input token匹配的 n-gram,將識別到的n-gram添加當前輸入后,并通過 LLM 正向傳遞對其進行驗證,從而一次生成N個token。

3.2 推理
3.2.1 推理序列
推理序列指的是在解碼過程中,發(fā)送給LLM進行推理的batch。推理序列分為三部分:
- 輸入token。或者說是prompt,即當前已經生成的 token。
- lookahead序列。從2D window中提取出來的所有行構成了一個序列。
- verification序列。拼接的 guess tokens,就是從N-gram中提取出來G個匹配的序列拼接在一起。每個N-gram是\(g_i\),其中\(g_i^{k}\)代表了\(g_i\)的第k個位置的token。
總結來說,每個解碼步驟中最終輸入給模型的推理序列是:[輸入token, lookahead序列, verification序列]。但是具體實現上原始代碼和llama.cpp略有出入,下面會用二者的變量進行說明。
原始代碼
下圖是結合論文和原始源碼做的解讀。圖中最上方是論文的原始圖,下方是筆者基于論文原始圖做的二次注釋。

輸入
對于原始代碼,每個解碼步驟中最終輸入給模型的推理序列是 [input_ids, past_tokens, guess_tokens]。
- input_ids就是輸入token,形狀是(batch_size, sequence_length)。
- past_tokens是lookahead序列。
- guess_tokens是verification序列。
從代碼可知,最終拼接為 [input_ids, past_tokens, guess_tokens]進行推理。
for ll in range(fill_level + 1):
all_past += past_tokens[ll]
if guess_tokens is not None:
# 此處會拼接
input_ids = torch.cat((input_ids, torch.tensor(all_past + guess_tokens, device=input_ids.device, dtype=input_ids.dtype).unsqueeze(0)), dim=1)
guess_ids = list(range(lst_id + 1, lst_id + 1 + guess_size)) * (len(guess_tokens) // guess_size)
position_ids = torch.cat((position_ids, torch.tensor(ids_list + guess_ids, device=input_ids.device, dtype=input_ids.dtype).unsqueeze(0)), dim=1)
attention_mask = torch.cat((attention_mask, torch.ones(1, attn_size + len(guess_tokens), \
device=input_ids.device, dtype=input_ids.dtype)), dim=1)
前向傳播
執(zhí)行 forward 解碼會生成 output,output 包含以下內容:
- out_logits:即正常解碼步驟中輸出的下一個 token 的 logits;
- inp_logits:根據 2D Window 生成的 W 個 tokens 的 logits,這 W 個 tokens 會和 2D Window 拼接成 W 個 N-grams;
- guess_logits:如果有匹配的 guess tokens,則會生成這些 guess tokens 的 logits 以供驗證。
相應的縮減版代碼如下。
如何生成
lguess = len(guess_tokens)
ret.out_logits = ret.logits[:,prefill_size - 1,:].to(input_ids.device) #decode logits
if lguess > 0:
window = len(past_tokens[fill_level])
start = ret.logits.size(1)-window-lguess
end = ret.logits.size(1)-lguess
ret.inp_logits = ret.logits[:,start:end,:] #lookahead branch logits
ret.guess_logits = ret.logits[:,-lguess:,:] #verification branch logits
如何使用
next_token_logits = outputs.out_logits #outputs.logits[:, -1, :]
if past_tokens[1] is None:
past_tokens[1] = torch.argmax(outputs.inp_logits, dim=-1)[0].tolist() #fill window with argmax
elif past_tokens[LEVEL - 2] is None:
past_tokens[fill_level + 1] = torch.argmax(outputs.inp_logits, dim=-1)[0].tolist()[1:] #fill window with argmax
else:
guess_logits = logits_warper(input_ids, outputs.guess_logits[0])
llama.cpp代碼
下圖是結合論文和llama.cpp的lookahead.cpp源碼做的解讀。圖中最上方是論文的原始圖,中間來自llama.cpp源碼的注釋,下方是筆者基于論文原始圖做的二次注釋。

從上圖中間部分紅色對Logits(代表要輸出logits)做的解讀可以看出來,llama.cpp最終給模型的輸入推理序列總共是\(W+ G+1\)個。針對上圖則具體為:
- 輸入tokens為一個序列:藍色0。
- lookahead共5個序列:藍色0+綠色1+紅色2;藍色0+橙色1+綠色2+紅色3;藍色0+橙色2+綠色3+紅色4;藍色0+橙色3+綠色4+紅色5;藍色0+橙色4+綠色5+紅色6;
- verification共兩個序列。第一個n-gram:藍色0+天藍色1+天藍色2+天藍色3。第二個n-gram:藍色0+天藍色1+天藍色2+天藍色3。
模型的輸入由三部分組成:[現有輸入token,2d window,guess token],prompt生成next token;2d window生成每個n-gram分支的next token;guess生成token,并且驗證。三部分可以并行執(zhí)行,互不干擾。
- 現有輸入token。
- 輸入tokens一個:藍色0,對應代碼圖的batch 0。
- 直接基于自回歸解碼生成當前序列的next token。該過程生成的next token會與guess token中的天藍色1進行比對。
- 2d window提取出的行。
- lookahead序列共5個序列:藍色0+綠色1+紅色2;藍色0+橙色1+綠色2+紅色3;藍色0+橙色2+綠色3+紅色4;藍色0+橙色3+綠色4+紅色5;藍色0+橙色4+綠色5+紅色6;對應代碼圖的batch 1~ batch 14。
- 上述序列對應的也是各個n-gram,不同n-gram序列分支生成不同的next token,從而生成新的n-gram組合。該過程的目的是維護和更新n-gram pool,和當前要驗證的tokens無關。
- n-gram,也叫guess token。
- verification序列第一個n-gram:藍色0+天藍色1+天藍色2+天藍色3。對應代碼圖的batch 15~batch 17。
- verification序列第二個n-gram:藍色0+天藍色1+天藍色2+天藍色3。對應代碼圖的batch 18~batch 20。
- 對于guess token,會和現有序列合并計算,生成各個位置的logits,用于和guess token進行逐個對比,滿足sampling要求的便可以加入到現有序列中。該過程的目的是驗證現有的n-gram pool中是否有符合要求的tokens,從而為現有的序列添加新的tokens。
另外,為了更好的說明。下圖是從llama.cpp中截取的注釋。對圖中的術語解讀如下。
- Batch:并行執(zhí)行的原始推理序列。數字代表token在原始推理序列中的位置。
- T:假如當前step是t,則0代表t-1個step,-1代表t-2個step。即0是上一時刻新生成的N-gram的最新token。
- Info:I 代表輸入token,L代表lookahead分支,V代表verify分支。
- Pos:用于掩碼設置。T會確定時間步順序,同一個T中的token由pos確定相對順序,因此每個token只能看到當前位置之前的掩碼。
- Logits:推理生成的logits。llama.cpp對原始代碼進行了優(yōu)化,實際進行推理的序列會比原始推理序列少。
- Seq:W+G+1=8,共有8個分支,每個分支是獨立推理,互相不干擾。Seq就是這8個分支對應的掩碼。
- j_tokens 和 id:具體代碼中的變量。

對應代碼如下,其中tokens_j是2d window。
// the input token belongs both to all sequences
std::vector<llama_seq_id> seq_id_all(W + G + 1);
for (int i = 0; i < W + G + 1; i++) {
seq_id_all[i] = i; // W+G個序列都有prompt
}
// current token - first token of the first level
// 輸入token對應的推理,共一個推理序列。n_past代表位置,輸入token需要輸出logits
common_batch_add(batch, id, n_past, seq_id_all, true);
// verification分支對應的推理
// verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation
{
const int g_cur = ngrams_observed.cnt[id];
ngrams_cur.resize(g_cur);
for (int g = 0; g < g_cur; g++) {
ngrams_cur[g].active = true;
ngrams_cur[g].tokens.resize(N);
ngrams_cur[g].i_batch.resize(N);
ngrams_cur[g].seq_id = W + 1 + g;
ngrams_cur[g].i_batch[0] = 0;
ngrams_cur[g].tokens [0] = id;
}
// 一共最多G個推理
for (int j = 0; j < N - 1; j++) {
for (int g = 0; g < g_cur; g++) {
const int idx = id*(N - 1)*G + g*(N - 1);
const llama_token t = ngrams_observed.tokens[idx + j];
ngrams_cur[g].tokens [j + 1] = t;
ngrams_cur[g].i_batch[j + 1] = batch.n_tokens;
// 放到prompt后j+1處;對應的序列是第{ W + 1 + g }個;這些token需要輸出logits,對應上圖,就是天藍色的1,2,3都需要輸出logits
common_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true);
}
}
}
// 依然是輸入token對應的推理,填補余下W-1給token。n_past + i代表位置
// fill the remaining W - 1 tokens for the first level
for (int i = 1; i < W; i++) {
seq_id_look.resize(W - i);
for (int j = 0; j < W - i; j++) {
seq_id_look[j] = i + j + 1;
}
// tokens_j[0][i]代表從2d window第一行提取W-1個token,塞到prompt后面1~i處,這些token不需要輸出logits,對應上圖,就是橙色和綠色不需要輸出logits;對應序列是 i + j + 1
common_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false);
}
// lookahead分支對應的推理
// fill the rest of the levels
for (int j = 1; j < N - 1; j++) {
for (int i = 0; i < W; i++) {
// tokens_j[0][i]代表從2d window第j行提取W-1個token,塞到prompt后面(1~i)xj處,如果是第N-2行(N-2就是2d window的最后一行的下標)就需要輸出logits,對應序列是{ i + 1 }
common_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2);
}
}
common_batch_add()代碼如下圖所示。
void common_batch_add(
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits) {
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
batch.token [batch.n_tokens] = id; // 新加入token的id
batch.pos [batch.n_tokens] = pos; // 新加入token的位置
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i]; // 新加入token屬于哪個序列
}
batch.logits [batch.n_tokens] = logits; // 是否要輸出logits
batch.n_tokens++; // 本batch有多少個token
}
3.3 總體流程
一個 Decoding Step 中大概包含如下幾個步驟:
- Parallel Decoding:為lookahead分支中的每個位置生成一個token,即經過一次前向推理,生成候選 Token 對應的待驗證 Token 序列。
- Verify:使用上一步生成的待驗證 Token 與候選 Token 對比,確定最長的正確序列。n-gram的pool存儲了歷史所有的n-gram(實際選擇了G個),選取\(g_i^k\)的第1個位置\(g_i^1\)恰好是當前輸出序列最后一個token的所有n-gram。
- Collect N-Grams:從lookahead分支軌跡中收集并緩存新生成的n-gram。具體是使用未驗證通過的候選 Token 和對應生成的 Token 組成 N-Gram 序列,并添加到 N-Gram Pool 中。
- Update:用生成的待驗證 Token 序列更新候選序列(lookahead 分支),以保持固定的窗口大小,即2D窗口整體向右滑窗。
- Match N-Grams:使用候選序列中的 Token 依次從 N-Grams 中匹配對應 Token,并替換候選序列。
下圖給出了W=5、N=3和G=2的LOOKAHEAD解碼工作。

每個解碼步驟中做如下操作,每個解碼步驟的 input 應該包含:當前已經生成的 tokens、壓縮的 2D Windows、拼接的 guess tokens:
- 將 2D Window 中的 token 拼接到輸入中。
- 如果上一步生成的 token (最后一個token) 在 N-gram pool 中有匹配的 N-gram,將這些 guess tokens 拼接到輸入中。例如,如果上一步生成的 token 是 "機" ,N-gram pool是{"機": ["器學習","關槍!"]},則將以下內容拼接進輸入中:["器學習關槍!"]。
- 構造 Attention Mask,其特點是:每個 token 只對其 position index 大于自己的其他 tokens 可見;Lookahead Branch 與 Verification Branch 中的 tokens 互相不可見。
- 執(zhí)行 forward 解碼,生成 output,output 包含以下內容:
out_logits:即正常解碼步驟中輸出的下一個 token 的 logits;inp_logits:根據 2D Window 生成的 W 個 tokens 的 logits,這 W 個 tokens 會和 2D Window 拼接成 W 個 N-grams;guess_logits:如果有匹配的 guess tokens,則會生成這些 guess tokens 的 logits 以供驗證。
其算法如下。

3.4 初始化
在官方博客里沒有介紹初始的 Guess Tokens (圖中的 “who”、“is”、“the”、“he”、“just”、“great”)從哪里來,在 Github Issue 作者有解答。其中第一級可以從輸入 Prompt 中隨機選取,甚至可以從詞表中隨機選取,然后通過 (N - 2)次 warmup 就可以生成多級的 n-Gram Pool,甚至也可以全部隨機。作者選擇的是通 Prompt 對應的 Token 列表中隨機選擇第一級,然后 warmup 后幾級。由于 N 相對整個生成過程的 step 數來說比較小,所以一般經過幾次迭代之后就會變得有效。
3.4.1 warm up
在進行解碼之前,n-gram pool和2d window都是空的,要進行初始化。需要構造 N-gram pool 和填充 2D Window。讓我們繼續(xù)假設 W=5, N=4。
3.4.2 填充 2D Window
T=0時刻,3-gram的第一個位置從prompt中隨機采樣,3-gram 的第二個位置來自2-gram并行解碼prefill。隨后每個step會并行解碼3-gram的最后一個位置。并且每到下個step時,滾動3-gram位置(隨著解碼的進行,軌跡中最早的 Token 會被刪除,因此會丟掉第一個位置,留下后兩個位置作為并行解碼輸入)。N-gram的概念就是以此類推。具體操作如下。
- T=0 時刻,從 prompt 中隨機選取 W+N-3=6 個 tokens 填充 2D Window 的第一行,此時 2D Window 為,里面只有一行,假設是:
? E相對 prompt 最后一個 token 的偏移為1,F相對 prompt 最后一個 token 的偏移為2,以此類推。
-
T=1 時刻,將E,F,G,H,I填入到輸入中,假設輸入是A,則拼接之后是 AEFGHI,執(zhí)行一次 forward,除了得到正常解碼步驟的下一個 token B,還能得到EFGHI對應的下一組token,假設是KLMNO。
更新 2D Window,正常解碼得到的 B 會取代 E ,因此需要移除 E ;另外,需要用新生成的 tokens 填充 2D Window,最終得到的 2D Window 如下:
-
T=2 時刻,將F,G,H,I,J,K,L,M,N,O填入到輸入中則拼接之后是 ABFGHIJKLMNO,執(zhí)行一次 forward,除了得到正常解碼步驟的下一個 token C,還能得到FGHIJKLMNO對應的下一組token,假設是PQRST。
更新 2D Window:正常解碼得到的 C 會取代 F ,因此需要移除 F ;另外,需要用新生成的 tokens 填充 2D Window,最終得到的 2D Window 如下:
\[[G,H,I,J]\\ [K,L,M,N,O]\\ [P,Q,R,S,T] \]此時 2D Window 已經填充完畢。值得注意的是,初始時應將 2D Window 的第一行填充為 W+N-3 個。因為每填充一行,需要將之前每一行的第一個 token 移除;一共需要填充 N-2 次,填充完后第一行最終變?yōu)?W+N-3-(N-2)=W-1 個 tokens,其余行均為 W 個 tokens。N-2的意思是:一共N次,prompt算一次,隨機填充的算一次,剩下N-2次。
3.4.3 填充n-gram
如果設置了 POOL_FROM_PROMPT,則會從 prompt 中構造 N-gram pool??梢员闅v prompt,以當前 token t 為 key,在 list 中存儲以 t 開頭的 n-gram。假設 prompt 為"BOOK,BUS!" ,則 N-gram pool 中"B"對應的 value 為:{"B": ["OOK","US!"]}
如果沒有設置,則在N-2次推理之后,此時已經生成了2d window。需要用 輸入 + 2d window 單獨做一次前向傳播,即可以生成一份完整的n-gram,借此對pool進行初始化。因此,廣義的warm up包括 N-1 次前向傳播。
Lookahead Branch 需要 N?2 次前向傳播才能完全搭建好。在此之前, N-gram Pool 為空,此時是沒有 Verification Branch 的。
3.5 lookahead分支
lookahead分支維護一個固定大小的2維窗口,以根據雅可比迭代軌跡生成新的 n-gram。
具體來說,就是循環(huán)生成不同 fill_level 的 past_tokens;最終期望的形狀是 [WINDOW-1, WINDOW, ..., WINDOW],長度為 LEVEL-1。之所以只有 LEVEL-1 個而不是 LEVEL 個,是因為這 LEVEL-1 個是被用作輸入來考慮;decode 時,還有額外的一個 WINDOW 長度的 ids,合起來是 LEVEL 個,構成 LEVEL-gram。
以論文圖例來說,在當前步驟 t 中,使用之前步驟形成的軌跡進行一次 Jacobi 迭代,為所有 5 個位置生成新的 Token。然后收集 4-gram(例如,一個 4-gram 可以包括位置 1 的橙色、位置 2 的綠色、位置 3 的紅色 Token,以及當前 step 中新生成的黃色 Token 4)。隨著解碼的進行,軌跡中最早的 Token 會被刪除,以保持 N 和 W 的恒定。
3.6 verification分支
n-gram會保存在n-gram pool中,verification分支在n-gram pool中識別和確定有希望的 n-gram。在此分支中,作者使用 N-Gram 中的第一個 Token 來匹配輸入的最后一個 Token,這一步是通過簡單的字符串匹配來確定的。一旦識別,這些 n-gram將被添加到當前輸入token后,并通過 LLM的正向傳播對其進行驗證。隨著 N-Gram 緩存的增加,會有多個相同 Token 開頭的 N-Gram 出現,并且越來越頻繁,這增加了驗證成本。為了降低成本,作者將驗證分支中候選 N-Gram 的數量上限設置為 G。通常設置為與 W 成正比,比如 G=W。
如果有匹配的 guess tokens,進入 Verification Branch 以驗證是否接受 guess tokens。目前 Lookahead Decoding 支持 Sampling 和 Greedy Search。
Greedy Search 方案會根據 guess_logits驗證所有的備選 N-gram,作最長匹配即可。
Sample算法如下圖所示,我們概述如下:
- 假設有 K 個備選 N-grams 匹配當前步驟解碼的 token,且有
prob = Softmax(out_logits); - 沿著 N-gram 的維度遍歷這 K 個 N-grams,假設當前遍歷到 N-gram 的第 i 個位置,此時還剩下 k 個備選 N-gram,遍歷這 k 個備選的 N-grams:
- 取其第 i 個位置的 token \(t_i\) ;
- 采樣 r~U(0,1) ;
- 若$ r < prob(t_i)$,接受該 token,從備選集中移除所有第 i 個位置不是該 token 的 N-gram;同時更新
prob = Softmax(guess_logits[i]); - 否則,繼續(xù)驗證下一個 N-gram;
- 結束遍歷后,從
prob中采樣一個新的 token。

3.7 Prepare for next iteration
當前迭代結束之后,會為下一次迭代做好準備。具體是:
- 更新 2D Window。用后一層替代前一層(最后一層由最新輸出得到的logits填充),并根據被接受的 tokens 的數量截斷每一層。在當前的序列中隨機采樣填充被截斷的部分。
- 更新 n-gram。如何生成新的n-gram?其實就是就是在2d-window里面從上往下找。
- 更新下一次前向的 Attention Mask 和 KV Cache。假設接受了 k 個 tokens,就據此擴展 Attention Mask,并將這 k 個 tokens 的 cache 拼接到 KV Cache 上。
- 當滿足退出條件,例如生成的 tokens 長度達到
max_length時,返回結果。
3.7.1 原始代碼
原始代碼中會用后一層來更新前一層。
if past_tokens[1] is None: #filling multi-level window, the very first step is different
past_tokens[0] = past_tokens[0][1:]
past_tokens[1] = torch.argmax(outputs.inp_logits, dim=-1)[0].tolist()
fill_level += 1
elif past_tokens[LEVEL - 2] is None: #filling multi-level window
for level in range(fill_level + 1):
past_tokens[level] = past_tokens[level][1:]
current_past_tokens = torch.argmax(outputs.inp_logits, dim=-1)[0].tolist()
past_tokens[fill_level + 1] = current_past_tokens[1:]
fill_level += 1
else:
# 用后一層來更新前一層
if ALWAYS_FWD_ONE:
past_tokens[0] = past_tokens[1][1:]
for level in range(1, LEVEL - 2):
past_tokens[level] = past_tokens[level + 1][:]
past_tokens[LEVEL - 2] = new_results
else:
past_tokens[0] = past_tokens[1][1 + max_hit:]
for level in range(1, LEVEL - 2):
past_tokens[level] = past_tokens[level + 1][max_hit:]
past_tokens[LEVEL - 2] = new_results[max_hit:]
3.7.2 llama.cpp
llama.cpp會用logits更新最后一行。v是逐個校驗n-gram的循環(huán),此循環(huán)把對2d window的更新和對n-gram的校驗都封裝在一起。
// update lookahead tokens
{
for (int i = 0; i < W; i++) {
tokens_j_prev[i] = tokens_j[0][i];
}
// 用后一層來更新前一層
for (int j = 0; j < N - 2; j++) {
tokens_j[j] = tokens_j[j + 1];
}
if (v == 0) {
// sample from the last level
// 用logits更新最后一行。v是逐個校驗n-gram的循環(huán),此循環(huán)把對2d window的更新和對n-gram的校驗都封裝在一起
for (int i = 0; i < W; i++) {
tokens_j[N - 2][i] = common_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
}
} else {
for (int i = 0; i < W; i++) {
// there are different ways to init these tokens
if (0) {
// random init
tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)];
} else {
// init from the previous level
tokens_j[N - 2][i] = tokens_j[0][i];
}
}
}
}
0xFF 參考
https://github.com/ggerganov/llama.cpp
Lookahead Decoding 圖文詳解 Sjrrr大蛇
Break the Sequential Dependency of LLM Inference Using Lookahead Decoding
hao-ai-lab/LookaheadDecoding - Github
萬字綜述 10+ 種 LLM 投機采樣推理加速方案 AI閑談
jacobi decoding 論文速讀 Bruce 仗劍走天涯
[2401.07851] Unlocking Efficiency in Large Language Model Inference: A Comprehensive Survey of Speculative Decoding
Lookahead Decoding & 6 種 LLM 加速解碼對比 AI閑談
https://lmsys.org/blog/2023-11-21-lookahead-decoding/
https://github.com/hao-ai-lab/LookaheadDecoding/issues/8
https://github.com/hao-ai-lab/LookaheadDecoding
https://jalammar.github.io/illustrated-gpt2/
https://jalammar.github.io/how-gpt3-works-visualizations-animations/
https://arxiv.org/abs/1811.03115
https://arxiv.org/abs/2211.17192
浙公網安備 33010602011771號