解密prompt系列54.Context Cache代碼示例和原理分析
Context Cache的使用幾乎已經是行業共識,目標是優化大模型首Token的推理延時,在多輪對話,超長System Prompt,超長結構化JSON和Few-shot等應用場景,是不可或缺的。這一章我們主要從原理、一些論文提出的優化項和VLLM開源項目入手,分析下context Cache的實現和適合場景。
重溫KV Cache
Context Cache的本質其實是KV Cache在多次請求之間的復用,所以我們先重溫下KV Cache的原理。
以下是Transformer的基礎模型結構由多層layer串聯構成,而KV緩存的是每一層用于Self-Attention計算的歷史序列的Key和Value取值,所以緩存向量維度是batch_size * seq_len * num_head * head_dim,(以下可視化來自LLM Visualization和transformer-explainer )

只所以對self-attention的KV進行緩存,因為在Transformer的眾多計算單元中只有self-attention是上下文依賴的,也就是在計算第k個token的輸出時,需要使用第k個token(Query)和前面K-1個token的(K&V)進行內機計算,使得self-attention的計算復雜度隨序列長度平方增長。而如果對歷史序列中的KV進行緩存后,每次生成新token只需要計算當前token,這樣時間復雜度就可以降為線性(O(n)),顯著降低計算量。

