解密Prompt系列8. 無需訓練讓LLM支持超長輸入:知識庫 & unlimiformer & PCW & NBCE
這一章我們聊聊有哪些方案可以不用微調直接讓大模型支持超長文本輸入,注意這里主要針對無限輸入場景。之前在BERT系列中我們就介紹過稀疏注意力和片段遞歸的一些長文本建模方案長文本建模 BigBird & Longformer & Reformer & Performer,不過以上方案無一例外都需要在訓練階段引入。針對當前大模型微調成本高的問題,更多研究放到如何在模型外部支持長文本輸入。先討論下為啥當前的大模型會在推理時存在輸入長度的限制,主要有以下幾點原因
-
Attention矩陣計算復雜度:在不引入稀疏注意力的條件下,Attention矩陣的內存和計算復雜度是\(O(序列長度^2)\),文本長度的上升會帶來顯存的指數增長。
-
訓練耗時:訓練階段的文本長度會顯著影響訓練速度, 因此2048一般是當前預訓練常見的最大長度。
- 位置編碼的外推性: 這里的外推性是指推理長度超過訓練長度。包括推理會出現沒訓練過的位置編碼,以及注意力機制需要處理比訓練更長的輸入。已有的旋轉位置編碼等相對位置編碼已經具有了外推性,既推理長度可以超過訓練長度,但在ALibi位置編碼的測試中,這種外推性是以大幅性能損失為代價的。
針對以上問題本章介紹4種方案:顯式搜索的知識庫外掛方案,隱式搜索的Unlimiformer, 并行輸入的pcw和并行解碼NBCE。
顯式搜索: 知識庫外掛
- paper: Unleashing Infinite-Length Input Capacity for Large-scale Language Models with Self-Controlled Memory System
- 看到最無敵的應用,文本和表格解析超厲害https://chatdoc.com/?viaurl=ainavpro.com
- ChatGPT代碼實現: https://github.com/arc53/DocsGPT
- ChatGLM代碼實現: https://github.com/imClumsyPanda/langchain-ChatGLM
- 適用于大規模知識問答場景
這塊可能是GPT后比較火的方向,有一陣每天都能看到類似的新應用,從GPT讀論文,再到百科問答,搭配langchain框架,在DocQA,KBQA的場景簡直無往不利, 以上分別給出了基于ChatGPT和ChatGLM的兩個實現方案。
實現的步驟基本可以被下圖概括
- 長文本解析切分成chunk: 實際使用過程中發現文本解析竟然是最核心的部分,能否把需要保留語義完整性的段落拆成整段,能否高質量的解析表格,和結構化數據,對后續QA的影響最大
- 文本向量化:中文可用的embedding模型有不少,也可以基于simcse,consert在垂直領域做進一步的微調。在向量化階段主要的問題是文本截斷帶來的上下文損失會影響召回,因此可以嘗試重疊切分,拼接摘要/標題等方式
- 向量入庫:需要高效向量檢索的數據庫,Milvus、Pinecone,這塊最近也火了一波初創公司
- 用戶問題改寫:在多輪QA的場景,對話歷史有兩種使用方式,其一使用歷史對話對當前query進行改寫再召回,其二種是使用原始用戶query去召回文本,在回復階段引入對話歷史
- 召回:基于用戶query或改寫query進行向量化檢索,topK或者閾值召回。除了考慮相關性,在部分場景也要考慮時效性,文本質量等等
- 答案生成:使用召回文檔拼接用戶query進行答案生成,這一步往往還需要用到模型摘要,Refine等能力,核心是對以上召回的長文本進行壓縮
搜索法最大的優點是實現簡單,不過也有許多限制就是只能支持NLU任務,以及會破壞輸入文本的上下文連續性,和文本順序。但在大規模知識問答這塊算是現在看到最好的方案。
隱式搜索:Unlimiformer
- Unlimiformer: Long-Range Transformers with Unlimited Length Input
- https://github.com/abertsch72/unlimiformer
- 適用于Encoder-Decoder模型,長文本摘要等場景
特意起了個隱式搜索的標題,是因為和上面的文本搜索實現有異曲同工之妙,本質的差異只是以上是離散文本塊的搜索。而Unlimiformer是在解碼階段對超長輸入,token粒度的輸出層embedding進行檢索,選擇最相關的Top Token計算Attention。
首先對于超長輸入,unlimiformr采用以上提到的重疊切分的方法,重疊率50%,這樣可以更好保留上文和文本連貫性,例如第一段文本是1-500字,第二段重疊250字取250-750字。然后使用Encoder對每段文本進行獨立編碼,繞過Attention的平方復雜度問題。最后輸出每段文本的Embedding,注意這里不是文本整體embedidng, 而是后半部分(250~500字)每個Token最上層的Embedding,并寫入向量索引,這里用的是Faiss。
在解碼層,每一步解碼,query都會檢索注意力最高的Top-k個輸入Token,作為編碼器部分的信息用于解碼器的解碼。這里簡單回憶下Attention計算, Top-K個Token就是讓以下注意力取值最高的key。
考慮Decoder的每一層(N層)中的每一個head(L個頭)都需要和Encoder的輸出層進行交互, 檢索Top Key,如果存儲每一層每個head的Key,需要構建\(O(L*N*seqlen)\)的向量存儲。對此作者進行了優化,改變了以下QK的計算順序,用每一層每個頭Key的映射矩陣對Q進行映射,這樣只需要存儲一份seq_len的編碼向量(\(h_{encoder}\)),在每一層檢索時用映射后的Q進行檢索既可,其實就是時間換空間
unlimiformer提供了代碼實現,核心代碼抽出來看下有兩塊
- 超長文本編碼:對文本進行切塊,分別編碼,取后半部分
for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices:
chunk = input_ids[:, context_start_ind:context_end_ind]
chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind]
hidden_states = self.model(chunk, attention_mask=chunk_attention_mask, labels=dummy_labels, return_dict=True)
last_hidden = hidden_states.encoder_last_hidden_state # (batch, chunked_source_len, dim)
to_add = last_hidden[:, update_start_ind:update_end_ind].detach()
to_apply_mask = chunk_attention_mask[:, update_start_ind:update_end_ind]
- 向前計算檢索Top-key用于Attention矩陣的計算
def attention_forward_hook(self, module, input, output):
# output: (batch, time, 3 * heads * attention_dim)
with torch.no_grad():
query = self.process_query(output)[:,-1] # (batch * beam, head, dim)
query = query[:, self.head_nums] # (batch * beam, head, dim)
#這是前面提到的計算優化使用每層每個head的Key映射矩陣對Query進行映射用于搜索
attention_layer_list = self.attention_layer_to_capture(self.layer_begin, self.layer_end)
k_proj_layer = [layers[0] for layers in attention_layer_list][self.cur_decoder_layer_index]
# modify query by k_projs
k_proj = k_proj_layer.weight
k_proj = k_proj.view(1, self.num_heads, query.shape[-1], k_proj.shape[0]) # (1, num_heads, attn_dim, embed_dim)
datastore_query = query.unsqueeze(-2) # (batch * beam, num_heads, 1, attn_dim)
datastore_query = torch.matmul(datastore_query, k_proj) # (batch * beam, num_heads, 1, embed_dim)
datastore_query = datastore_query.squeeze(-2) # (batch * beam, num_heads, embed_dim)
datastore_query = datastore_query.view((self.datastore.batch_size, -1, datastore_query.shape[2])) # (batch, beam * num_heads, embed_dim)
# 這里進行Top Key的檢索:得到Key的索引,Embedding和得分
top_search_key_scores, top_search_key_indices = self.datastore.search(datastore_query, k=self.actual_model_window_size)
embeddings = torch.take_along_dim(input=self.embeddings.unsqueeze(1),
indices=top_search_key_indices.unsqueeze(-1).to(self.embeddings.device), dim=-2)
##后面就是常規的對Embedding進行Key和Value的映射然后做Attention了
和前面的文本檢索對比,unlimiformer的存儲成本會更高,因為要存儲token粒度的Embedding信息,更適用于on-the-fly的長文本推理使用,例如針對單一文檔的QA,只存儲當前文檔,而前面文本塊檢索方案更適合一些大規模知識,批量的文檔的存儲。
但其實unlimiformer直接對Token進行離散召回,這一點我讓我有些困惑,這樣單一token的檢索召回,真的不會破壞上文連續性么?還是說Encoder編碼方式已經保證了檢索召回大概率會召回成段的Token,又或者說每個Token的Embedding內已經充分編碼了連續上下文的信息,召回離散Token也不會出現割裂的語義信息?哈哈考慮unlimiformer只支持Encoder-Decoder的框架,和我們用的Decoder框架不適配,我決定不細糾結了!有在中文嘗試過效果的童鞋可以分享下~
并行輸入:PCW
- Parallel Context Windows for Large Language Models
- https://github.com/AI21Labs/Parallel-Context-Windows
- 適用于Decoder模型,以及小規模內容理解場景
同樣是對超長文本進行切塊,然后獨立編碼,PCW使用的是Decoder框架。和unlimiformer只使用Top-Key進行解碼,PCW在解碼過程中對全部輸入上文進行Attention。對比Encoder-Decoder框架,因為輸入和輸出都在Decoder側,PCW需要解決兩個問題:位置編碼和注意力矩陣如何調整, 下圖基本概括了這兩個細節
1. 位置編碼:輸入文本截斷后,每段文本的位置編碼相同。考慮所最長的文本長度為C,則輸入文本最大的位置編碼id是$P_C$,則解碼器第一個字的位置編碼id是$P_{C+1}$,然后順序向后編碼。其實就是丟棄了上文多段文本之間的位置關系,解碼時只知道上文多段文本都是在解碼器之前,但無法區分文本之間的位置。不過因為上文每段文本復用了相同的位置編碼,因此位置編碼的長度大幅降低,也就降低了對位置編碼外推性的需求。
position_ids = attention_mask.long().cumsum(-1) - 1
n_task_tokens = position_ids.shape[1] - sum_windows_size
# 保證解碼器的位置編碼比最長上文要長度+1
position_ids[0, -n_task_tokens:] = torch.arange(max_window_size, max_window_size + n_task_tokens, 1)
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: # i.e., first token is already generated
position_ids = position_ids[:, -1].unsqueeze(-1)
elif windows_key_values: # i.e., we are in the first token generation #其實就是取-n_task_tokens:
position_ids = position_ids[:, sum_windows_size:]
- 注意力矩陣
- 輸入文本進行截斷后各自獨立通過Decoder進行編碼。因此每一段輸入的文本的注意力矩陣是相互獨立的。這塊不需要修改注意力矩陣的實現,只需要文本chunk后分別過模型即可。得到每段文本的past-key-values直接進行拼接
def combine_past_key_values(past_lst: List[Tuple[Tuple[torch.Tensor]]],
contains_bos_token: bool = True) -> Tuple[Tuple[torch.Tensor]]:
# 這里past_lst是每段文本的past-key-value
# GPT是n_layer * 2(key+value) * tensor(seq_len,batch,n_head,n_hidden)
# 注意不同模型past-key-value的shape不同
# Chatglm是n_layer * 2(key+value) * tensor(seq_len,batch, n_head, n_hidden)
return tuple(
(torch.cat([c[i][0] for c in past_lst], dim=2),
torch.cat([c[i][1] for c in past_lst], dim=2))
for i in range(len(past_lst[0])))
- 解碼器對全部上文進行Attention計算:這里需要修改Attention把上文的全部Attention進行拼接,讓解碼器的每一步可以對全部上文計算Attention
res['past_attention_mask'] = torch.cat([window['attention_mask'] for window in windows], dim=1)
combined_attention_mask = torch.cat((cache['past_attention_mask'], encoded_task_text['attention_mask']), dim=1)
考慮ChatGLM本身是二維的Attention矩陣和位置編碼,特殊的BOS和GMASK,我重寫了PCW,但是在長文本QA問題上表現比較一般,表現在當上文多段文本無明顯關系的時候例如多個完全無關的新聞,在進行問答的時候,正確答案中會混雜很多無關的文本變短,以及這個問題當上文片段變多,或者指令問題變多的時候會變得越來越嚴重,直到開始完全胡說八道。當然不排除我寫bug了哈哈哈,但我自己是真的沒查出來。
不過也有一種可能,是PCW是在輸入層就開始對超長上文進行Attention,因為不同上文的位置編碼相同,一定程度上會讓解碼注意力變得非常分散,導致注意力的熵值變高,解碼的不確定性變大,更容易出現亂碼。
并行解碼:NBCE
- 蘇劍林. (May. 23, 2023). 《NBCE:使用樸素貝葉斯擴展LLM的Context處理長度 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9617
- 蘇劍林. (May. 31, 2023). 《關于NBCE方法的一些補充說明和分析 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9632
- https://github.com/bojone/NBCE
- 適用于Encoder-Decoder模型,長文本內容理解如摘要問答等場景
壓軸的必須是蘇神的NBCE!這里我把看完博客后的理解進行簡單的總結,詳細推理請看去蘇神的科學空間!答應我一定要去看!每次看蘇神推導,都會覺得數學之魂在燃燒!
NBCE的原理簡單解釋如下圖,和PCW相同是對每段上文進行獨立編碼,但差異在于PCW是在輸入層進行融合,而NBCE是在輸出層對每一個Step輸出的預測token的概率矩陣進行融合,更大程度上避免了注意力被分散,保證了解碼的合理性。
這里我們簡單說下如何在輸出層進行融合,把找超長文本chunk成多段文本后($s_1,s_2,...s_k$),基于樸素貝葉斯的簡化假設, 基于多段文本進行并行解碼的預測概率可以簡化如下,也就是每段文本條件解碼概率之和減去無條件解碼概率 $$ log(P(T|s_1,..s_k)) = \sum_{i=1}^Klog(p(T|s_i)) -(n-1)log(p(T)) + const $$
既然說了是簡化假設,因此可以對上式進行一些調優,核心是讓模型對上文的解碼更加準確,降低無關上文帶來的解碼噪聲,比較重要的優化包括
- 準確率優化解碼
以上解碼概率求和,其實是對k段文本生成的\(vocab * K\)的概率矩陣,沿K做AvergePooling,得到最終\(vocab*1\)的解碼概率。但考慮LM訓練其實是擬合one-hot(出現概率最高的詞),也就是除了概率最高的幾個token之外其余token的預測概率都不靠譜。如果直接取平均的多路打分,很容易投出一個在各段文本上打分都不高不低的token,上文越多這個問題越明顯。但其實在閱讀理解例如抽取,QA問題的解碼策略上我們要的是在某段文本上打分置信度最高的token,因為答案往往只來自一個上文片段。
因此蘇神給出了兩種準確率更高的解碼方案,一個是MaxPooling+GreedySearch,其實就是對\(vocab*k\)的概率矩陣取全局概率最高的token,另一個是最小熵+RandomSampling,也就是從多段上文中取1個預測置信度最高的上文進行解碼。這里其實是和PCW最大的差異,也就是在解碼層進行融合,并通過熵值較低的融合策略來保證解碼的準確率。
以及后面蘇神還通過Top-P來進一步過濾尾部的噪聲,以及通過控制每一步解碼的轉移概率,來讓解碼器不會在不同上文片段之間反復切換,而是保證連續的解碼片段大概率來自相同的上文片段。
- Context-aware解碼
基于上文來進行解碼的一個核心是為了降低模型回答胡說八道的概率。例如在金融場景我們直接問chatgpt基金贖回費用是多少 vs 我們基于某個基金的介紹問模型該基金的贖回費用是多少,后者得到的答案一定是更準確的。而其實以上二者的差異在于條件(上文)解碼和無條件解碼, 因此可以通過diff無條件編碼的方式來提高解碼對上文的依賴程度(reliablity)。如下圖
因此蘇神把把n變成超參Beta, 控制條件概率和無條件概率的占比,Beta越高解碼和上文的關聯度越高,QA等場景的解碼準確率越高,生成自由度越低。
當前NBCE的局限性在于無法處理上文片段之間的位置關系,以及無法處理解碼需要依賴多個上文片段的場景。后者感覺可以通過預測概率矩陣的相關性修改Pooling方式,而前者
基于蘇神提供的代碼,在chatglm上做了嘗試,只需要簡單調整下輸入輸出的部分就可以直接使用。我在論文,書籍,和新聞上進行摘要,實體抽取和QA問答后發現,INT8量化的模型效果似乎要略優于FP16, 顯著優于INT4。INT8量化下,10K左右的輸入,顯存占用基本可以限制在單卡A100(40g),大家可以自行嘗試下~
@torch.inference_mode()
def generate(max_tokens):
device = torch.device('cuda')
"""Naive Bayes-based Context Extension 演示代碼
"""
inputs = tokenizer(batch, padding='longest', return_tensors='pt').to(device)
input_ids = inputs.input_ids
n = input_ids.shape[0]
with torch.no_grad():
for i in range(max_tokens):
# 模型輸出
model_input = model.prepare_inputs_for_generation(input_ids)
outputs = model(**model_input,
return_dict=True,
use_cache=True
)
"""
中間代碼不變
"""
# 把唯一的回答擴充到每一個batch進行下一輪的解碼
next_tokens = next_tokens.unsqueeze(-1).tile(n, 1)
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
# 更新past-key-values, 更新attention_mask, 更新position_ids
model_kwargs = model._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
)
想看更全的大模型相關論文梳理·微調及預訓練數據和框架·AIGC應用,移步Github >> DecryptPropmt
Reference
- https://blog.langchain.dev/langchain-chat/
- https://blog.frankzhao.cn/build_gpt_bot_for_doc/
- https://zhuanlan.zhihu.com/p/616620170
- ALiBi:Train short, test long:attention with linear bias enables input length extrapolation
- https://github.com/ofirpress/attention_with_linear_biases
- Trusting Your Evidence: Hallucinate Less with Context-aware Decoding
未經許可請勿轉載哦~

這一章我們聊聊有哪些方案可以不用微調直接讓大模型支持超長文本輸入,分別介紹顯式搜索,unlimiformer隱式搜索,并行輸入的PCW,和并行解碼的NBCE方案
浙公網安備 33010602011771號