而其他Linear, FFN層的計算都是針對dim層,每個token的計算獨立,和歷史token無關和序列長度無關,因此也沒有緩存的必要。
當然KV Cache也不全是優點,雖然能顯著降低推理延時,但是會帶來較大的顯存占用,占用顯存和序列長度、模型層數成正比。
那現在眾多大模型API廠商都支持的Context Cache能力和KV cache有哪些區別呢?
首先從cache共享上,因為KV cache只用于單一序列的推理過程中,因此沒有任何共享問題,本次推理完成即釋放,而context cache會在多次推理請求之間共享,因此對于如何命中Cache,管理Cache, 提升Cache的存儲和使用效率,就需要更多的考慮。
其次從使用時機上,KV cache是用于首Token之后的增量預測(auto-regressive Phrase),而Context Cache是用于首Token之前的存量計算(Prompt Phrase)。在context cache出現之前這兩個階段其實有明確的劃分,計算prompt的階段需要對全部序列進行attention計算屬于數據計算密集的任務,而解碼階段因為kV Cache的存在更多是存儲密集型任務。因此Context Cache是面向首Token延時的優化方案。
KV Cache只是Context Cache的基礎使用形式,下面我們會分別就Contxt Cache的幾個核心問題包括命中率低,等討論一些優化方案
并行效率更高: Chunk Attention
- https://github.com/microsoft/chunk-attention/tree/main
- ChunkAttention: Efficient Self-Attention with Prefix-Aware KV Cache and Two-Phase Partition
微軟的Chunk Attention通過對KV Cache分段存儲,實現self-attention的并發計算,提升長序列self-attention計算效率,經評估自注意力內核的速度可以提高3.2-4.8倍。下面我們結合源碼簡單說下,源碼chunk attention是C語言寫的,這里簡單轉換成python pseudo
- Prefix Aware KV Cache
傳統的KV緩存以密集tensor形式存儲,大小是batch * head * seq length * dim,當多個序列共享相同的前綴(prompt)時這些kv緩存就是相同的,因此可以進行共享存儲,共享存儲的大小則是head * chunk length * dim,這里論文選擇使用前綴樹進行存儲。用前綴不用其他存儲方式,主要是為了保證位置編碼的可用性。這部分就不細說了,就是Trie樹,child節點存儲的是對應chunk 序列的kv tensor。而這里的分塊存儲,為后面推理階段并行self-attention的計算提供了可能性。
- Two-Phase Partition(TPP)
兩階段的推理算法則是如何更高效的使用以上共享前綴的KV Cache。第一步是共享前綴的部分,讀取前綴樹種存儲的chunk KV Cache,進行chunk粒度的并行self-attention計算并存儲多個partial attention。
然后剩余未命中前綴的部分,按序列計算剩余attention,之后通過attention合并把把一個序列的多段attention進行合并得到最終的attention.
以下是C轉換成python的偽代碼
# 偽代碼體現兩階段推理核心邏輯
class PartialAttention:
def __init__(self, chunk_size=64):
self.chunk_size = chunk_size # 分塊大小
self.scale = 1 / sqrt(d_head) # 縮放因子
def chunk_first_phase(self, Q_shared, K_shared, V_shared):
"""
分塊優先階段計算(對應論文Algorithm 1)
Q_shared: [n_shared_seqs, n_heads, d_head]
K_shared: [n_shared_seqs, chunk_size, d_head]
V_shared: [n_shared_seqs, chunk_size, d_head]
"""
# 步驟1: 計算分塊注意力得分
attn_scores = torch.einsum('bhd,bsd->bhs', Q_shared, K_shared) * self.scale
# 步驟2: 在線計算局部softmax
max_values = attn_scores.max(dim=-1, keepdim=True).values # [b, h, 1]
exp_scores = torch.exp(attn_scores - max_values) # 數值穩定
# 步驟3: 保存中間結果(對應論文公式1)
partial_attn = {
'exp_scores': exp_scores, # [b, h, s]
'max_values': max_values, # [b, h, 1]
'sum_exp': exp_scores.sum(-1), # [b, h]
'partial_v': torch.einsum('bhs,bsd->bhd', exp_scores, V_shared)
}
return partial_attn
def sequence_first_phase(self, Q_private, partial_results):
"""
序列優先階段(對應論文Algorithm 2)
Q_private: [1, n_heads, d_head] (單個序列的查詢)
"""
# 步驟1: 初始化累計變量
global_max = -float('inf')
global_sum = 0
merged_output = 0
# 步驟2: 合并所有分塊結果(對應論文公式2)
for partial in partial_results:
local_max = partial['max_values']
local_exp = partial['exp_scores']
local_sum = partial['sum_exp']
local_v = partial['partial_v']
# 調整指數差值
adjust_factor = torch.exp(local_max - global_max)
# 更新全局統計量
new_global_max = torch.max(global_max, local_max)
new_global_sum = global_sum * adjust_factor + local_sum * torch.exp(local_max - new_global_max)
# 更新輸出
merged_output = merged_output * adjust_factor + local_v * torch.exp(local_max - new_global_max)
# 保存新全局值
global_max = new_global_max
global_sum = new_global_sum
# 步驟3: 最終歸一化
return merged_output / global_sum.unsqueeze(-1)
而之所以attention可以通過局部計算再合并后效果和原始attention一致,不過是因為指數計算的變換,我們拋開論文里面不太好理解的pseudo code,看一個具體的case
- 步驟1:分塊計算中間結果
- 塊1(s?, s?):
- 計算局部最大值:m? = max(s?, s?)
- 調整后指數:e^{s? - m?}, e^
- 局部分母:sum? = e^{s? - m?} + e^
- 塊2(s?, s?):
- 計算局部最大值:m? = max(s?, s?)
- 調整后指數:e^{s? - m?}, e^
- 局部分母:sum? = e^{s? - m?} + e^
- 步驟2:合并中間結果
- 全局最大值:M = max(m?, m?)
- 調整因子:
- 塊1調整因子:α = e^
- 塊2調整因子:β = e^
- 合并的分母
total_sum = α * sum? + β * sum?
= e^{m? - M}*(e^{s? - m?} + e^{s? - m?}) + e^{m? - M}*(e^{s? - m?} + e^{s? - m?})
= e^{s? - M} + e^{s? - M} + e^{s? - M} + e^{s? - M}
- 合并后的attention計算結果
total_sum = α * sum? + β * sum?
= e^{m? - M}*(e^{s? - m?} + e^{s? - m?}) + e^{m? - M}*(e^{s? - m?} + e^{s? - m?})
= e^{s? - M} + e^{s? - M} + e^{s? - M} + e^{s? - M}
空間利用更高:Radix Attention
- SGLang Efficient Execution of Structured Language Model Programs.
- https://github.com/sgl-project/sglang

同樣是使用樹形存儲,SGLang使用了Radix Tree結合LRU的存儲策略(論文還有更多提高cache命中率之類的策略這里不予贅述)。所以其實核心就在Radix Tree和Prefix Tree的對比了。
像前綴樹每個節點只能是單個Token,而Radix支持可變長的Token列表,因此可以節省大量節點指針是空間效率更高的存儲模式,這里我們直接舉個例子
存儲以下鍵:"test", "team", "slow", "slowly"。
- Trie Tree
root
├─ t
│ ├─ e
│ │ ├─ s → t
│ │ └─ a → m
└─ s
└─ l → o → w
└─ l → y
- Radix Tree
root
├─ te
│ ├─ "st" → [leaf: "test"]
│ └─ "am" → [leaf: "team"]
└─ "slow"
└─ "ly" → [leaf: "slowly"]
整體上對比下ChunkAttention是固定長度的分塊,樹結構是靜態的,對于類似超長systemt prompt、multiple few-shot、超長contxt的場景更加合適,因為能通過多chunk實現self-attention的并發計算。而Radix Attention基于Radix Tree,支持動態可變長的前綴,自然更適合類似多輪對話,multi-step思維鏈推理等動態前綴場景。
顯式管理Cache:PML
- PROMPT CACHE: MODULAR ATTENTION REUSE FOR LOW-LATENCY INFERENCE

Prompt Cache則是提出了標記語言PML,支持用戶顯示標記輸入內容中,哪些內容是需要被cache的。這里包含3個主要機制
- 通過XML語言標記的cache模塊
如上圖所示,每個module本身就是一個存儲塊,被XML包裹的內容會分別進行KV cache存儲。下次在推理時會根據XML的tag進行對應Key,Value Tensor Cache的獲取。
- 不連續的位置編碼ID
但這就和前面的chunk Attention,Radix Attention有了顯著的差異,就是PML支持非連續位置的緩存。這也引出了一個問題,只要不是從頭開始的Prompt Cache,都會存在緩存cache的位置編碼和使用該cache所在位置不同的問題。
論文給出的解釋是他們經過實際測試,發現模型對非連續位置編碼有包容性,只要cache的的模塊內容部相對位置編碼正確即可,即便下一次Cache出現的位置和它被緩存的絕對位置不同,也不會影響推理效果。但個人認為這依賴用戶PML定義的極端合理化,也就是每個XML包裹的內容都是一個完整的語義內容,且彼此之間相對獨立。例如多個few-shot使用這個模式就是可以的,一個system prompt里面<requirement>和<output foramt>用這種方式應該也可以。但要是放在多輪對話場景,或者multi-step思考推理場景,我感覺會出現問題。
- 緩存模塊的拼接與新內容的計算
當一個用戶提示由多個模塊(有些是緩存的,有些是新的)組成時,Prompt Cache 會:
- 從緩存中檢索已緩存模塊的KV狀態
- 對于提示中未被緩存的新文本段(比如插入到兩個緩存模塊之間的文本,或者在緩存模塊之后的新文本),系統會根據它們在最終完整提示中的實際位置,為這些新文本段計算新的KV狀態和相應的位置ID。例如,如果一段新文本插入在起始位置為0的模塊A和起始位置為110的模塊C之間,且模塊A長度為50,那么這段新文本的位置ID將從50開始。
- 最后,系統將這些來自緩存的、帶有原始(可能不從0開始)位置ID的KV狀態,與新計算的、帶有新分配位置ID的KV狀態,按照它們在完整提示中的正確順序拼接起來 。
VLLM源碼分析
最后我們來直接看一個生產級別的源碼實現。以下是VLLM中KV Cache的整個調用鏈路。整個調用鏈路如下
- 初始化階段
- LLMEngine 初始化時調用 _initialize_kv_caches()
- Worker 通過 determine_num_available_blocks() 計算可用的 GPU 和 CPU 塊數
- ModelRunner 初始化 KV 緩存配置
- BlockPool 創建并管理 KVCacheBlock 對象
- 推理階段
- 客戶端發送推理請求
- LLMEngine 接收請求并轉發給 Worker
- Worker 調用 ModelRunner 執行模型推理
- ModelRunner 通過 KVCacheManager 獲取或分配 KV 緩存塊
- KVCacheManager 從 BlockPool 獲取新的緩存塊
- Attention 層使用 KV 緩存進行注意力計算

核心的KV Cache存儲在KVCacheBlock類中,使用了和前面Chunk Attention相同的塊存儲機制,這里通過hash所有前綴token+當前塊token得到block_hash
以下為cache存儲、生成cache的哈希ID部分的核心代碼。
@dataclass
class KVCacheBlock:
block_id: int
ref_cnt: int = 0
_block_hash: Optional[BlockHashType] = None
prev_free_block: Optional["KVCacheBlock"] = None
next_free_block: Optional["KVCacheBlock"] = None
def hash_block_tokens(
hash_function: Callable,
parent_block_hash: Optional[int],
curr_block_token_ids: Sequence[int],
extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHashType:
"""計算塊的哈希值,用于前綴緩存"""
if not parent_block_hash:
parent_block_hash = NONE_HASH
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
return BlockHashType(
hash_function(
(parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
curr_block_token_ids_tuple, extra_keys)
VLLM還使用了Radix Attention的LRU最少使用驅逐策略,通過ref_cnt的引用計數追蹤不使用的cache。同時當計數歸0,cache不會直接被釋放,而是會被添加到evictor,當內存壓力大時再進行釋放,保證更大程度的cache復用,避免頻繁地內存分配和釋放,相關驅逐代碼如下
def _decr_refcount_cached_block(self, block: Block) -> None:
block_id = block.block_id
assert block_id is not None
refcount = self._refcounter.decr(block_id)
if refcount > 0:
block.block_id = None
return
else:
assert refcount == 0
# 將塊添加到 evictor 而不是直接釋放
self.evictor.add(block_id, block.content_hash,
block.num_tokens_total,
self._block_tracker[block_id].last_accessed)
class LRUEvictor(Evictor):
def evict(self) -> Tuple[int, int]:
while self.priority_queue:
last_accessed, _, block_id, content_hash = heapq.heappop(
self.priority_queue)
if (block_id in self.free_table and
self.free_table[block_id].last_accessed == last_accessed):
self.free_table.pop(block_id)
return block_id, content_hash
當用戶的Query進來,會按照固定chunk大小進行切分,然后從左到右進行去尋找cache,實現最長前綴命中。不過這里命中都是整個Block命中,也就是不會像Radix Tree一樣支持Token級別的命中。命中前綴部分的代碼如下
def find_longest_cache_hit(self, block_hashes: list[BlockHashType], max_length: int) -> list[KVCacheBlock]:
computed_blocks: list[KVCacheBlock] = []
max_num_blocks = max_length // self.block_size
for i in range(max_num_blocks):
block_hash = block_hashes[i]
if cached_block := self.block_pool.get_cached_block(block_hash):
computed_blocks.append(cached_block)
else:
break
if self.use_eagle and len(computed_blocks) > 0:
computed_blocks.pop()
return computed_blocks
閉源 context Cache文檔
- https://ai.google.dev/gemini-api/docs/caching?hl=zh-cn&lang=python
- https://platform.openai.com/docs/guides/prompt-caching
- https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
想看更全的大模型論文·微調預訓練數據·開源框架·AIGC應用 >> DecryPrompt

Context Cache的使用幾乎已經是行業共識,目標是優化大模型首Token的推理延時,在多輪對話,超長System Prompt,超長結構化JSON和Few-shot等應用場景,是不可或缺的。這一章我們主要從原理、一些論文提出的優化項和VLLM開源項目入手,分析下context Cache的實現和適合場景。
浙公網安備 33010602011771號