探秘Transformer系列之(28)--- DeepSeek MLA
探秘Transformer系列之(28)--- DeepSeek MLA
0x00 概述
MLA(Multi-head Latent Attention / 多頭潛在注意力)的基本思想是將注意力輸入\(h_t\) 壓縮成一個低維的潛在向量 \(c^{KV}_t\) ,維度為 \(d_c\),且 \(d_c\) 遠小于原始的維度(\(h_nd_h\))。在需要計算注意力時,可將這個潛在向量 \(c^{KV}_t\) 映射回高維空間。因此,只需要存儲潛在向量 \(c^{KV}_t\) ,就可以顯著減少內存的占用。
這個過程可以通過以下公式更正式地進行描述。其中 \(c^{KV}_t\) 表示潛在向量;\(W^{DKV}\) 是壓縮矩陣(上標 D 代表"下投影",即降維操作),負責將 \(h_t\) 的維度從(\(h_n·d_h\))壓縮到\(d_c\);\(W^{UK}\)和 \(W^{UV}\) 是上投影矩陣,負責將共享的潛在向量 \(c^{KV}_t\) 映射回高維空間。只需要存儲這個潛在向量 \(c^{KV}_t\) ,就能獲得對應不同文本特征的Key和Value,而不需要對每個文本特征都存儲對應的Key和Value。
類似地,我們也可以將查詢向量映射到一個潛在的低維向量,然后再將其映射回原始的高維空間。而且,MLA又結合了權重吸收技術,減少了計算開銷。

注:
- 全部文章列表在這里,估計最終在35篇左右,后續每發一篇文章,會修改此文章列表。cnblogs 探秘Transformer系列之文章列表
- 本系列是對論文、博客和代碼的學習和解讀,借鑒了很多網上朋友的文章,在此表示感謝,并且會在參考中列出。因為本系列參考文章實在太多,可能有漏給出處的現象。如果原作者或者其他朋友發現,還請指出,我在參考文獻中進行增補。
0x01 原理
1.1 問題
標準Transformer的一大障礙就是KV Cache的空間占用問題:多頭注意力機制需要為每個注意力頭單獨存儲歷史生成的Key和Value向量(即KV緩存)。隨著序列長度增加,KV緩存的存儲需求呈指數級增長,導致內存占用急劇上升。而GPU的顯存空間往往非常有限,較大的KV Cache會導致同時處理的request數量變少,也即batch size較??;為了減少KV緩存需求,研究人員提出了像Multi-Query Attention(MQA)和Group-Query Attention(GQA)這些方法。這些方法雖然降低了緩存要求,可模型的性能也受到影響。MQA或GQA算子在計算注意力的過程中,所有KV Cache中的數據讀取后都僅參與一次或幾次計算,導致該算子的MFU極低,并且由于每個request有自己的KV Cache,這一問題無法通過提高batch size的方式解決。
因此,如何減少推理過程的KV Cache,從而實現在更少的設備上推理更長的Context,或者在相同的Context長度下增大batch size,實現更快的推理速度或者更大的吞吐總量,最終降低推理成本。是一個關鍵問題。
1.2 當前狀況
我們首先總結下當前各種方案的情況來看看有沒有可以改進的空間。下圖給出了MHA、GQA、MQA 以及 MLA 做法。

圖上從左到右依次是 MHA、GQA、MQA 以及 MLA 。圖中有陰影的長條表示會緩存到顯存的結果。MHA、GQA、MQA 都需要將 KVCache 緩存到顯存。幾種方案特點如下。
- MHA:MHA KVCache 在注意力頭這個維度和 Q 矩陣一樣,屬于“一對一”。MHA把一個注意力計算拆成多個注意力頭,每個注意力頭使用獨立的Q、K、V進行計算,需要把K、V都存儲下來,KV Cache中每個token需要緩存的參數量為\(2??_???_???\)。而GQA、MQA 在注意力頭的維度比 Q 矩陣小。
- MQA:所有查詢頭共享相同的單一鍵和值頭,因此只需要存儲共享的K和V,KV Cache中每個token需要緩存的參數量為\(2d_hl\)。在計算注意力時,會把共享的單一K頭和V頭廣播給每個查詢頭,然后分別一一計算。
- GQA:將所有的Q頭分成g組,同一組的Q頭共享一個K頭和一個V頭,因此KV Cache中每個token需要緩存的參數量為\(2??_g??_???\)。在計算注意力時,會把KV頭復制給所在組的所有Q頭進行計算。
\(n_h\)是注意力頭數量,\(n_g\)是GQA分組數,\(d_h\)是隱藏層維度,\(l\)為模型層數,\(?_??\in??^??\) 表示第 ?? 個token在一個attention層的輸入。
1.3 改進思路
MLA是對MHA、GQA、MQA方案的改進,其思路是加強信息壓縮能力(對應下圖標號1)和豐富信息表達能力(對應下圖上標號2),其實,兩個標號也對應了從輸入到Q、K、V的數據流上兩個關鍵點,也是硬幣的兩面:增強了矩陣的表現能力的同時,也會使得壓縮能力更大。

于是就是研究人員經常遇到的困境了:既要壓縮更低(降低推理過程中的KV Cache資源開銷),又要表現力更強(緩解MQA、MGA對性能的損耗),或者說新方案的表現力要盡可能接近MHA。
1.3.1 增強信息壓縮能力
思路
從某個角度考慮,MQA和GQA 也屬于低秩壓縮的思想,MQA將 \(2n_?\) 壓縮到2,GQA 則壓縮到 \(2n_?/g\)。但是壓縮能力和性能難以兼顧,所以GQA效果要好于MQA。
因此我們要思考,是不是可以在“增強信息壓縮能力且兼顧效果”之上再進一步?因為MQA在KV頭上已經幾乎做到了極致,因此我們沒法從KV頭數量上做減少。那就勢必得從KV本身思考。目前,不管是GQA還是MQA,都需要緩存K、V兩個值,兩個值不一樣。那么,是否可以把兩個值合并為一個值?有沒有可能每個緩存的KV都比之前???從LoRA那里得到啟發,一個\(M \times N\)的矩陣可以近似成兩個\(M\times k\)和\(k \times N\)矩陣的乘積,如果我把一個K或者V矩陣拆成兩個小矩陣的乘積,就可以減少KV Cache的顯存占用。
方案
MLA的核心是對注意力鍵和值進行低秩聯合壓縮,以減少推理期間的鍵值(KV)緩存大小,從而提高推理效率。與 GQA、MQA 直接壓縮 KVCache 頭維度不同,MLA通過使用下投影矩陣 \(W^{DKV}\)將多個注意力頭的Key和Value投影到一個低維的共享潛在向量空間中,取代了傳統的逐頭存儲方式。
具體而言,MLA將KV矩陣轉換為低秩形式:將原矩陣表示為兩個較小矩陣(相當于潛向量)的乘積。具體而言,
- 對輸入矩陣的 HiddenState 會先做低秩轉換,將一個 Shape 為 [S,H] 的 HiddenState 壓縮到 Shape 為 [S,CH] 的潛在向量\(c_t^{KV}\),其中 CH?H 。H是token維度。
- 將壓縮后的KV向量\(c_t^{KV}\)作為 KVCache 存儲到顯存中,這樣就達到了降低 KV 大小的目的。在V2的論文中, \(K_t\) 的表達從 \(W^Kh_t\) 變為 \(W^{UK}W^{DKV}h_t\) , 原來緩存的是 \(W^Kh_t\),而現在緩存的是 \(K_t\) 的一部分 \(W^{DKV}h_t\)。
問題
但這有一個問題,如果單純的把一個大的K/V矩陣拆成2個小矩陣進行緩存,在推理的時候,還是需要計算出完整的K矩陣,這樣就失去了緩存的意義,畢竟緩存的意義就是減少計算。
問題就變成:有沒有一種方法,即能減少緩存大小,又不增加推理時候的計算?
1.3.2 豐富信息表達
思路
我們可以注意到,在MQA和GQA計算注意力時,只用到了簡單的廣播或者復制機制把KV頭復制給對應的Q頭進行計算。我們以GQA為例,GQA 目的是減少KV Cache占用,存儲的是KV,即\(C^{KV}\)。下面公式是如何得到k(這里省略了v的操作)。
- 首先它將向量對半分為兩份分別作為K、V。
- 然后每一份又均分為g份。
- 每一份復制h/g次,以此來“湊”夠h個Attention Head所需要的K、V。
這里的\(W^{UK}\)是一組簡單線性變化(比如簡單復制)的組合,其表現能力是有限的,所以其壓縮維度不大。
既然MQA和GQA的信息表達能力不強,那么我們是不是可以引入一個矩陣變換來替代這些簡單的線性變換操作(切分、復制)?比如通過針對每個 ?? 都去自適應學習,這樣就可以讓這一層的信息表達更加豐富。
方案
我們已經得到了潛在向量\(c_t^{KV}\),那么就可以在推理期間使用每個頭的上投影矩陣\(W^{UK}\)(用于“鍵”)和\(W^{UV}\)(用于“值”)從這個潛在向量中\(c_t^{KV}\)重建K和V。
具體而言,MLA 在 Decode 階段將:
- 加載壓縮的KVCache潛在向量 \(c_t^{KV}\)。
- 然后通過上投影矩陣\(W^{UK}\)和\(W^{UV}\)做兩個升秩轉換,分別轉換為 Shape 均為[S,H] 的 K、V 矩陣,即從潛在向量中恢復出每個頭的Key和Value(將這個潛在向量映射回高維空間)。上投影矩陣\(W^{UK}\)和\(W^{UV}\)做兩個升秩轉換起到的作用比GQA 的簡單線性變化(比如簡單復制)的組合要大得多。
- 進行 MHA 計算。這樣,MLA在推斷過程中僅緩存潛向量,而不緩存完整的鍵KV。
MLA的本質是對KV信息的有損壓縮,但MLA可以通過訓練學習到如何在提高存儲信息密度的同時盡可能保留關鍵細節。這規避了分組查詢注意力和多查詢注意力的查詢的信息損失,從而在降低KV緩存的前提下獲得更好的性能。從MLA算子的計算特征來看,同時解決了這兩方面的問題:
- 一方面,通過低秩壓縮大幅降低了推理過程中的KV Cache資源開銷。減少推理過程的KV Cache,從而實現在更少的設備上推理更長的Context,或者在相同的Context長度下增大batch size,實現更快的推理速度或者更大的吞吐總量,最終降低推理成本。
- 另一方面,MLA解壓縮后的多頭注意力機制能夠提供較高的計算強度(正比于 Head 數),有助于充分利用GPU的算力資源,緩解MQA、MGA對性能的損耗。MLA 通過低秩轉換方式壓縮 KVCache,從公式來看引入了額外的升秩轉換計算,并且需要存儲升秩轉換計算的激活值結果。但可以根據矩陣乘的交換率特性,將升秩轉換的矩陣乘權重和其他權重融合,然后在 attention kernel 直接完成 attention 計算,無需引入額外的計算開銷以及存儲開銷。
1.3.2 解決位置編碼沖突
然而,壓縮和RoPE位置編碼是沖突的,即矩陣吸收后的\(c_t^{KV}\)沒有了位置相關信息(原因:RoPE對key和query都是位置敏感的)。在這種情況下,只依靠\(c_t^{KV}\)來壓縮KV-Cache的路是行不通的,所以需要額外的信息來表達qk之間位置關系。為了走出這個困境,DeepSeek提出了一種折中的方法:使用\(W^{QR}\)和\(W^{KR}\)兩個矩陣來表征跟ROPE相關的特征提取,為q和k都增加一個額外的維度\(d^R_h\)來添加ROPE編碼,之前的\(d_h\)維度不使用ROPE編碼,總長度變為\(d_h+d_r\)。即,MLA采用了MQA的思想,構造了所有head共享的cache變量\(c_t^{KV}\)和 \(k_i^R\),這樣才大幅降低了KV Cache。其中 \(c_t^{KV}\)是參數低秩分解中Down處理后Up處理前的低維向量,而\(k_i^R\) 可視作是MQA版本的RoPE。
具體參見下圖。

1.4 架構圖 & 流程
作為對比,下圖給出了MHA的數學公式,對于每個token需要緩存\(2n_hd_hl\)個元素。如果是千問72B,則需要$2 \times 80 \times 64 $。在這里 \(??_{??,??}\),\(??_{??,??}\),\(??_{??,??}\) 都是用列向量表示。t是第t個token,j是迭代第1到t個token的序號,i是迭代head的序號。

下圖給出了MLA的架構圖,以及公式。

圖中,黃色區域公式主要是為了計算Q(即Attention中的Q矩陣)。綠色區域主要是為了計算K的位置不敏感部分。紫色區域是計算K的位置敏感部分;灰色是把K聚合起來;紅色是計算V。具體流程如下:
- 查詢(Q)的降維壓縮:輸入序列中的 t 個Token(\(h_t\))通過一個下投影矩陣\(W^{DQ}\)壓縮為壓縮潛在向量\(c_t^{Q}\)(其維度遠遠小于輸入Token的維度)。此處對應圖上標號37。
- 鍵(K)和值(V)的聯合壓縮:輸入序列中的第t個Token(\(h_t\))通過一個下投影矩陣\(W^{DKV}\)壓縮為壓縮潛在向量\(c_t^{KV}\)(其維度\(d_c\)遠遠小于輸入Token的維度d)。在推理階段,MLA僅需要緩存\(c_t^{KV}\),即KV緩存僅\(d_c \times l\)個元素,其中l為模型層數。此處對應圖上標號41。
- 解耦RoPE策略:為提高模型對序列中上下文信息的敏感性,MLA中應用了解耦旋轉位置編碼(RoPE)技術。因RoPE與低秩KV壓縮矩陣不兼容,故MLA引入額外的查詢向量\(q_t^R\)和共享鍵向量\(k_t^R\)來攜帶RoPE信息,避免了RoPE與低秩壓縮矩陣之間的耦合問題,解決了位置信息與推理效率之間的矛盾。此處大致對應圖上標號39和標號43。
- 恢復信息:進行注意力計算時,進行注意力計算時,\(c_t^{KV}\)分別通過上投影矩陣\(W^{UK}\)和\(W^{UV}\)還原出鍵和值,此處對應圖上標號42和45。每個注意力頭上的鍵再與攜帶了RoPE信息的共享鍵向量\(k_t^R\)拼接形成MHA的鍵值輸入,此處對應圖上標號44。\(c_t^{Q}\)通過上投影矩陣\(W^{UQ}\)和\(W^{QR}\)升維還原并生成查詢向量\(q_t^C\)(對應圖上標號38)和攜帶RoPE信息的查詢向量\(q_t^R\)(對應圖上標號39),二者拼接形成MHA的查詢向量輸入,此處對應圖上標號40。
- 注意力計算。此處對應圖上標號46。
- 最終多個頭的輸入拼接在一起,并經過線性映射\(W^O\)得到最終的輸出。此處對應圖上標號47。
從圖上可以看出MLA的特色:
從定性角度看,可以節約內存,因為:
- 在進入標準MHA算法之前,用壓縮的向量來替代之前更長的KV向量。之前是緩存K和V兩個向量,現在只存儲壓縮后的一個向量。
- 不僅僅壓縮了KV,而且還能重建成K和V(不是標準MHA下面的K和V)。
如果定量來可看,每個Transformer層,只緩存了上述公式藍框的向量: \(??_??^{????}\) 和 \(??_??^??\) ,其它的都可以利用“矩陣吸收”,重新恢復過來。 \(??_??^{????}\) 和 \(??_??^??\) 這兩個向量的大小分別為:
\(??_??^{????}\) : 維度為 \(??_??=4×??_?\)。\(d_h\)是單個頭的向量維度。 \(??_??^{????}\) 是多頭共享的。
\(??_??^??\) :維度為 \(??_?^??=??_?/2\)。\(??_??^??\) 是多頭共享的。
對比MQA(每層有一個\(??_?\) 維度的 ?? 和 一個 \(??_?\) 維度的 ?? ,共 2\(??_?\) 個元素),MLA相當于增加了2.25倍的存儲。對比MHA的\(2n_hd_h\),則\(n_h\)會大于2.25,所以肯定減少緩存。
1.5 代碼
下圖給出了DeepSeek V3源碼中MLA的定義部分。
class MLA(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
#隱藏層維度
self.dim = args.dim
# 注意力頭的總數量
self.n_heads = args.n_heads
# 計算每個并行進程的本地注意力頭數量
self.n_local_heads = args.n_heads // world_size
# 對應 query 壓縮后的隱向量的維度 d'_c
self.q_lora_rank = args.q_lora_rank # q的低秩壓縮的維度
# 對應 key-value 壓縮后的隱向量維度 d_c
self.kv_lora_rank = args.kv_lora_rank # kv的低秩壓縮的維度
# 表示query和key的向量中,不應用旋轉位置編碼部分的頭維度, $d_h$
self.qk_nope_head_dim = args.qk_nope_head_dim
# 對應$d_h^R$,表示應用了旋轉位置編碼的query和key的一個頭的維度。
self.qk_rope_head_dim = args.qk_rope_head_dim
# $d_h + d_h^R$, 注意力頭大小為非rope部分大小加上rope部分大小
self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
# value 的一個注意力頭的隱藏層維度
self.v_head_dim = args.v_head_dim
if self.q_lora_rank == 0:
# 不適用低秩分解,回歸到傳統MHA
self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)
else:
# 其實就是$W^{DQ}$,用來生成$c_t^Q$
# 下采樣矩陣,得到壓縮后的q向量
self.wq_a = Linear(self.dim, self.q_lora_rank)
self.q_norm = RMSNorm(self.q_lora_rank)
# $W^{UQ}$
# 上采樣矩陣,用來恢復q向量
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
# $ [W^{DKV}; W^{KR}] $
# 下采樣矩陣,得到壓縮后的kv向量
self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
self.kv_norm = RMSNorm(self.kv_lora_rank)
# 上采樣矩陣,用來恢復kv向量
# $ [W^{UK}; W^{UV}] $
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
# output權重矩陣
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
# 計算1/sqrt(d_k)
self.softmax_scale = self.qk_head_dim ** -0.5
if args.max_seq_len > args.original_seq_len:
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
self.softmax_scale = self.softmax_scale * mscale * mscale
if attn_impl == "naive": # native模式下,kvcache存儲的是沒有壓縮的數據,大小為d_h + d_h^R, 不但沒有節省,反而增加了顯存消耗
self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
else:
# 在非native模式下,存儲的是壓縮的c,大小為d_c
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
很明顯,MLA算子是針對現代GPU硬件特點“量體裁衣”定制的一個注意力機制,通過對存儲和計算的再平衡,能夠充分發揮現代GPU的各項優勢。我們接下來就對MLA的幾個核心實現要點進行仔細分析。
0x02 核心要點
MLA的核心要點如下:
- 通過低秩KV聯合壓縮(Low-Rank Key-Value Joint Compression)降低了KV Cache的資源占用。在計算注意力時,對壓縮后的向量進行升維變換,進而增強模型的表達能力。
- 通過權重吸收減少了向上投影的計算量。
- 通過解耦RoPE策略(Decoupled Rotary Position Embedding)來解決RoPE和權重吸收的沖突。
2.1 低秩KV聯合壓縮
2.1.1 低秩分解
低秩矩陣分解(Low-Rank Matrix Factorization)是一種特別有效的矩陣分解方法,用于發現數據中的低維結構。低秩矩陣分解的核心思想是將一個大矩陣分解為兩個或多個更小、更簡單的矩陣的乘積,這些小矩陣通常具有更低的秩。
在神經網絡層中使用低秩分解一般都是用內存成本換取計算成本,這種方法的變體在LoRA微調等場景中很受歡迎,因為這些場景受限于總內存成本,而不是計算開銷或推理速度。其好處是壓縮后的矩陣使用的參數更少,并且在某種程度上更具表現力(層數增多)。它們最終可以大致近似或等同于一個更大的矩陣,因此在理論上,我們可以將這些矩陣的權重相乘,以恢復原始矩陣的近似值。
其缺點是,我們現在每次使用這些矩陣時都必須執行兩次操作(即,對于每個壓縮和解壓縮的層,我們將矩陣乘法的總數翻倍,以換取使它們變得更?。2⑶乙驗閷⑺鼈兿拗茷橹萺或更低的矩陣,顯然會損失原始矩陣的一部分表示能力。
2.1.2 思路
傳統的注意力機制直接將輸入X映射到QKV的注意力頭維度,MQA和GQA通過共享機制來變相壓縮KV Cache的頭維度。MLA的核心思想是采用類似LoRA的方式表示KV。具體而言,是在prefill期間構建一個壓縮空間,對輸入矩陣的HiddenSize 維度進行壓縮。即先將輸入X映射到隱向量c存儲起來。簡單理解就是,假設有個矩陣的維度是?????,那么可以將其分解為兩個?????的矩陣相乘,而?????。這樣就降低了存儲量。在decode階段計算注意力之前,會通過上投影矩陣將c恢復到QKV的原始維度,這樣可以減少注意力鍵(keys)和值(values)在推理過程中的緩存,從而提高推理效率。
其實這里還有一個問題:按照這種低秩方案,傳統意義上的\(W^Q, W^K, W^V\)全部變成了低秩矩陣。既然存在了低秩矩陣對滿秩矩陣的替換,就可能存在性能問題。既然DeepSeek做了替換且效果不錯,就說明\(W^Q, W^K, W^V\)這幾個原來的滿秩矩陣可能就是冗余的,具備較大的低秩特性。

在實現過程中,Q、K 和 V 的權重矩陣通常會進行融合以提高 GPU 的計算和內存效率。與分別進行獨立投影不同,采用組合權重矩陣可以優化運算過程。
2.1.3 向下投影

上圖給出了向下投影的具體流程,其中 \(?_t\) 作為輸入向量, \(W^{DKV}\) 和\(W^{DQ}\)為壓縮矩陣,用于降維, \(c_t^{KV}\) 和 \(c_t^Q\) 分別是壓縮后的KV潛向量和Q潛向量(潛向量的維度遠小于輸入向量的維度與自注意力頭數之積)。這個 \(c_t^{KV}\) 是和具體的哪個head(索引為i)無關的,需要被緩存,相當于說,我們不再直接緩存key/value這兩個維度和\(?_??\) 一樣的向量,而是緩存 \(c_t^{KV}\) ,并通過計算來動態的恢復 \(k_??\)和 \(v_t\) 。
- 對于KV,構建一個共享的降維映射矩陣\(??^{??????}\)用來對模型輸入進行降維。
- \(??^{??????}\)會將輸入\(?_??\)(hidden state)投射到隱向量\(??_??^{????}\),這是key和value的聯合隱向量。即將一個 Shape 為 [S,H] 的 HiddenState 壓縮到 Shape 為 \([S,d_c]\),其中 \(??_??^{????}\)的維度\(??_??\)遠小于多頭key和value的原始維度\(d_h\)。MLA 不保留完整的隱藏維度,而是縮小了它們的尺寸。
- 將壓縮后的KV向量作為 KVCache 存儲到顯存中。推理的過程中只需要緩存每一層的隱向量\(??_??^{????}\)(因為每一層的注意力頭共享該參數)。由于\(??_??^{????}\)的維度遠小于K、V。因此在MLA中,每一步token推理產生的KV Cache參數由之前的\(2??_???_???\)變成\(??_????\),從而大大減少 KV 緩存的內存占用。
- 對于Q,使用降維映射矩陣\(??^{??Q}\)用來對模型輸入進行降維。這與減少KV Cache無關,主要是為了減少訓練期間參數量和相應激活所占的顯存。這是因為在模型訓練的過程中,每個輸入的token會通過多頭注意力機制生成對應的query、key和value。這些中間數據的維度往往非常高,因此占用的內存量也相應很大。
2.1.4 向上投影
當 Decode 階段需要進行 MHA 時,會將加載KVCache,然后利用\(??^{????}\)和\(W^{UV}\)對\(??_??^{????}\)向上投影以恢復更大的尺寸。這個更大的尺寸既可以與原始輸入 \(h_t\)的維度匹配,也可以根據注意力頭的配置來調整。DeepSeek是將KV的維度擴展回\(??=??_???_?\),從圖上也可知,新的\(k_t^C,v_t^C\)分別被均分為\(n_h\)個向量,即每個注意力頭有一個單獨的 ??,?? (跟MHA的KV數量一致)。
具體參見下圖。 \(W^{UK}\) 和 \(W^{UV}\) \(W^{UQ}\)均為投影矩陣,用于升維。注:此處忽略了RoPE,后續會結合RoPE再進行擴充和更新。

結合向下投影和向上投影,我們可以看到, \(W^Q,W^K,W^V\) 的矩陣實際上分別被拆分成了兩個,做成了lora的形式進行信息壓縮,這個形式下MLA就是MQA加上lora形式的擴展,并且計算量從dxd的復雜度減少為 2 x d x c。這種信息壓縮后再恢復原維度的方式相比于之前只有一個矩陣的形式,能很好的幫助網絡進一步學習到更有效的信息。實現了同樣的低秩分解下更好的效果,這就是MLA比GQA更進一步壓縮KV Cache的根本原因。
下圖給出了如何拆分,上方和MLA,下方是作為比對的MQA。

實際上,論文“TransMLA: Multi-Head Latent Attention Is All You Need"就對MLA的表達能力做了相關分析。論文指出傳統的GQA模型在計算注意力的時候,同一組里頭的頭都會共享相同的鍵值對,這就導致它在表達能力上有點受限。而MLA就不一樣啦,它通過低秩分解,再加上獨特的投影矩陣設計,突破了這個限制。
具體參加下圖,在MLA里,就\(W_K^b\)拿來說,如果這里面的向量是正交的,那么每個通道在乘以\(XW_k^a\)之后,輸出在各個通道間都不一樣??蒅QA呢,同一組里所有頭的輸出都是相同的。就因為這種結構上的差別,在KV緩存大小一樣的情況下,MLA的表達能力更強。說白了,MLA通過調整網絡結構,優化參數更新策略,讓注意力計算過程更高效,這樣就能更好地捕捉復雜的語義關系,提升模型的能力。

2.1.5 完整過程
完整的對比過程如下圖。圖中上方是總體思路。下方是MLA和GQA的對比,其中又分為兩部分,上部分是通過公式看看MLA如何增強表現力;下半部分是完整的流程。

2.2 權重吸收
2.2.1 當前狀態
我們目前已經通過向下投影將壓縮的隱向量進行保存,這減少了KV Cache的內存占用。也通過向上投影矩陣增強了表達能力。然而,MLA強調的是激活值也隨之減少。當前我們還看不出來怎么減少激活值的數量的。因為雖然壓縮之后的KV占據內存比較少,但是在每次推理的時候,都必須通過 \(??^{????},??^{????}\) 來根據緩存的\(c_t^{KV}\)重新計算出 \(??_{??,??},??_{??,??}\),單從KV的數量和維度上看跟MHA是一個量級,比GQA和MQA都要多,上采樣后的 kv cache巨大,可能導致OOM。不但內存不少(\(??_{??,??},??_{??,??}\)依然存在),還引入了新的計算量,會處于計算瓶頸。沒有達到用時間和算力來換取空間的目的。
2.2.2 權重吸收
既然每次計算量太大,DeepSeek就想是否可以在保存壓縮的隱向量的基礎上來減少這個計算量(其實也減少了新KV的內存占用),于是他們給出了權重吸收這個法寶。即其作者利用矩陣結合律的特性對這些公式進行了優化,避免了針對每個query重新計算key與value,下面是文章中的原文:

備注:矩陣吸收計算是指利用矩陣乘法的結合律或低秩分解等線性代數技巧,改變矩陣的乘法順序,重新組合某些矩陣因子,使原本需要獨立計算的矩陣乘積合并在一起,避免生成大矩陣,從而降低計算復雜度和內存開銷的過程。
比如,給定三個矩陣 \(A \in R^{m,k}\), \(B \in R^{k,p}\), \(C \in R^{p,n}\),通過矩陣乘法的可知\((A \times B) \times C = A \times (B \times C)\),但是二者的計算復雜度是不一樣的。 \((A \times B) \times C\)的計算復雜度是 $2\times m\times k\times p+2\times m\times p\times n=2\times m\times p\times (k+n) $, \(A \times (B \times C)\) 的計算復雜度是\(2\times m\times k\times n+2\times k\times p\times n=2\times n\times k\times (m+p)\)。當 n 相比 m 和 p 都顯著更小的時候,第二種計算順序的性能會遠好于第一種。假設 ,m=k=p=4096,n=1 ,那么第一種計算順序的計算復雜度是 \(2\times 4096\times 4096\times 4097\),第二種方式的計算復雜度是 \(2\times1\times4096\times8192\),顯著低于第一種。
但是,具體要如何用矩陣吸收,如何使用矩陣乘法結合律,需要權衡計算量,memory讀寫量和瓶頸,可以套用典型的Roofline Model進行分析。這里的核心就是 AC x CB 矩陣的最終效果和 AB 矩陣的效果有多少差異。
2.2.3 推導
KQ合并
我們來結合Dot-Attention的具體形式,看看如何通過一個簡單但不失巧妙的恒等變換來規避這個問題。首先,在訓練階段還是照常進行,此時優化空間不大;然后,在推理階段,我們利用如下公式(不帶位置編碼)可以看到,在推理階段,我們把\({W^{UQ}}^{\top} W^{UK}\) 合并為一個(和位置無關的)矩陣W作為Q的投影矩陣,就可以用\(c_t\)代替原本的\(k_t\)。這樣就避免了重復計算中間結果q和k。

其中轉置 ? 表示交換張量形狀中最后兩個維度。各個張量的形狀如下,這里注意 num_heads 要拎出來成為一個維度,因為最后 attention weight 的結果是頭間獨立的。
-
\(C^Q : [batch\_size,1,q\_len, q\_lora\_rank]\)。
-
\(W^{UQ}:[num\_heads, q\_lora\_rank, qk\_nope\_head\_num]\)。
-
\(W^{UK}:[num\_heads, kv\_lora\_rank, qk\_nope\_head\_num]\)。
-
\(C^K : [batch\_size,1,kv\_len, kv\_lora\_rank]\)
我們每次緩存的 \(c_t^{KV}\)都可以直接參與計算,而不需要顯式的計算出K。而且,W 矩陣是可以事先就通過\({W^{UQ}}^{\top} W^{UK}\)計算出來的(其實就是被神經網絡自動計算出來)。
代碼表述如下:
"""來源:https://mathmach.com/8b428574/"""
# 消融W_UK
W_UQ = tf.reshape(W_UQ, [q_lora_dim, num_head, head_dim])
W_UQ = tf.transpose(W_UQ, perm=[1, 0, 2]) # [num_head, q_lora_dim, head_dim]
W_UK = tf.reshape(W_UK, [kv_lora_dim, num_head, head_dim])
W_UK = tf.transpose(W_UK, perm=[1, 2, 0]) # [num_head, head_dim, kv_lora_dim]
W_UQUK = W_UQ * W_UK # [num_head, q_lora_dim, kv_lora_dim]
# 計算qk內積
c_Q = tf.reshape(c_Q, [batch_size, q_seq_len, q_lora_dim])
c_KV = tf.reshape(c_KV, [batch_size, kv_seq_len, kv_lora_dim])
c_KV = tf.transpose(c_KV, perm=[0, 2, 1]) # [batch_size, kv_lora_dim, kv_seq_len]
c_Q_product_W_UQUK = tf.einsum('bij,hjk->bhik', c_Q, W_UQUK) # [batch_size, num_head, q_seq_len, kv_lora_dim]
q_product_k = tf.einsum('bhik,bkj->bhij', c_Q_product_W_UQUK, c_KV) # [batch_size, num_head, q_seq_len, kv_seq_len]
VO合并
另外,傳統方法需要先計算 Value 向量 \(v_t^C\) ,然后再進行注意力計算并投影到最終的輸出層。我們可以直接將 \(W^{UV}\)吸收到 \(W^{O}\)里,簡化最終的輸出計算。吸收公式如下(此處提取劇透了rope和nope分離的模式):
可以用代碼描述為:
q_pe = W_QR(c_q)
q_nope = W_UQ_UK(c_q)
output = W_UV_O(MQA(q_pe, q_nope, c_kv, k_pe))
注意我們需要小心的通過轉置等手段保證數學上的恒等。參見下面圖,每個注意力頭都可以消融成一個矩陣,因此,實際代碼中可以使用高維矩陣將所有head消融在一個矩陣里,代碼表述見下面。

代碼表述:
"""來源:https://mathmach.com/8b428574/"""
# 消融W_UV
W_O = tf.reshape(W_O, [num_head, head_dim, hidden_dim])
W_UV = tf.reshape(W_UV, [kv_lora_dim, num_head, head_dim])
W_UV = tf.transpose(W_UV, perm=[1, 0, 2]) # [num_head, kv_lora_dim, head_dim]
W_OUV = W_UV * W_O # [num_head, kv_lora_dim, hidden_dim]
# 計算u
q_R = RoPE(c_Q * W_QR) # [batch_size, q_seq_len, num_head, rope_dim]
k_R = RoPE(h * W_KR) # [batch_size, kv_seq_len, rope_dim]
q_product_k_rope = tf.einsum('bijk,bhk->bijh', q_R, k_R) # [batch_size, q_seq_len, num_head, kv_seq_len]
q_product_k_rope = tf.transpose(q_product_k_rope, perm=[0, 2, 1, 3]) # [batch_size, num_head, q_seq_len, kv_seq_len]
attention_weight = tf.softmax((q_product_k + rope_score) / tf.sqrt(head_dim + rope_dim)) # [batch_size, num_head, q_seq_len, kv_seq_len]
c_KV = tf.transpose(c_KV, perm=[0, 2, 1]) # [batch_size, kv_lora_dim, kv_seq_len]
attention_weight_product_c_KV = tf.einsum('bijk,bhk->bijh', attention_weight, c_KV) # [batch_size, num_head, q_seq_len, kv_lora_dim]
u = tf.einsum('bijh,ihd->bjd', attention_weight_product_c_KV, W_OUV) # [batch_size, q_seq_len, hidden_dim]
結合
把目前的合并結合起來,我們得到如下:
這樣,在推理時期\(W^{UK}\)可以和\(W^{UQ}.W^{DQ}\)結合,\(W^{UV}\)和\(W^{O}\)結合,最終只有\(W^Q\)和\(W^O\)。矩陣合并以后,對KV的整個計算過程都在低維空間進行,不會出現再把\(C^{KV}\)解壓縮回高維空間的情況。 而且,上述矩陣全都是模型的權重,再推理過程重是不會變的,可以看作常量。如果是部署推理服務的話,再加載模型的時候就可以把這兩個矩陣乘好,為以后的每次推理節省兩次矩陣乘法。實際上并無額外的算力開銷。MLA就達到了克服以往方法中KV Cache過大的問題并且保留的KV Cache該有的減少重復計算的功能。
2.2.4 討論
訓練
論文中一直提到在推理階段使用權重吸收,這點很好理解,因為此時權重矩陣固定了。
那么什么不在訓練階段直接結合\(W^{UK}\)和\(W^{UV}\),其原因大致如下:
- 從梯度更新的角度來看,不做權重吸收會使得優化更加簡單,即遵從下面的方式進行訓練更好\(\nabla (\phi \psi) =\psi \nabla (\phi ) + \phi\nabla ( \psi)\)。
- 從投影的角度來看,KV共享\(W^{DKV}\)某種意義上對于空間構成了一種約束,Weight Tying 使得模型能夠更好的收斂,并且提高其泛化能力,還可以提高模型的穩定性。
所以,MLA在訓練階段和MHA類似。除了多一步低秩投影以及只在部分維度加RoPE之外,MLA與Q、K的頭維度由\(d_k\) 換成\(d_k+d_R\)的MHA一樣。
MHA
其次,既然權重吸收這么好,為什么MHA沒有做權重吸收?
我們先看看推理階段的特點。
首先,MHA中的計算公式如下(為了演示方便,這里先討論單頭),在標準的MHA實現中,quey、key、value的embedding是分別計算的,然后通過query embedding和key embedding來計算self-attention的權重矩陣,之后將這個權重矩陣和value embedding進行相乘得到最終的結果。但是如果我們展開公式如下。
此處看起來,\((W^Q)^TW^K\)和\(W^VW^O\)都有吸收的可能。
其次,Decode 計算時,輸入的 Q 往往只有一個 token,這就天然給我們一個簡化計算的機會。即這個順序是可以交換的,即從query的embedding出發,一直向下進行計算,得到最終的結果。因為首先將比較小的query embedding參與計算,因此看起來整體計算復雜度會明顯降低。而且看起來和MLA的思路非常類似,即將
K 的 projection 放到 Q projection 之后,將 V projection 放到 attention 之后,output projection 之前。

目前看起來MHA做矩陣吸收的好處頗多。然而,事實并非如下簡單。我們通過\(q_t^Tk_i\) 為例來進行分析為何MHA不適合吸收,以及為何MLA可以提高效率。
對于單個頭,\(n_h\)=1,對應矩陣乘是\([1,d] \times [d, d_h] \times [d_h, d] \times [d, 1]\)。我們來看看這個矩陣乘哪些可以計算,哪些可以存儲。有以下幾種可能:
- 標準KV Cache。
- 存儲角度:我們把\(W^Kh_i\)存儲起來,就是存儲k(v和k一致),則KV Cache大小為:\(2n_hd_hl\).
- 計算角度:每個頭實例化參數是\(W^Q\),\(W^K\),\(W^V\),\(W^O\),大小為\(4dd_h\)。
- 把\((W^Q)^TW^K\)結合到一起,并把結合后的權重施加到x上
- 存儲角度:存儲\((W^Q)^TW^Kh_i\)作為新的cache,其大小為\(2dn_hl\),與KV Cache相比擴大了\(n_h\)倍。
- 計算角度:每個頭實例化參數是\((W^Q)^TW^K\) 和 \(W^VW^O\)。大小為\(2d^2\)。
- 把\((W^Q)^TW^K\)結合到一起,但是只cache x,不cache k和v的權重。
- 存儲角度,需要存儲的cache大小是\(dl\),相比標準kv cache減少了一半;
- 計算角度,每個頭實例化的參數為\((W^Q)^TW^K\) 和 \(W^VW^O\)。大小為 \(2d^2\)
結合上面的分析,標準的kv cache已經相對而言在空間開銷上和計算上是最優的了,盡管我們可以通過只 cache x減少一半的kv cache,但是結合后的矩陣放到運行時計算也增加了計算量,權衡下并不是好的方案。
我們再來看看MLA。\(W^K\)做了低秩變化后,從\([d_h,d]\)變成了\([d_h,r] \times [r,d]\), $ h_tT(WQ)TWKh_i\(變成了\) h_tT(WQ)TWW^{DKV}h_i$。
對應矩陣乘是\([1,d] \times [d,d_h] \times [d_h, d_c] \times [d_c, d] \times [d, 1]\)。我們來看看這個矩陣乘哪些可以計算,哪些可以存儲。 ,那么有以下幾種可能:
- 從存儲的角度:此時存儲的kv cache就是 \(W^{DKV}h_i\), cache大小是 \(d_cl\) ,加上旋轉位置編碼的部分,總的kv cache是$ (d_c+d_h^R)l$ ,和MHA進行比較,則是$ (d_c+d_h^R)/2d$ =(512+64)/(2?5120) =5.58%
- 從計算的角度: \(W^{UK}\) 可以被合并(merge)到 \(W^Q\) 中,類似地,\(W^{UV}\) 可以被合并(merge)到 \(W^O\)中。這樣實例化的權重就變成了原來的 \(d/r\) 分之一
- 無論是存儲還是計算的角度,MLA的拆分方法都優于MHA。
所以到這里我們就明白了,MLA的好處來源于兩個方面,一個是kv cache的顯著降低,另一個是權重的合并和吸收。
不合并
具體實施過程中需要依據實際情況進行抉擇,比如 李偉華大神 在https://developnotes.readthedocs.io/zh-cn/latest/deepseek.html#id1 有精彩論述。
考慮如下運算:\(Y=XAB,C=AB\)。其中\(X \in R^{m\times d}\)是輸入的hidden states,\(A \in R^{d\times d_c}\)和\(B \in R^{d_c \times n}\)是權重矩陣,\(C\in R^{d \times n}\)是吸收后的矩陣。
直接計算\(Y=XAB\)的flops是 \(2mdd_c + 2mnd_c = 2md_c(d+n)\),合并后計算\(C=AB\)的flops是\(2mdn\)。如果\(d_c\)較小,則\(dn \gt d_c(d+n)\),計算量太大,所以不一定需要進行權重吸收。
或者我們使用MLA的實際代碼來看。已知配置如下:
"hidden_size": 5120, # 隱藏層的大小
"kv_lora_rank": 512, # KV壓縮維度
"q_lora_rank": 1536, # Query壓縮維度
"qk_rope_head_dim": 64, # 解耦Query和Key的每個頭部維度
"qk_nope_head_dim":128 #
兩種情況的計算量如下:
- \({c_t^Q}^{\top}{W^{UQ}}^{\top} W^{UK}\)的計算量是:\(2 \times (q\_lora\_rank \times hidden\_size \times qk\_nope\_head\_dim + kv\_lora\_rank \times hidden\_size \times qk\_nope\_head\_dim) = \\ 2 \times hidden\_size \times qk\_nope\_head\_dim(q\_lora\_rank + kv\_lora\_rank ) = \\ 2 \times 5120 \times 128 (1536 + 512 )\)
- \({c_t^Q}^{\top}W^{UQK}\) 的計算量是:$2 \times hidden_size \times q_lora_rank \times kv_lora_rank= 2 \times 5120 \times 1536 \times 512 $。
可以看到,把\(W^{UQ}W^{UK}\)合并后計算量反而增大很多。prefill 的時候其實是不要做“吸收”的,可以按 $ ({c_tQ}{W{UQ}} )(W^{UK} c_t^{KV})$ 或者$ ({c_tQ}{W{UQ}} W^{UK} )c_t^{KV}$來計算。
因此,他認為,Absorb 的真實含義其實是矩陣乘法結合律,優先結合某些矩陣,并緩存 compressed latent vector \(c_t^{KV}\), 并不是合并權重矩陣,用 Absorb 命名有一定誤導性。如果吸收,也是\(W^{UK}\)被吸收到\(Q^C\),而非\(W^{UQ}\)。
2.3 解耦RoPE
為提高模型對序列中上下文信息的敏感性,MLA中應用了解耦旋轉位置編碼(RoPE)技術。而迄今為止,我們在分析中丟失了一個非常重要的步驟,即位置編碼。這是因為RoPE與低秩KV壓縮矩陣不兼容(與權重吸收會沖突),此時還無法無縫切換。為了解決這個問題,MLA引入額外的查詢向量\(q_t^R\)和共享鍵向量\(k_t^R\)來攜帶RoPE信息。從架構圖中可以發現,DeepSeek的q和k各自都有2個部分,分別是\([q_t^R,q_t^C]\)和\([k_t^R,k_t^C]\)。
- 1個部分是壓縮部分:\([q_t^C]\)和\([k_t^C]\)。
- 1個部分則加上了RoPE位置編碼。即有獨立一路做RoPE:\([q_t^R]\)和\([k_t^R]\)
最終兩個部分拼接成Q,K矩陣。這樣就把RoPE與低秩壓縮矩陣之間做了解耦,解決了位置信息與推理效率之間的矛盾。
我們接下來仔細進行剖析。
2.3.1 RoPE背景
下面代碼是Llama 3計算注意力的摘要。RoPE 旋轉位置編碼中Query和Key都是位置相關的。在進行注意力計算前,代碼是先應用\(W^K\)等矩陣得到Q和K,然后在Q和K上施加RoPE(乘以一個旋轉矩陣),以此在Q和K中融入相對位置信息。
class Attention(nn.Module):
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor],):
bsz, seqlen, _ = x.shape
# 獲取Q、K和V
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
# 施加RoPE
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
# 處理KV Cache
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# 計算注意力,分開計算了RoPE部分的q和k的注意力計算再求和
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
2.3.2 問題
無法直接應用到低秩壓縮
我們先看看是否可以把RoPE 施加到低秩壓縮向量上,即RoPE直接被低秩壓縮向量K和V所吸收。
因為K和V的低秩表示已經是壓縮了的狀態,壓縮操作可能已經丟失了某些信息,而RoPE矩陣對key和value是位置敏感的,直接在\(??_??^??\) 和 \(??_??^{????}\) 上應用 \(??_??\) 和 \(??_??\) 不再等價于在完整的Q和K上應用位置編碼,不能直接和有效地反映原始Q和K的相對位置關系。換言之,RoPE與低秩KV壓縮不兼容(RoPE is incompatible with low-rank KV compression),只能作用到原始K和V上。即只能從低秩KV壓縮先還原成原始的KV,然后在原始KV上施加RoPE。之前已經學習過,這樣做對性能有損失,所以采用了權重吸收。
與權重吸收不兼容
我們仔細看看RoPE作用到原始K和V上時,是否可以被權重吸收。
在RoPE的實現中,如果我們要讓Q、K帶上位置信息,會分別乘以相應的位置編碼矩陣。
如果計算\(??^T??\)時,就變成了
DeepSeek-V2對Q和K都進行了壓縮,則整個過程變成:
這里,\(??^{????}\) 和 \(??^{????}\) 分別是用于從低秩表示恢復到原始維度的解壓縮矩陣。目前公式中間多了一個與token位置差t-i相關的矩陣\(R_{m-n}\),該矩陣隨著相對位置變化而變化,并不是個固定矩陣,無法提前計算好。并且矩陣乘法不遵循交換律,沒辦法把\(R_{m-n}\)挪到公式的其它地方,因此在推理時,\(??^{????}\) 和 \(??^{????}\) 無法直接進行交互,\(??^{????}\) 就無法整合到 \(??^??\) 中。即\(??^{????}\) 和 \(??^??\) 無法合并為一個固定的投影矩陣。如果要強行降低KV Cache,則必須將參數簇\(R_{m-n}^s, s=1,2,...,head_{num}\)全部全部緩存下來。這個參數簇包含了\(O(sequence\_length^2)\)個參數張量,實在太大。因此,這就導致DeepSeek-V2原定的權重吸收無法實現,在推理過程中需要對所有前置tokens對應的Key進行旋轉位置編碼的計算,這會降低推理速度。
下圖給出了更加精確的闡述,上方是NoPE,下方是RoPE。

2.3.3 解決方案
為了解決MLA中的RoPE與低秩KV聯合壓縮不兼容的問題,DeepSeek團隊提出了解耦RoPE的策略:對于一個head,用一個高維度的向量表示其文本信息,以及一個低維度的向量來表示其旋轉位置編碼信息。前面的高維度向量稱為nope,后面的低維度向量稱為rope。具體而言是,把Query和Key進行拆分為\([q_t^R,q_t^C ]\)和\([k_t^R,k_t^C]\),其中一部分小向量進行了旋轉位置編碼( \(q_t^R,k_t^R\) ),一部分大向量進行壓縮( \(q_t^C,k_t^C\))。
- 信息存儲部分( \(q_t^C,k_t^C\))。這部分存儲了大部分的業務信息,是被壓縮的。下圖的紅圈和紫圈表明,我們有\(n_h\)個注意力頭,因此,我們需要把\(q_t^C,k_t^C\)??分別均分為\(n_h\)份。下標 i 表示的是第 i 個頭。
- 位置信息部分( \(q_t^R,k_t^R\) )。具體又分為兩部分。
- 使用共享的鍵(shared keys)\(??_??^??∈??^{??_?^??}\) 來攜帶RoPE信息,\(??_?^??\) 表示解耦的queries和key的一個head的維度。共享的\(??_??^??\)指的是每個頭的K都用這同一個\(??_??^??\)。注意,此處是基于 \(h_t\)(輸入嵌入)而不是基于向下投影的 \(C_t^{KV}\) 來生成\(k_t^R\)。
- 使用額外的多頭查詢(multi-head queries) \(??_{??,??}^??∈??^{d_?^??}\) 來攜帶RoPE位置信息。注意,此處是基于\(c_t^Q\)生成\(q_t^R\),而且每個頭會有自己的\(??_{??,??}^??\)。
最后將這四個變量分別拼接起來進行注意力計算。從而在推理時不需要對Key進行位置編碼的計算,避免了RoPE與低秩壓縮矩陣之間的耦合問題,解決了位置信息與推理效率之間的矛盾,提高了推理效率。具體參見下圖。
最終乘積計算如圖中標號4.1,其中前一項(標號4.2)按照無RoPE的情況計算,推理時只需要緩存\(c_t^{KV}\),后者(標號4.3)則對于所有注意力頭只緩存一個共享\(k_t^R\)。即,在推理階段,單個Token產生的KV Cache包含了兩個部分。
- 需要緩存鍵值的壓縮潛在向量\(c_t^{KV}\)(維度為\(d_c\))。
- 攜帶RoPE信息的共享鍵向量\(k_t^R\)(維度為\(d_h^R\))。
一共是\((??_??+??_?^??)??\) 個元素,l是層數。這種折中的方法保證了KV Cache的顯存空間依然很?。m然在 ???? 的基礎上增加了64維的 ???? ),FLOPS上有增加但是代價不大。

經過Concat過程會增加 Q 和 K 向量的維度。為了處理增加的維度,模型可以選擇:
- 增加注意力頭的數量:這將保持原有的每頭維度,但需要更多的計算資源。
- 調整每個頭的處理維度:保持頭的數量不變,但提高每個頭的維度,以適應Concat向量。
下圖給出了清晰的對比。進行注意力計算時,\(c_t^{KV}\)分別通過上投影矩陣\(W^{UK}\)和\(W^{UV}\)還原出鍵和值,每個注意力頭上的鍵再與攜帶了RoPE信息的共享鍵向量\(k^R_t\)拼接形成MHA的鍵值輸入。\(c_t^Q\)通過上投影矩陣\(W^{UQ}\)和\(W^{UR}\)還原并生成查詢向量\(q_t^C\)和攜帶RoPE信息的查詢向量 \(q_t^R\),二者拼接形成MHA的查詢向量輸入。最終多個頭的輸入拼接在一起,并經過線性映射\(W^O\)得到最終的輸出。

2.3.5 和權重吸收結合
我們再看看結合權重吸收之后如何處理,這里就需要將nope和rope也加進來,公式演變如下。
2.4 資源占用
2.4.1 參數量
MLA的思路來自LoRA,LoRA強調的是參數量的減少,而MLA也確實做到了減少參數量。按DeepSeek-V3的參數配置,兩個低秩矩陣參數量: \(2 \times d_c \times d =2\times512\times7168\) ,而正常MHA的參數矩陣參數量: \(d \times d=7168 \times 7168\) 。
具體參數如下:
"vocab_size": 129280,
"dim": 7168,
"inter_dim": 18432,
"n_heads": 128,
"q_lora_rank": 1536,
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
各個矩陣的參數量如下:
-
\(W^{DKV}\):dim * kv_lora_rank = 7168 * 512
-
\(W^{UK}\):kv_lora_rank * qk_rope_head_dim * n_heads = 512 * 128 * 128
-
\(W^{UV}\):kv_lora_rank * qk_nope_head_dim * n_heads = 512 * 128 * 128
-
\(W^{KR}\): dim * qk_rope_head_dim = 7168 * 64
-
\(W^{DQ}\):dim * q_lora_rank = 7168 * 1536
-
\(W^{UQ}\): q_lora_rank * qk_nope_head_dim * n_heads = 1536 * 128 * 128
-
\(W^{QR}\):q_lora_rank * qk_rope_head_dim * n_heads = 1536 * 64 * 128
-
\(W^O\):n_heads * v_head_dim * hidden_size = 128 * 128 * 7168。
2.4.2 內存占用
但MLA強調的是KV-cache的減少,也就是KV的激活值減少。我們接下來繼續分析。與經典的MHA和GQA,MQA比較。MLA實際緩存的向量是:
- \(c_t^{KV}\),維度是\(d_c\)。
- \(k_t^R\),維度是\(d_h/2\)。
如下圖所示,我們可以看出,MLA在優化kv cache和保證模型效果上有很強的優越性。圖中\(n_h\)是注意力頭數量,\(n_g\)是GQA分組數,\(d_h\)是隱藏層維度(低秩壓縮后的維度),\(d_c\)是KV壓縮維度,\(l\)為block的塊數。和MHA相比,Q和K的頭維度變成了\(d_c+d_r\),V的頭維度變成了\(d_c\),對于DeepSeek-V2,\(d_c\) 被設置為\(4d_h\),而\(d_h^R\)被設置為\(\frac{d_h}{2}\)。KV Cache的數量以元素數量來衡量(不考慮存儲精度)。
- 在MHA中,推理階段針對每個Token,需要緩存其鍵向量和值向量,則每個Token的緩存參數個數為\(2 \times n_h \times d_n \times l\)。與MHA相比,MLA占用的token數\(\frac{9}{2}d_hl\) 通常要小于\(2n_hd_hl\),所以MLA能獲得比 MHA 更強的性能,顯著降低了KV緩存的大小。
- GQA 通過分組共享 K/V 矩陣(如 LLaMA-70B 設置 g=8)減少顯存占用,但壓縮率有限(僅減少到 g/h 倍)。與GQA相比,MLA相當于GQA中的組數量 ???? =2.25,小于大多數Model里的 group數量,由此可見,其kv cache的尺寸會大大減小。即,MLA 的 KVCache 存儲成本約等于GroupNum=2.25 的 GQA 的 KVCache 存儲成本。
- 與MQA相比,MLA相當于增加了2.25倍的存儲,但是MLA的性能和效果顯著優于MQA,甚至強于MHA和GQA,真正實現了即降低推理成本,又保證了模型性能。

2.4.3 計算量
和MHA相比,MLA的Q和K的頭維度變成了\(d_c+d_h^R\),V的頭維度變成了\(d_c\)。而 DeepSeek V3的一些超參數如下:
- \(d_k\)(hidden dimension/模型維度):7168。
- \(n_h\)(注意力頭數):128。因為MLA的KV Cache大小跟\(n_h\)無關,增大\(n_h\)只會增加計算量和提升模型能力,但不會增加KV Cache。
- \(d_h\)(每個注意力頭的維度):128。
- \(d_c\)(KV的壓縮維度):512,即\(4d_h\)。
- \(d_h^R\)(RoPE頭相關維度):64,即\(\frac{d_h}{2}\)。
既然MLA每個頭的Q/K的head size變大了不小,所以MLA的推理計算量增加了。那為什么還能提高推理效率呢?其實,MLA可以提高效率是因為結合了LLM推理的瓶頸時訪存而不是計算這一特性。我們可以將LLM的推理分兩部分:第一個Token的生成(Prefill)和后續每個Token的生成(Generation),Prefill階段涉及到對輸入所有Token的并行計算,然后把對應的KV Cache存下來,這部分對于計算、帶寬和顯存都是瓶頸,MLA雖然增大了計算量,但KV Cache的減少也降低了顯存和帶寬的壓力。Generation階段由于每步只計算一個Token,實際上它更多的是帶寬瓶頸和顯存瓶頸,因此MLA的引入理論上能明顯提高Generation的速度。另一方面,由于Compressed KV在每個head中都參與了計算,DeepSeek-V2的128個heads能夠提供足夠的計算強度(正比于 Head 數),這樣就把 LLM 解碼過程的訪存密集型,轉換為計算密集型的操作,因此Attention部分的MFU也得到了大幅提高。
我們假設 q 的形狀是\((b,n_h,s_q, d_h)\),\(c^{KV}\)的形狀是\((b,1,s_{kv},d_c)\),\(W^{UK}\) 的形狀是\((d_c,n_h,d_h)\)。prefill階段,\(s_q = s_{kv} =s\)。
- native的計算量是:\(2bsd_cd_hn_h + 2bn_hssd_h = 2bn_hd_hs(d_c+s)\)。
- 吸收后的計算量是:\(2bsd_cd_hn_h + 2bn_hssd_c = 2bn_hd_cs(d_h+s)\)。
兩者相比是:\((d_h(d_c+s)) / (d_c(d_h+s))\)。
decode階段,\(s_q=1,s_{kv}=s\)。
- 緩存K的計算量。\(2bd_cd_hn_h+2bn_hsd_h=2bn_hd_h(d_c+s)\)。
- 緩存潛向量時候的計算量。\(2bsd_cd_hn_h+2bn_hsd_h=2bn_hd_h(d_cs+s)\)。
- 吸收后的計算量。\(2bd_cd_hn_h+2bn_hsd_c=2bn_hd_c(d_h+s)\)。
2.4.4 信息轉移
有研究人員再讀MLA,還有多少細節是你不知道的認為,MLA的作用其實是"信息轉移“,即把KV頭中獨有的信息轉移到對應的Q頭上,而把KV頭中間共享的相同信息存儲到KV Cache中。具體思路如下:
- 改進目的:在盡量不壓縮head上K、V信息的情況下,節省kv cache。
- 改進背景:之所以要保存token對應的所有注意頭上的K、V值,是因為每個k_head附帶有不同的信息,它將用這份獨有的信息和對應的q_head進行注意力計算。
- 改進思路(下面以K頭為例,V頭類似):
- 把一個token中所有K頭中的共有信息抽取出來,壓縮到KV Cache中,因為這些共有信息會更少,只保存它們才能減少KV Cache的大小。這個相同信息是每個tokens的所有k_heads共享一份,同時在不同tokens間共享。
- 把K中每個頭上獨有的信息轉移到對應的Q頭上。因為Q頭需要承載更多信息,所以Q和K的頭維度變成了\(d_c+d_r\),\(d_c\) 被設置為\(4d_h\)。在V3上,\(d_c\)是512,相當于把緩存7168維的向量降低到了緩存512維。而q壓縮之后是1536維,之所以這么大,就是因為Q要承載更多的信息。
雖然從形式上來說,MLA和MQA/GQA很像,似乎都是通過壓縮k/v_heads的數量來節省KV cache大小的。但MLA是壓縮num_heads,不壓縮信息(把信息轉移到了q_heads上);而MQA/GQA則在一定程度上對信息做了壓縮。具體這些相同信息、相異信息存儲在何處?是在\(W_K\)矩陣中?還是存儲在原始token \(h_t\)中?筆者目前不能確定。所以只能用下圖展示。

另外,GQA 的分組數需嚴格匹配硬件規模(如 8 卡對應 g=8),限制了模型部署的靈活性。而 MLA 通過潛在空間投影和解耦式權重合并,可動態適配不同硬件配置(如單卡或多機集群)。GQA 為彌補性能損失需增大 FFN 層規模(如 LLaMA3-70B 的 FFN 參數量增加 20%),導致模型復雜度上升。MLA 則通過低秩投影和動態路由,無需額外補償即可維持性能。
2.5 并行
在大模型推理的decode階段,MLA無法使用張量并行。故在目前的一些開源實現中,主要還是基于數據并行來對MLA進行處理,即不同請求的KVCache存儲到不同的GPU中。DeepSeek-V3論文提到使用張量并行和序列并行。
- 張量并行:MHA通常對head_num維度進行切分來實現張量并行。而MLA則有自己的特點,如果采用 tp 并行時,部分權重和 kvcache 都無法按 head_num 劃分到不同的卡上。
- 使用張量并行部分:kv_b_proj、o_proj等模塊都包括了head維度,因此可以按照head維度切分執行張量并行,將MLA計算均勻的劃分到多卡上,實現并行加速。
- 難以使用張量并行部分。
- mla存儲KV Cache時,對于一個token存儲的是(1, 1, kv_lora_rank+qk_rope_head_dim),而不是常規MHA下的(1, kv_head_num, head_dim)。因此KVCache中只保存一份潛空間的壓縮向量,并不包含head維度,沒有辦法按照head進行劃分。導致每張卡上都要保存所有請求的的完整kvcache,其形狀是(bs, 1, seq_len, kv_lora_rank),這意味著KVCache 各個卡的存儲是冗余的。
- 部分權重由于head_num=1無法切分到不同的卡上,比如q_a_proj 和kv_a_proj_with_mqa不能按 head_num 切分。只有上投影矩陣才能考慮按列切分和最后輸出矩陣按行切分。
- 數據并行。即按照請求切分,不同請求的潛空間的壓縮向量存儲到不同的GPU中。但是因為不同GPU上的請求長度可能差異很大,這樣會導致顯存占用不均衡,也會導致不同GPU上計算時間差異較大,進而導致性能最差的GPU拖慢整體進度。
- 序列并行:MLA會用序列并行(Sequence Parallel)來進行輔助。即,對KVCache按照序列維度進行切分,每一張卡上都使用query來做local的attention計算,然后對結果進行規約。
0x03 計算過程
我們來梳理下MLA在推理階段的計算流程。
3.1 公式
首先,我們給出Q、K、V的變換過程對應的公式。后續會按照這個公式來進行解析。

3.2 原始流程
我們將上述公式轉換為流程圖,圖中細節如下:
- 從上到下分為Q、K、V三路。
- Q和K又都細分為兩路,“上路”綠色的權重和激活值對應隱向量/低秩部分;“下路”灰色漸變的權重和激活值對應decoupled RoPE。
- K的下路和V路的數據流向有所交錯。
- “緩存” 代表在推理階段會進行緩存的數據,具體分為兩部分:
- KV聯合隱向量 \(c_t^{KV}\)。
- 單獨施加了RoPE的鍵$k_t^R \(。K路位置編碼模塊接受的輸入還是原始的\)?_t\(而不是壓縮后的\)c_t$。
此處假設頭數\(n_h\)為2,矩陣大小并不是完全按照比例縮放。

3.3 吸收
3.3.1 過程
接下來第二步,將論文中所說的權重吸收過程施加進去,得到下圖:
- 推理階段要緩存的東西不變。
- \(W^{UK}\) 吸收進 \(W^{UQ}\) 之后。
- Q的上路計算邏輯沒有變,但是權重和激活值的形狀都有相應的調整。
- K的上路則直接少掉了一處線性映射的計算邏輯,變成了重復拷貝$n_h $份,與K下路類似。
- $W^{UV} $吸收進 \(W^{O}\)之后。
- V路由線性映射退化為重復拷貝的邏輯。
- 最后輸出映射的計算邏輯不變,但是權重和激活值的形狀有相應的調整。
- 紅色字體公式代表了吸收對應的公式。綠色箭頭表示有進一步吸收的可能。

3.3.2 吸收結果
我們對上圖進行整理,得到吸收的結果如下。

3.3.3 MQA形式
MLA推理階段的計算邏輯其實很像一個MQA,我們進行比對下(不考慮 RoPE)。

MQA和MHA的最大區別在于 \(K,V\) 是所有 head共享的,因此能夠減少KV Cache的顯存占用。其中 $$ Q_iTK=HT(W_iQ)TW^KH $$。
對于MLA,單獨看 Attention 計算的前一部分,其中$ Q_iTK_i=HT(W{DQ})T(W_i{UQ})TW{UK}_iWH$,令 \(W_i^Q=(W_i^{UK})^TW_i^{UQ}W^{DQ}\),我們有 $$ Q_iTK_i=HT(W_iQ)TW^{DKV}H $$ 。可以看到這一計算公式和 Multi-Query Attention 其實是一樣的,都是使用的單獨的 \(Q\) 和共享的 \(K\)(\(C^{KV}\)),等價于將single-head的KV重復拷貝若干遍再執行正常的MHA。
區別在于,這里 \(W_i^QH,W^{DKV}H\in\mathbb{R}^{d_c\times l}\)。也就是說在進行 attention 計算的時候,向量點積的維度是 \(d_c\) 而不是 \(d\)。在論文中實際設置的是 \(d_c=4d\)。也就是說 Multi-Head Latent Attention 其實是 head dimension 提高到4倍的 Multi-Query Attention。在論文中也提到了在推理的時候 absorb \(W^{UK}\) into \(W^{UQ}\),其實就代表了這里的結合方式。因為每個head的維度提高了,所以能夠計算出更加復雜的 attention分布,從而相比起 Multi-Query Attention 取得性能提升。相比起直接提高 head dimension,其優點在于所有head的 \(W^{DQ},W^{UQ},W^{UK}\)的總參數量是 \(d\cdot d_c+d \cdot d_c+ d \cdot d_c=3d\cdot d_c=12d\cdot d_h\),而所有 head 的 \(W^Q\) 的參數量是 \(d \cdot d_c\cdot n_h=4d^2\),節省了參數量。也就是說對 \(W^Q\) 做了一個低秩分解。
但是這個提升并不是免費午餐,因為 head dimension 提高意味著 attention 的計算量也提高,而 attention 的計算量是 \(O(l^2)\) 的。為了處理長文本,現在大家一般都傾向于盡可能降低 attention 計算量的常數,而這個方法是會增加常數的。以上分析沒有考慮 RoPE,如果考慮 RoPE 的話,每個 head 的維度會從 \(4d\) 變成 \(4.5d\),其中\(4d\)是沒有 positional encoding的,\(0.5d\) 是使用 RoPE encoding的。其實 ChatGLM2-6B 中已經使用過類似的做法,即只在一半的 head dimension 上使用 RoPE ,目的是為了把 attention 計算分成位置相關和位置無關的兩部分,與性能提升的關系并不大。
了看得更明顯,我們可以把圖中的一些權重進一步吸收合并,得到下圖。
- Q的計算過程退化為普通multi-head線性映射
- 每個head一部分維度保持不動,對應綠色部分
- 每個head另一部分維度施加RoPE變換,對應紅色部分
- K的計算過程退化為single-head線性映射
- 同樣只對部分維度施加RoPE變換。
- 施加后進行重復拷貝(邏輯上如此呈現以便于理解,計算上當然可以優化掉)。
- V則直接使用K中未經施加RoPE變換的部分,同樣重復拷貝。
下圖與標準MQA的區別是:
- QK只有部分維度施加RoPE;
- V與未施加RoPE的K共享激活值。

0x04 代碼
我們主要使用V2的代碼來分析,因為條理更加清晰。也需要注意的是,DeepSeek的代碼在很多地方和論文不一致。V2中的DeepseekV2Attention的實現本質上和V3中的native一樣,其實并沒有節省KV-Cache,V3版本的非native版本是跟論文一致,節省了顯存。
4.1 配置
我們摘錄一些相關配置信息如下。在 Naive 實現中,512 維的 Latent KV \(c^{KV}\) 被映射回對應 128 個 head,每個 head 128 維的 \(k^C\) 和 \(v^C\),然后再拼接上位置向量 \(k^R\) ,最終形成標準的 q、k、v 輸入到標準的 Multi Head Attention 進行 Attetion 計算。另外,代碼中也使用了norm,在論文中也有相應提及。

具體配置信息如下。其中:
- 鍵和值的壓縮維度 \(d_c\) :設置為 512 ,原始嵌入維度 ??=5120,比例為 1/10。由于鍵和值在推理時需要緩存,因此采用較大的壓縮比例以顯著減少內存開銷。
- 查詢的壓縮維度 \(d'_c\) :設置為 1536 ,比例為 0.3 。查詢在訓練時需要頻繁計算,因此采用較小的壓縮比例以保留更多信息,確保模型性能。
"num_hidden_layers": 60, # Transformer層的數量
"hidden_size": 5120, # 隱藏層的大小
"num_attention_heads": 128, # 注意力頭的數量
"kv_lora_rank": 512, # KV壓縮維度
"q_lora_rank": 1536, # Query壓縮維度
"qk_rope_head_dim": 64, # 解耦Query和Key的每個頭部維度
"n_shared_experts": 2, # MoE層中的共享專家數量
"n_routed_experts": 160, # MoE層中的路由專家數量
"moe_intermediate_size": 1536, # 每個MoE專家的中間隱藏層的維度
"num_experts_per_tok": 6, # 每個token激活的專家數量
"routed_scaling_factor": 16.0, # 路由專家的縮放因子
"rms_norm_eps": 1e-06 # RMS歸一化的epsilon值
4.2 定義
給定輸入向量\(h_t \in \mathbb{R}^{B \times L \times 5120}\),其中\(B\)為batch size,\(L\)為sequence length。
class DeepseekV2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
# 對應 query 壓縮后的隱向量的維度 d'_c
self.q_lora_rank = config.q_lora_rank
# query和key的隱藏向量中,應用rope部分的維度,對應d_h^R
self.qk_rope_head_dim = config.qk_rope_head_dim
# 對應 key-value 壓縮后的隱向量維度 d_c
self.kv_lora_rank = config.kv_lora_rank
# value 的一個注意力頭的隱藏層維度
self.v_head_dim = config.v_head_dim
# 向量中不應用rope部分的維度
self.qk_nope_head_dim = config.qk_nope_head_dim
# 每一個注意力頭的維度應該是nope和rope兩部分之和
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
self.is_causal = True
# MLA 中對 Q 投影矩陣也做了一個低秩分解,對應生成 q_a_proj 和 q_b_proj 兩個矩陣,即兩階段投影:先將hidden_size投影到q_lora_rank,再投影到最終維度
# 對query進行壓縮,即down-projection。即,第一階段投影:hidden_size -> q_lora_rank,對應論文公式中的W^DQ
self.q_a_proj = nn.Linear(
self.hidden_size, config.q_lora_rank, bias=config.attention_bias
)
self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank)
# 對壓縮后的query映射成高維,即up-projection。對應上述公式中的W^UQ和W^QR合并后的大矩陣,僅僅只是內存放在一起。
# q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim = 128 + 64
self.q_b_proj = nn.Linear(
config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
)
# KV向量的生成也是先投影到一個低維的 compressed_kv 向量(對應c_t^{KV}),再升維展開
# 對應論文公式中的W^{DKV}和W^{KR}
self.kv_a_proj_with_mqa = nn.Linear(
self.hidden_size,
config.kv_lora_rank + config.qk_rope_head_dim,
bias=config.attention_bias,
)
self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank)
# 對應論文公式中的W^{UK}和W^{UV},由于 W^{UK} 只涉及 non-rope 的部分,所以維度中把 qk_rope_head_dim 去掉了
self.kv_b_proj = nn.Linear(
config.kv_lora_rank,
self.num_heads
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
bias=False,
)
# 對應論文公式的第 47 行
self.o_proj = nn.Linear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=config.attention_bias,
)
self._init_rope()
self.softmax_scale = self.q_head_dim ** (-0.5)
if self.config.rope_scaling is not None:
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.softmax_scale = self.softmax_scale * mscale * mscale
對應的一些信息如下。把整個計算流程拆成 q_nope, k_nope, k_pe, k_nope 這四個部分就是為了把RoPE進行解耦。兩個pe結尾的變量就是用于儲存旋轉位置編碼的信息。Deepseek-V2將kv cache壓縮到了同一個小矩陣中,后面再解壓縮出來。
# q = q.view(bsz, q_len, num_heads, q_head_dim).transpose(1, 2)
# q_nope, q_pe = torch.split(q, [qk_nope_head_dim, qk_rope_head_dim], dim=-1)
q_pe : torch.Size([16, 128, 1, 64])
q_nope : torch.Size([16, 128, 1, 128])
# query_states = k_pe.new_empty(bsz, num_heads, q_len, q_head_dim)
query_states : torch.Size([16, 128, 1, 192])
# kv = .view(bsz, kv_seq_len, num_heads, qk_nope_head_dim + v_head_dim).transpose(1, 2)
# k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1)
value_states : torch.Size([16, 128, 1024, 128])
k_nope : torch.Size([16, 128, 1024, 128])
# k_pe = k_pe.view(bsz, kv_seq_len, 1, qk_rope_head_dim).transpose(1, 2)
k_pe : torch.Size([16, 1, 1024, 64])
# key_states = k_pe.new_empty(bsz, num_heads, kv_seq_len, q_head_dim)
key_states : torch.Size([16, 128, 1024, 192])
self = {DeepseekAttention}
hidden_size = {int} 5120
kv_a_layernorm = {DeepseekV2RMSNorm} DeepseekV2RMSNorm()
kv_a_proj_with_mqa = {Linear} Linear(in_features=5120, out_features=576, bias=False)
kv_b_proj = {Linear} Linear(in_features=512, out_features=32768, bias=False)
kv_lora_rank = {int} 512
num_heads = {int} 128
o_proj = {Linear} Linear(in_features=16384, out_features=5120, bias=False)
q_a_layernorm = {DeepseekV2RMSNorm} DeepseekV2RMSNorm()
q_a_proj = {Linear} Linear(in_features=5120, out_features=1536, bias=False)
q_b_proj = {Linear} Linear(in_features=1536, out_features=24576, bias=False)
q_head_dim = {int} 192
q_lora_rank = {int} 1536
qk_nope_head_dim = {int} 128
qk_rope_head_dim = {int} 64
rotary_emb = {DeepseekV2RotaryEmbedding} DeepseekV2RotaryEmbedding()
softmax_scale = {Tensor} tensor(0.0723, dtype=torch.bfloat16)
v_head_dim = {int} 128
另外,https://github.com/sgl-project/sglang/discussions/3082 這里闡釋了為何使用norm。
4.3 操作Q
我們把Q相關的代碼都合并在一起進行分析。總的流程是:模型處理上一層計算出的隱藏狀態(hidden_size=5120)時,首先會將模型的q壓縮到 q_lora_rank 這一維度(設定為1536),再擴展到 q_b_proj 的輸出維度(num_heads * q_head_dim),最后切分成 q_pe 和 q_nope 兩個部分。
4.3.1 變量定義
MLA 中對 Q 投影矩陣\(W^Q\)做了一個低秩分解,對應生成 q_a_proj 和 q_b_proj 兩個矩陣。
- q_a_proj 大小為 [hidden_size, q_lora_rank] = [5120, 1536],對應公式中的 $$W^{DQ}$$,用來降維。
- q_b_proj 大小為 [q_lora_rank, num_heads * q_head_dim] = [q_lora_rank, num_attention_heads * (qk_nope_head_dim + qk_rope_head_dim)] = [1536, 128*(128+64)] = [1536, 24576] ,用來升維,對應公式中的 \(W^{UQ}\) 和 \(W^{QR}\)合并后的大矩陣。因為從公式來看這兩個矩陣都需要和\(c_t^Q\)計算,所以可以合并矩陣后再進行拆分。對于一個head,用一個128維度的向量表示其文本信息,以及一個64維度的向量來表示其旋轉位置編碼信息。前面的128維度,稱為nope,后面的64維度,稱為rope。
self.num_heads = config.num_attention_heads # 128
self.q_lora_rank = config.q_lora_rank # 1536
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim # 128 + 64
# 對query進行壓縮,即down-projection
self.q_a_proj = nn.Linear(
self.hidden_size, config.q_lora_rank, bias=config.attention_bias
)
self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank)
# 對壓縮后的query映射成高維,即up-projection
self.q_b_proj = nn.Linear(
config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
)
4.3.2 變量操作

在DeepSeek-V2中,Q向量也采用了低秩壓縮的方式。
- 首先,將輸入向量投影到一個1536維的低維空間:$$ c_t^Q = W^{DQ} ,h_t \in \mathbb{R}^{B \times L \times 1536} $$。對應論文第37號公式。
- 然后,將其投影到\(\mathbb{R}^{H \times 128}\)的多頭向量空間上(其中\(H=128\)是heads數),得到了Q向量的第一部分:$$ q_t^C = W^{UQ} c_t^Q \in \mathbb{R}^{B \times L \times H \times 128} $$。對應第38號公式。
- 再將其投影到\(\mathbb{R}^{H \times 64}\)上并使用RoPE嵌入位置信息,得到Q向量的第二部分:$$ q_t^R = \mathrm{RoPE}(W^{KR} h_t) \in \mathbb{R}^{B \times L \times H \times 64} $$。對應第39號公式。每個head有自己的旋轉位置編碼,每個head之間不共享。
- 將兩部分拼接的到最終的Q向量:$$ q_t = [q_t^C, q_t^R] \in \mathbb{R}^{B \times L \times H \times 192} $$。對應第40號公式。
在具體的實現過程中其輸入為 hidden_states 向量,對應公式中的 \(?_t\)。是一個大小為 [batch_Size, sequence_length, hidden_size] 的矩陣,其中 hidden_size 具體為 5120。后續的nope指代非rope。
# hidden_states對應公式中的h_t,hidden_states的shape是(batch_size, seq_length, hidden_size),其中 hidden_size為 5120,是num_head * q_head_dim
bsz, q_len, _ = hidden_states.size()
# 下面兩行代碼對應第37、38號公式,先降維再升維。q_b_proj維度是[1536, 24576],q_a_proj維度是[5120, 1536],是W^Q [5120, 24576]矩陣的低秩分解。即[5120, 24576] -> [5120, 1536] * [1536, 24576]
# 首先,使用全連接層(self.q_a_proj)對輸入的隱狀態(hidden_states)進行降維投影
# 然后,使用全連接層(self.q_b_proj)對壓縮的向量進行上投影
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
# 重塑為多頭形式,是第40號公式的前置準備操作,或者說是40號公式的反向操作
# q_pe 要扔給 RoPE模塊,所以需要重整下形狀
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
# 把最后一維切分成nope和rope兩部分
# 將最后一層 192 的hidden_states切分為 128 (qk_nope_head_dim) + 64 (qk_rope_head_dim),即將查詢表示(q)分為兩部分:沒有經過位置編碼的部分(q_nope)和經過位置編碼的部分(q_pe),q_nope表示不需要應用RoPE的,q_pe表示需要應用RoPE的
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# 第39號公式,給q和k施加RoPE
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
# 初始化查詢狀態(query_states)的張量,這個張量將用于存儲融合了解耦RoPE的查詢表示,其中q_head_dim = qk_nope_head_dim + qk_rope_head_dim = 128 + 64 = 192
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
# 下面兩行對應第40號公式
# 將未經過位置編碼的查詢表示(q_nope)復制到 query_states 張量的前一部分,即那些不包含位置編碼的維度。
# 這樣做可以有利于后續將原始的查詢表示與含有位置編碼信息的查詢表示分開來處理
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope # 128
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe # 64
4.4 操作KV
我們把KV相關的代碼都合并在一起進行分析。對于kv矩陣的設計,模型使用了kv壓縮矩陣設計(只有576維),在訓練時進行先降維再升維。在模型推理的時候,需要緩存的量變成 compressed_kv,經過 kv_b_proj 升高維度得到 k,v 的計算結果。
4.4.1 變量定義
KV向量和Q向量類似,也做了一個低秩分解,對應生成 kv_a_proj_with_mqa和 kv_b_proj 兩個矩陣。
- kv_a_proj_with_mqa 大小為 [hidden_size, kv_lora_rank + qk_rope_head_dim] = [5120, 512 + 64] = [5120, 576],對應上述公式中的 $$W^{DKV}$$ 和 $$W^{KR}$$的合并矩陣,用來把輸入先投影到一個低維的空間(對應 $$C_t^{KV}$$),同時做兩種降維操作(nope,rope的前置操作)。因為因為從公式來看這兩個矩陣都需要和\(h_t\)計算,所以可以合并矩陣計算后再進行拆分。輸出的維度則是512+64=576了。前面的512維度是給kv的,后面的64維度是給key的旋轉位置編碼的。
- kv_b_proj 大小為 [kv_lora_rank,num_heads * (q_head_dim - qk_rope_head_dim + v_head_dim)] = [512, 128*((128+64)-64+128)] = [512, 32768],對應上述公式中的 $$W^{UK}$$ 和$$W^{UV}$$的合并矩陣。由于 $$W^{UK}$$ 只涉及nope 的部分,所以維度中把 qk_rope_head_dim 去掉了。192-64是把key表示向量中的64維度的旋轉位置編碼向量從192維度中減去;然后的128維度是留給value的,因為value不需要考慮位置信息。需要考慮位置信息的只有query和key。
或者說,通過kv_a_proj_with_mqa 來對head脫敏,即得到的張量和具體的head無關;通過kv_b_proj來重新恢復成對每個head敏感,得到的是形如[1, 16, 26, 128]這樣的,和具體16個head分別相關的張量。
self.kv_lora_rank = kv_lora_rank # 512,key和value各占256維度
self.qk_rope_head_dim = config.qk_rope_head_dim # 64
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim # 128 + 64
self.v_head_dim = config.v_head_dim # 128
self.hidden_size = config.hidden_size # 5120
# 計算壓縮后的latent kv以及需要緩存的應用RoPE的k的部分:k_t^R(前置條件),即把隱向量的5120維度 映射到 config.kv_lora_rank + config.qk_rope_head_dim = 512 + 64維度
self.kv_a_proj_with_mqa = nn.Linear(
self.hidden_size,
config.kv_lora_rank + config.qk_rope_head_dim,
bias=config.attention_bias,
)
self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank)
# 計算up-projection后的不應用RoPE的k的部分 和 up-projection后的v的結果
self.kv_b_proj = nn.Linear(
config.kv_lora_rank,
self.num_heads
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
bias=False,
)
4.4.2 變量操作

計算KV向量時,有幾個和公式中不同的地方,即把某些矩陣操作打包在一起執行(同時將K,V的向量一起產出了),后續再拆分開。
-
首先需要將輸入向量投影為512維的聯合壓縮表示:$$ c_t^{KV} = W^{DKV} h_t \in \mathbb{R}^{B \times L \times 512} $$,對應第41號公式。
-
與Q向量的計算過程類似,K向量的第一部分是將\(c_t^{KV}\)通過投影解壓縮到\(\mathbb{R}^{H \times 128}\)的多頭向量空間:$$ k_t^C = W^{UK} c_t^{KV} \in \mathbb{R}^{B \times L \times H \times 128} $$,對應第42號公式。注意:此處增加了一個頭維度。
-
K的第二部分是將輸入向量投影到64維向量空間并施加RoPE嵌入位置信息:$$ k_t^R = \mathrm{RoPE}(W^{KR} h_t) \in \mathbb{R}^{B \times L \times 64} $$,對應第43號公式。
-
與Q不同的是,完整的K是將K的第二部分廣播到每個head后與第一部分拼接得到:
\[ k_t = \begin{bmatrix} k_{t,1}^C & k_t^R \\ k_{t,2}^C & k_t^R \\ \vdots & \vdots \\ \end{bmatrix} \in \mathbb{R}^{B \times L \times H \times 192} \]也就是說,每個head的RoPE部分是完全相同的。此處對應第44號公式。再強調下:對于query,每個head有自己的旋轉位置編碼向量;key則是所有heads共享同一個旋轉位置編碼向量。
-
V向量的計算較為簡單,直接將\(c_t^{KV}\)解壓縮到\(\mathbb{R}^{H \times 128}\)即可:$$ v_t = W^{UV} c_t^{KV} \in \mathbb{R}^{B \times L \times H \times 128} $$,對應第45號公式。
通過維度分析可以看到 kv_lora_rank 是 qk_nope_head_dim 的 4 倍且 K 和 V 共享 latent state,qk_rope_head_dim 只有 qk_nope_head_dim 的一半,結合起來 4+1/2=9/2,是 正式下圖中 MLA KVCache per Token 大小的來源。

具體的代碼實現如下,可以發現除了在對q做計算時涉及到gemv之外,也就是q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))),其它地方的矩陣乘運算q_len維度都是和num_heads在一起做計算,而num_heads在Deepseek2的配置里面已經是128了,導致其它的Matmul幾乎都落在了計算密集的范疇。
# 使用MQA(Multi-Query Attention)對輸入的隱狀態進行處理,得到壓縮后的鍵值對表示(compressed_kv),對應41號公式和43號(還沒有加 rope)。此時compressed_kv就是公式中的c_t^{KV}+W^{KR}h_t,形狀是[B, q_len, kv_lora_rank + qk_rope_head_dim],kv_lora_rank是d_t
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
# 將壓縮后的鍵值對表示分為兩部分:低秩壓縮的鍵值對部分和經過位置編碼的鍵部分(k_pe),分別對于nope和rope。這是第44號公式的前置準備操作,或者說是44號公式的反向操作
# 此時compressed_kv才是公式中的c_t^{KV},k_pe是公式中的W^{KR}h_t
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
# k_pe 要傳給 RoPE模塊,所以需要重整下形狀,增加注意力頭這個維度
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
# 計算得到k^C和v^C
# 1. 對壓縮后的鍵值對升維,包括RMSNorm(self.kv_a_layernorm)和全連接層(self.kv_b_proj,對應W^{UK}和W^{UV}),是42號和45號公式結合體的前半部分,得到W^{UK}c^{KV}_t(k^C_t)和W^{UV}c^{KV}_t(V^C_t),但此時k^C_t和V^C_t是拼接在一起的
# 2. 用view()和transpose()函數將MLA展開成標準MHA的形式。注意:此處增加了一個頭維度
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
)
# 使用torch.split函數將k^C_t和V^C_t分離開,是42號和45號公式結合體的后半部分。因為 kv_b_proj 包括 W^{UK} 和 W^{UV},因此要把它們的計算結果分離出來,分別在不同的地方吸收,最終拆分成兩部分:
# k_nope是沒有經過位置編碼的鍵部分,不包含位置信息。維度為[B, num_head, kv_seq_len, qk_nope_head_dim]
# value_states是值部分,用于后續的位置編碼和注意力權重計算,維度為[B, num_head, kv_seq_len, v_head_dim]
k_nope, value_states = torch.split(
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
# 獲取key/value的序列長度,即包含當前位置可用上下文的長度
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# 調用self.rotary_emb函數,根據value_states和更新后的序列長度kv_seq_len計算RoPE的cos和sin值
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# 使用apply_rotary_pos_emb函數對W^{KR}h_t施加RoPE,得到k_t^R,即k_pe變量
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
# 初始化鍵狀態(key_states)的張量,存儲融合了解耦RoPE的鍵表示
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope # k^C_t
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe # k^C_t + k_t^R
4.5 注意力操作
4.5.1 變量定義
o_proj對應矩陣\(W^O\),大小為[num_heads * v_head_dim, hidden_states]=[128 * 128, 5120]。
self.v_head_dim = config.v_head_dim # 128
self.num_heads = config.num_attention_heads # 128
self.hidden_size = config.hidden_size # 5120
self.o_proj = nn.Linear( # 對應第47號公式
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=config.attention_bias,
)
4.5.2 變量操作
生成 QKV 向量之后的流程就基本上等同于標準的 MHA 計算了。唯一的區別在于只有 q_pe, k_pe 這兩個部分給加上了 rope。具體流程如下:
首先計算attention score:
然后計算對V的加權和,并將所有head壓平,得到Attention輸出:
最后經過另一個矩陣的投影,就能得到MLA的最終輸出:
# 更新和拼接歷史 KVCache,將當前位置之前的壓縮后的kv以及應用過rope的k的部分拼接進去,可以看到這里存儲的是展開后的 MHA KVCache
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update( # 更新kv cache
key_states, value_states, self.layer_idx, cache_kwargs
)
# 后續就是標準的 MHA 代碼,代碼 Q^T*K*V*O
attn_weights = (
torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
4.6 前向傳播
我們把完整的前向傳播代碼摘錄如下,大家可以更好的理解。
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None, # V2代碼中,kv cache存儲的是全部緩存,不是壓縮后的
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# hidden_states對應公式中的h_t,的shape是(batch_size, seq_length,hidden_size)
bsz, q_len, _ = hidden_states.size()
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
)
k_nope, value_states = torch.split(
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
attn_weights = (
torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
對應如下圖例。

4.7 V3 代碼
我們也給出V3代碼具體如下。V3中的 native 版本其實并沒有節省KV-Cache(甚至還多了存儲),V3版本的非native版本是跟論文一致,節省了顯存。
native 版本的實現直觀、適合學習,但是不適合Decode階段,因為Decode階段需要用到KV Cache。針對KV Cache,native 版本的實現有兩種選擇:
-
① 緩存 Latent KV。緩存規模小,矩陣運算是\((b,n_h,1,d_c) \times (b,1,s,d_c)\),假定是bfloat16精度,內存讀取量是\(2bn_hd_c + 2bsd_c = 2bd_c(n_h+s)\)。但 Latent KV 緩存不能直接送 MHA 計算,還得經過 \(W^{UK}\) 和 \(W^{UV}\) 的線性映射,這是兩個規模不小的矩陣計算,而且每輪都得重復計算。
-
② 緩存 KV。緩存規模大,不用重復計算,性能好。標準MHA \((b,n_h,1,d_h) \times (b,n_h,s,d_h)\)的內存讀取量是\(2bn_hd_h+2bn_hsd_h = 2bd_hn_h(1+s)\)。但 MLA 的一大好處就是 KV Cache 壓縮,這樣顯存內能緩存更多 token,支持更大的 batch 和 prefix cache。如果緩存 KV,在顯存上對比 MHA 就完全沒有優勢了。
native 版本最終的選擇是方案2。所以,Naive 實現可能會用于 Prefill階段,但在 Decode 計算時需要更好的計算方法,也就是非native版本。在非native版本最核心的 Attention kernel 計算中,“吸收“模式下 K/V tensor Shape 中不攜帶 num_attn_heads 信息,計算邏輯轉換成了類 MQA 計算,“不吸收”模式下 K/V tensor 仍攜帶 num_attn_heads,就是MHA計算。
# from: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
class MLA(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.dim = args.dim
self.n_heads = args.n_heads
self.n_local_heads = args.n_heads // world_size
# 對應 query 壓縮后的隱向量的維度 d'_c
self.q_lora_rank = args.q_lora_rank # q 壓縮后的維度
# 對應 key-value 壓縮后的隱向量維度 d_c
self.kv_lora_rank = args.kv_lora_rank # kv 壓縮后的維度
# 表示query和key的向量中應用rope部分的維度, $d_h$
self.qk_nope_head_dim = args.qk_nope_head_dim
# 對應$d_h^R$, 表示應用了rope的 queries 和 key 的一個 head 的維度。
self.qk_rope_head_dim = args.qk_rope_head_dim
# $d_h + d_h^R$, 注意力頭大小為非rope部分大小加上rope部分大小
self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
self.v_head_dim = args.v_head_dim
if self.q_lora_rank == 0:
# 不適用低秩分解,回歸到傳統MHA
self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)
else:
# 其實就是$W^{DQ}$,用來生成$c_t^Q$
# 下采樣矩陣,得到壓縮后的q向量
self.wq_a = Linear(self.dim, self.q_lora_rank)
self.q_norm = RMSNorm(self.q_lora_rank)
# $W^{UQ}$
# 上采樣矩陣,用來恢復q向量
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
# $[W^{DKV}; W^{KR}]$
# 下采樣矩陣,得到壓縮后的kv向量
self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
self.kv_norm = RMSNorm(self.kv_lora_rank)
# 上采樣矩陣,用來恢復kv向量
# $[W^{UK}; W^{UV}]$
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
self.softmax_scale = self.qk_head_dim ** -0.5
if args.max_seq_len > args.original_seq_len:
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
self.softmax_scale = self.softmax_scale * mscale * mscale
if attn_impl == "naive": # native模式下,kvcache存儲的是沒有壓縮的數據,大小為d_h + d_h^R, 不但沒有節省,反而增加了顯存消耗
self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
else:
# 在非native模式下,存儲的是壓縮的c,大小為d_c
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
bsz, seqlen, _ = x.size()
end_pos = start_pos + seqlen
# 計算q
if self.q_lora_rank == 0:
q = self.wq(x)
else:
q = self.wq_b(self.q_norm(self.wq_a(x)))
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
# 分離nope,rope
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# 執行RoPE計算
q_pe = apply_rotary_emb(q_pe, freqs_cis)
kv = self.wkv_a(x)
# KV-Cache大小為wkv_a outputdim(self.kv_lora_rank + self.qk_rope_head_dim)
# 分離KV和K位置編碼
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
# 執行RoPE計算
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
if attn_impl == "naive":
q = torch.cat([q_nope, q_pe], dim=-1)
kv = self.wkv_b(self.kv_norm(kv))
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
self.k_cache[:bsz, start_pos:end_pos] = k # 存儲的是完全沒有壓縮的k
self.v_cache[:bsz, start_pos:end_pos] = v # 存儲的是完全沒有壓縮的v
# score = q^T \times k_cache
scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
else:
# 處理KV u-pprojection矩陣
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
# q_{nope} = q_{nope} \times W^{UK}
# q中不需要位置編碼的先和K的不需要位置編碼的權重相乘
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) # 保存KV Cache
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) # 保存K的位置編碼Cache(pe cache)
# scores = q_{nope}^T \times kv\_cache + q_{pe}^T \times pe\_cache
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
if mask is not None:
scores += mask.unsqueeze(1)
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
if attn_impl == "naive":
# score \times v_cache
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
else:
# u = W^{UV} \times scores \times kv\_cache
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
# out = W^O \times u
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
x = self.wo(x.flatten(2))
return x
具體比對如下圖。

0x05 優化代碼
DeepSeek代碼并沒有給出某些功能的具體方案,比如壓縮優化和權重吸收。因此,我們主要以章明星老師給出的方案 https://github.com/madsys-dev/deepseekv2-profile/tree/main DeepSeek-V2 高性能推理 (1):通過矩陣吸收十倍提速 MLA 算子為例進行學習。
5.1 壓縮優化
目前V2代碼中,Attention中的KV Cache緩存的仍然是全量的key和value(從隱向量又解壓縮出來),而并非論文中所說的壓縮后的compressed_kv以及k_pe,導致其實沒有減少KV Cache的緩存。
主要原因可能是:一方面復用transformers原有的Cache邏輯,方便實驗和理解;另一方面這部分應該是訓練代碼,而推理代碼會針對這部分進行優化和改進。
我們可以做如下修改,也將RoPE后的k_pe一并緩存入KV Cache中。
# 將當前位置之前的壓縮后的kv(c_t^{kv})以及應用過rope的k的部分拼接到KV Cache前面
if past_key_value is not None:
# 得到的應該是
# compressed_kv: [B, kv_seq_len, d_c]
# k_pe: [B, 1, kv_seq_len, qk_rope_head_dim]
compressed_kv, k_pe = past_key_value.update(compressed_kv, k_pe)
章明星老師給出了更詳盡的方案。
# CacheCompressed
def forward(self, hidden_states_q: torch.Tensor, q_position_ids: torch.LongTensor, compressed_kv: torch.Tensor):
...
kv_seq_len = compressed_kv.size(1)
# 對應完整公式的 44 行反過來
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
k_pe = k_pe.view(bsz, kv_seq_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = self.kv_b_proj(compressed_kv) \
.view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) \
.transpose(1, 2)
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
...
def compress_kv(self, hidden_states_kv: torch.Tensor, kv_position_ids: torch.LongTensor) -> torch.Tensor:
# return the RoPE'ed & compressed kv
bsz, kv_seq_len, _ = hidden_states_kv.size()
compressed_kv = self.kv_a_proj_with_mqa(hidden_states_kv)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_layernorm(compressed_kv)
k_pe = k_pe.view(bsz, kv_seq_len, 1, self.qk_rope_head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(k_pe)
k_pe = apply_rotary_pos_emb(k_pe, cos, sin, kv_position_ids).view(bsz, kv_seq_len, self.qk_rope_head_dim)
return torch.cat([compressed_kv, k_pe],dim=-1)
5.2 權重吸收
在計算MLA的時候,仍然需要存儲解壓后的完整的KV Cache,這很可能引起OOM崩潰。DeepSeek-V2的論文中提出,可以將KV的解壓縮矩陣吸收到Q-projection和Out-projection中,從而可以在不解壓縮KV Cache的情況下直接計算最終的Attention結果。
實際上,把權重吸收理解成矩陣乘法交換律更合適。因為實際上是提前將兩個參數矩陣乘起來,即把 \((W^{UQ})^TW^{UK}\) 的計算結果做為新的參數矩陣,然后再跟中間張量乘,在性能上不一定比分開計算更好。
下圖分別給出了MHA、MLA和權重吸收的MLA的計算示例。最右側的兩個虛線箭頭,顯示了在優化的計算過程中,哪些參數矩陣被交換了位置。它們能交換的原因,就是從數學上這樣修改是等價的(矩陣乘法交換律)。此時,輸入注意力機制的 q、k、v 形狀發生了明顯的變化。q 的形狀由 $$[n_h \times (d_h+d_h^R)]$$ 變化成了 $$[n_h \times (d_c+d_h^R)]$$,k 的形狀由 \([n_h \times (d_h + d_h^R)]\) 變化成了 \([n_h \times (d_c + d_h^R)]\),v 的形狀由 \(d_h\) 變化成了 \(d_c\)。這樣一來,新的計算過程中只剩下 ① Latent KV 了。原來的 ② KV 就不存在了,變成可以用Latent KV表示。而且實際上 V 也不存在了,因為 V 就是 K 的前 512 維。這其實就是MQA,這實際上就是 FlashMLA 代碼庫解決的問題。

我們接下來依據章老師的代碼和文字來繼續學習。
5.2.1 absorbed_cache_compressed.py
與論文不同,此處將代碼中 kv_b_proj 中屬于 K 的部分權重(論文中對應\(W^{UK}\))吸收進 q_nope(論文中對應 \(q^C\),而且是在運行時做,非提前吸收);將代碼中 kv_b_proj 中屬于 V 的部分權重(論文中對應\(W^{UV}\))吸收進 attn_out。抽象一點的理解就是,將 Q 也映射到 KV 的低秩空間,然后在低秩空間做完整的 Attention,之后再映射回 Q 的原始空間。
\(W^{UK}\)
對于K的吸收,在注意力分數的計算公式中,非RoPE部分可以做如下展開:
也就是說,我們事實上不需要每次都將低維的\(c_t^{KV}\)展開為\(k_t\)再計算,而是通過矩陣乘法結合律,直接將 \(W^{UK}\) 通過結合律先和左邊做乘法,改為計算,避免了解壓縮出完整的K矩陣。即將前三者進行計算:
此外,在原始版本的解壓縮的過程中,由于每個token的key都需要與\(W^{UK}\)相乘才能得到,因此計算量較大;矩陣吸收后,\(W^{UK}\)只需要對\(q_t^C\)這一個向量相乘,也大大減少了浮點計算量。
# Absorbed_CacheCompressed
def forward(hidden_states_q: torch.Tensor, q_position_ids: torch.LongTensor, compressed_kv: torch.Tensor):
...
# 從 kv_b_proj 中分離的 W^{UK} 和 W^{UV} 兩部分,他們要分別在不同的地方吸收
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim,:]
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :]
cos, sin = self.rotary_emb(q_pe)
q_pe = apply_rotary_pos_emb(q_pe, cos, sin, q_position_ids)
qk_head_dim = self.kv_lora_rank + self.qk_rope_head_dim
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, qk_head_dim)
# 此處改變了q_nope的計算順序,把 W^{UK} 吸收到 W^{UQ}
query_states[:, :, :, : self.kv_lora_rank] = torch.einsum('hdc,bhid->bhic', q_absorb, q_nope)
query_states[:, :, :, self.kv_lora_rank :] = q_pe
...
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(q_nope.dtype)
# 此處改變了attn_output的計算順序
attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
attn_output = torch.einsum('bhqc,hdc->bhqd', attn_output, out_absorb)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
除了壓縮KV Cache之外,我們還可以觀察到上面涉及到的2個矩陣乘法實際上都來到了計算密集的領域,例如對于 torch.einsum('bhqc,blc->bhql', q_nope, compressed_kv) 。由于不同 head 的 q_nope 部分共享了共同的 compressed_kv 部分,實際計算的是 batch_size 個 [head_num * q_len, kv_lora_rank] 和 [past_len, kv_lora_rank] 的矩陣乘法。計算等價于一個 MQA 操作,計算強度正比于 head_num 的也就是 128。因此相比 MHA,吸收后的 MLA 計算強度要大得多,可以更加充分的利用 GPU 算力。
\(W^{UV}\)
對于V的吸收,情況稍微復雜。為表述的清楚性,我們采用Einstein求和約定描述該過程
v_t = einsum('hdc,blc->blhd', W_UV, c_t_KV) # (1)
o = einsum('bqhl,blhd->bqhd', a, v_t) # (2)
u = einsum('hdD,bhqd->bhD', W_o, o) # (3)
# 將上述三式合并,得到總的計算過程
u = einsum('hdc,blc,bqhl,hdD->bhD', W_UV, c_t_KV, a, W_o)
# 利用結合律改變計算順序
o_ = einsum('bhql,blc->bhqc', a, c_t_KV) # (4)
o = einsum('bhqc,hdc->bhqd', o_, W_UV) # (5)
u = einsum('hdD,bhqd->bhD', W_o, o) # (6)
5.2.2 Move Elision
不過,這樣還不能完全發揮出MLA的威力。在原始代碼中,query_states和key_states會通過拼接RoPE和非RoPE部分得到:
def forward(...):
...
# 更新和拼接歷史 KVCache,可以看到這里存儲的是展開后的 MHA KVCache
# 其中 q_head_dim 等于 qk_nope_head_dim + qk_rope_head_dim
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
key_states = k_pe.new_empty(bsz, self.num_heads, kv_seq_len, self.q_head_dim)
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
...
當我們采取了上述優化后,此處的拼接過程會產生大量無用的數據拷貝和廣播,同時也會占用大量顯存空間導致OOM,而且如果是concat放在框架做,但可能會增加IO,尤其是decode本就是IO瓶頸。而且,先對Latent解壓縮再計算,則Attn的計算是一個實打實的Multi Head Attention,會增大計算量。
為此,我們采用MoveElision優化策略,即省略此處的拼接RoPE部分和非RoPE部分的過程,而是直接分別計算量部分的Attention Score并相加(考慮\(q_t^\top k_t = {q_t^C}^\top k_t^C + {q_t^R}^\top k_t^R\))。即,將 RoPE 部分與 NoPE 部分分別做乘法,然后進行拼接的操作,改為 NoPE 部分 Attention 和 RoPE 部分 Attention 兩者結果相加,這樣做的好處在于節省了內存搬運操作,這種做法等效于ALiBi。我們具體推導如下。
具體對應下面代碼中的torch.matmul(q_pe, k_pe.transpose(2, 3))這行。即,分開計算了RoPE部分的q和k的注意力計算再求和。標準實現是將加上了 rope 的 q_pe/k_pe 和沒加 rope 的 q_nope/k_nope 拼接起來一起。
# Absorbed_CacheCompressed_MoveElision
def forward(...):
...
# qk_head_dim = self.kv_lora_rank + self.qk_rope_head_dim
# query_states = k_pe.new_empty(bsz, self.num_heads, q_len, qk_head_dim)
# query_states[:, :, :, : self.kv_lora_rank] = torch.einsum('hdc,bhid->bhic', q_absorb, q_nope)
# query_states[:, :, :, self.kv_lora_rank :] = q_pe
# key_states = k_pe.new_empty(bsz, self.num_heads, kv_seq_len, qk_head_dim)
# key_states[:, :, :, : self.kv_lora_rank] = compressed_kv.unsqueeze(1)
# key_states[:, :, :, self.kv_lora_rank :] = k_pe
# attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
# 吸收后 attn_weights 直接基于 compressed_kv 計算不用展開
attn_weights = torch.matmul(q_pe, k_pe.transpose(2, 3)) + torch.einsum('bhqc,blc->bhql', q_nope, compressed_kv)
attn_weights *= self.softmax_scale
...
代碼比對如下:

5.2.3 Materializing Projection Matrices
DeepSeek-V2的論文中說:

不過,似乎并沒有必要再改變順序,對模型參數進行預處理,將\(W^{UK}\)與\(W^{UQ}\)相乘,以及將\(W^{UV}\)與\(W^O\)相乘。這是因為,\(W^{UK}\)與\(W^{UQ}\)相乘后的結果可以視為\(H\)個大小為\(1536 \times 512\)的低秩(不超過128)矩陣,而\(W^{UV}\)與\(W^O\)相乘的結果可以視為\(H\)個大小為\(5120 \times 512\)的低秩矩陣。相比用這些特別大的低秩矩陣做投影,明顯不如按照低秩分解形式依次相乘來得劃算。因此,章老師認為這一步的優化并不是很有必要。
因為假設有矩陣 A[m,k],B[k,n],C[n,l],B 和 C 為低秩矩陣,依次相乘 A?B?C 需要的算力: 2mkn+2mnl=2mn?(k+l),而提前合并 D=(B?C),A?D 需要的算力:2mkl,當 n?(k+l)<kl 時,提前合并低秩矩陣,反而會引入更多計算。而在 LoRA 的推理階段,之所以能這樣做,是因為本身就已經存在一個大的 pre-train weight 的矩陣,因此提前做吸收,不會增加計算量。
具體代碼如下:
def forward(self, hidden_states_q: torch.Tensor, q_position_ids: torch.LongTensor, compressed_kv: torch.Tensor):
'''
Attention masks and past cache are removed.
Input:
- hidden_states_q: [bsz, q_len, hidden_size]
- compressed_kv: [bsz, kv_len, kv_lora_rank]
- position_ids: [bsz, q_len]
'''
bsz, q_len, _ = hidden_states_q.size()
q_b_proj_rope, q_absorbed, out_absorbed = self.get_absorbed_proj()
q = self.q_a_layernorm(self.q_a_proj(hidden_states_q))
q_nope = torch.einsum('bqc,hdc->bhqd', q, q_absorbed)
q_pe = torch.einsum('bqc,hdc->bhqd', q, q_b_proj_rope)
cos, sin = self.rotary_emb(q_pe)
q_pe = apply_rotary_pos_emb(q_pe, cos, sin, q_position_ids)
kv_seq_len = compressed_kv.size(1)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
k_pe = k_pe.view(bsz, 1, kv_seq_len, self.qk_rope_head_dim)
attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * elf.softmax_scale
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(q_nope.dtype)
attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
attn_output = torch.einsum('bhqc,dhc->bqd', attn_output, out_absorbed)
return attn_output
5.3 融合算子
另外,如果針對prefill和decode階段進行不同處理,則在推理的時候Prefill 和Decode 走的邏輯不同。
-
推理的時候 Prefill 是不做矩陣吸收的(原因是Prefill做矩陣吸收會增加計算量),MLA計算與普通的MHA計算大致相同,唯一的區別在于需要支持q/k和v/o使用不同的head_dim。
-
Decode 是要做矩陣吸收的,矩陣吸收ops 遠小于矩陣不吸收。這是因為此時Q的長度是1,原本重復在KV 上做up projection的操作轉移到了Q 上,讓Q 投影到kv 的latent space 上,Q的長度遠小于KV的長度,不需要對KV做重復做up projection。或者說,MLA的主要思路就是通過交換矩陣計算順序,利用decode階段query seq_len比較小的特點,優化矩陣計算開銷,進而達到只存儲Multi-head attention中hidden states cache,而不是key和value兩個cache,進而降低一半KVCache存儲的目的。
因此Decode階段需要單獨設計高效的融合算子,以便高效地與低秩kv-cache進行attention計算。
權重吸收之后,公式如下:
可以用代碼描述如下,即可以設計一個MQA算子來實現。
q_pe = W_QR(c_q)
q_nope = W_UQ_UK(c_q)
output = W_UV_O(MQA(q_pe, q_nope, c_kv, k_pe))
FlashAttention最初設計的初衷是減少對softmax矩陣儲存的開銷,其大小正比于 \(l_q \cdot l_{kv}\),占整體I/O的比值為:
對于推理階段而言,\(l_q\) 其實是非常小的,不融合qk和pv兩階段的計算也能取得不錯的效果。但是對于MLA而言,融合是必要的,這是因為:
- MLA有較大的group ratio: \(??_{????}/??_{????}=128\) ,會增大softmax的占比。
- MLA復用了key和value矩陣,因此如果我們不融合兩階段的話,前后兩個算子將各自訪問一遍KV-Cache,如果硬件的cache不夠大的話,帶寬利用率將無法超過50%。
5.4 矩陣乘的重排序(增補@2025-04-19)
內容參考:DeepSeek V3推理: MLA與MOE解析 Arthur
具體特點如下:
- 方案來源:SGlang,應用于DeepSeek-V2。
- 方案特點:基于矩陣乘法結合律改變計算順序,從而優化注意力機制計算效率。在解碼階段,能夠有效減少計算量。
- 方案內容:
- 原始計算順序:\(q_{nope}k_{nope} + q_{rope}k_{rope}\)。其中\(q_{nope}k_{nope}\)的計算方式是\(q^T_{nope}(W^{UK}c)\)。FLOPs為 \(2d_c -1)hdn_k + (2d-1)hn_qn_k\)。
- 改進順序為:\((q^T_{nope}W^{UK})c\)。FLOPs為\((2d-1)hn_qd_c+(2d_c-1)hn_qn_k\)。
這種改變利用了矩陣乘法的結合律,使得計算可以在不同的維度上進行重組,在解碼階段(\(n_q=1\) ),優化后的方法可以顯著減少計算量。

0x06 轉換
6.1 GQA
Group Query Attention(GQA)是MHA的一種變體,旨在減少KV緩存的開銷。它將查詢頭分成多個組,每個組共享一個鍵和值對。這種方法通過減少鍵和值頭的數量來降低KV緩存的大小,但可能會犧牲模型的表達能力??梢詫QA看作是MLA的一種特例。由于GQA是通過復制產生的,而MLA不受這種限制,表達能力更強。
盡管MLA在Deepseek V2/V3/R1中已經證明了其效率和有效性,但許多主要的模型提供商仍然依賴GQA。為了促進MLA的更廣泛應用,論文“TransMLA: Multi-Head Latent Attention Is All You Need"提出了TransMLA,這是一種后訓練方法,可以將廣泛使用的基于GQA的預訓練模型(例如LLaMA、Qwen、Mixtral)轉換為基于MLA的模型。轉換后,模型可以進行額外的訓練以增強表達能力,而不會增加KV緩存的大小。
6.1.1 思路
論文首先證明了對于相同的KV緩存開銷,MLA的表達能力總是大于GQA。具體來說,任何GQA配置都可以等價地轉換為MLA表示,但反之不然。這一結論為將基于GQA的模型轉換為基于MLA的模型提供了理論基礎。
在等價轉換過程中,TransMLA方法首先將GQA中的鍵矩陣進行復制,以匹配查詢頭的數量。然后,它將這個復制后的鍵矩陣分解為兩個較小矩陣的乘積,從而得到MLA中的低秩表示。通過這種方法,TransMLA可以在不增加KV緩存大小的情況下,將基于GQA的模型轉換為基于MLA的模型。
6.1.2 方案
第一步是復制key矩陣,以匹配查詢頭的數量。在GQA中,為使標準多頭注意力計算時,??和??(以及??)具有相同數量的頭,需要對??進行擴展,從\(n_k\)個頭擴展到\(n_q\)個頭。這其實也有兩種方法。
- 定義復制因子\(??=\frac{??_??}{??_??}\)(\(??_??\)為??的頭數,\(??_??\)為??的頭數),將??按列劃分為\(??_??\)個塊\(??^{(??)}\),通過將每個\(??^{(??)}\)復制??次并連接,得到擴展矩陣??′。具體見下圖(a)。
- 另一種方法是將復制操作移到參數側(其實也是使用MHA替代GQA的方法),即在計算K之前,先復制投影矩陣\(W_K\)。先將\(??_??\)按列拆分為\(??_??\)個部分\(??_??^{(??)}\),然后復制每個\(??_??^{(??)}\) ??次并連接,形成新的投影矩陣\(??'_??\),再應用\(??'_??\)到??直接得到\(??′=????′_??\),此方法與先計算??再復制其頭在數學上是等效的。具體見下圖(b)。

由于\(??'_??\)由復制\(??_??\)形成,其自由度最多為\(??_????_?\),因此它的秩最多為\(??_????_?\)。為了更正式地理解這一點,通過奇異值分解(SVD)對\(??'_??\)進行分解:\(??'_??=??_????_????_??^?\) ,其中\(??_??\)和\(??_??\)是??×??正交矩陣,\(??_??\)是包含奇異值的??×??對角矩陣。只有前\(n_kd_h\)(或更少)的奇異值可能是非零的。因此,可以截斷SVD,只保留前 r 個奇異值,其中$ r \le n_kd_h\(。則\)??'_??=??_??????_????\(且\)??′=????_??????_????$ 。這樣就將GQA的“重復KV”方案解釋為類似MLA的低秩分解形式,在實際緩存時,僅需存儲低秩表示\(????_??^??\),在注意力計算時通過乘以\(??_??^??\)恢復完整維度,增強了模型的表現力。

6.2 MHA
如何使原本為 MHA 訓練的 LLMs(如 Llama)快速適應 MLA 進行推理,而無需從頭開始預訓練,既具有意義又充滿挑戰。論文“Towards Economical Inference: Enabling DeepSeek’s Multi-Head Latent Attention in Any Transformer-based LLMs” 第一種數據高效的微調方法MHA2MLA,用于*從MHA轉換到MLA。該方法包含兩個關鍵組件:
-
對于partial-RoPE,論文從對注意力分數貢獻較小的查詢和鍵的維度中去除 RoPE。
-
對于低秩近似,論文基于鍵和值的預訓練參數引入聯合SVD近似。
這些精心設計的策略使 MHA2MLA 僅使用極少部分(3‰至 6‰)的數據就能恢復性能,顯著降低推理成本,同時能與 KV 緩存量化等壓縮技術無縫集成。

6.2.1 partial-RoPE
為實現從標準 MHA 到 MLA 的遷移,論文提出 partial-RoPE 微調策略,從目標比例的維度中去除 RoPE 并轉換為 NoPE。
MHA
MHA 的 Full-RoPE 通過特定頻率的旋轉將位置信息編碼到查詢和鍵中,具體如下圖所示。

拆解
MLA中,\(k_i\)由\([k_{i,nope};k_{i,rope}]\)組成,所以我們首先需要把MHA的\(k_{i,rope}\)也分解成這樣的無RoPE編碼和有RoPE兩部分。
DeepSeek的MLA里面其實是在原始的每個head的不使用RoPE編碼\(d_h\)維度上,再增加一個使用RoPE編碼的\(d_h^R\)維度。但是我們現在只能把全長為\(d_h\)維度的\(k_{i,rope}\)進行拆解,把里面\(d_r,dr \ll d_h\)部分做RoPE編碼。也就是\(r=\frac{d_r}{2}\)長度的2D子空間做旋轉編碼。
在注意力計算中,并非所有維度上的旋轉位置編碼(RoPE)都對結果有同等的貢獻。Partial-RoPE 技術通過去除對結果貢獻較小的維度上的 RoPE,減少了冗余計算。這就像是在一場考試中,抓住重點知識進行復習,避免在一些無關緊要的知識點上浪費時間。通過這種方式,Partial-RoPE 技術在不影響模型性能的前提下,有效提升了計算效率。
在從 Full-RoPE 轉換到 Partial-RoPE 時,我們選擇哪一部分子空間來做旋轉編碼呢?論文提出四種策略(主要是依據旋轉的頻率)來旋轉 RoPE 編碼的子空間。
- 高頻保留:保留 r 個旋轉最快(高頻)的子空間,即位置最靠前的個2D子空間。
- 低頻保留:保留 r 個旋轉最慢(低頻)的子空間。
- 均勻采樣:選擇間隔相等的 r 個子空間,即不管是高頻還是低頻,按照等距離采樣,這樣高低頻都分別有一部分。
- 根據每個頭2-norm貢獻選擇(Head-wise 2-norm Contribution):根據每個頭中各子空間的 2-norm分數對所有子空間進行排序,選擇前 r 個。第 r 個頻率子空間對最終的attention logits的貢獻有上界。

選擇好了\(d_h\)維度中的\(d_r\)維度做RoPE位置編碼,剩下的\(d_h - d_r\)部分我們就要當成當成MLA中的無位置編碼部分,也就是\(q_{nope}\)。但是要注意DeepSeek的MLA中這部分維度是\(d_h\),我們這里是\(d_h - d_r\)。
6.2.2 低秩近似
MHA中的\(k_i = W_kx_i,v_i=W_vx_i\)。我們已經使用上面的四種方法之一找到了需要做RoPE的部分,也就可以把\(W_k\)對應的部分取出來得到\(W^{KR}\)。
我們也把\(W_k\)中對應非RoPE的部分參數提取出來:
我們的目標是從\(W_{k,nope},W_{v,nope}\)中構造出MLA中的\(W^{DKV}\)。
從 Full RoPE 轉換到 Partial RoPE 后,為得到 MLA 中 KV 緩存的第二個組件\(c_{i,kv}\),論文提出兩種基于SVD的策略:解耦 SVD和聯合 SVD,具體參見下圖。
- 解耦 SVD(\(SVD_{split}\)):分別對\(W_{k,nope}\)和\(W_n\)進行截斷 SVD 分解,分配\(d_{kv}/2\)個維度給每個矩陣。
- 聯合 SVD(\(SVD_{joint}\)):為保留\(K_{nope}\)和V之間的交互,對連接矩陣\([W_{k,nope},W_v]\)進行聯合分解。這種分解方式更加貼合MLA的標準格式。

到這里,我們就處理完了key和value部分。query部分并不像DeepSeek里面的MLA一樣再做低秩分解,而是把得到的query對應key中的nope和rope部分也分解成兩部分。
0xFF 參考
DP MLA For DeepSeek In Sglang 是小肖啊
DeepSeek V3, R1, Janus-Pro系列模型方法解讀 榴蓮酥
【LLM算法】MLA 技術在 DeepSeek-R1 大顯神通,清華 TransMLA 將 GQA 一鍵轉換成 MLA SmartMindAI
首個參數高效微調框架:在任何LLMs中使用DeepSeek的MLA AcademicDaily00 [AcademicDaily](javascript:void(0)??
【LLM算法】MLA 技術在 DeepSeek-R1 大顯神通,清華 TransMLA 將 GQA 一鍵轉換成 MLA SmartMindAI
DeepSeekV2之MLA(Multi-head Latent Attention)詳解 一滴水的使命
DeepSeek模型解讀:Scaling Law,MLA,MoE JMXGODLZ
還在用MHA?MLA來了DeepSeek-v2的MLA的總結和思考 rainbow
一文通透DeepSeek-V2(改造Transformer的中文模型):詳解MoE、GRPO、MLA v_JULY_v
DeepSeekV2之MLA(Multi-head Latent Attention)詳解 一滴水的使命
大模型KV Cache節省神器MLA學習筆記(包含推理時的矩陣吸收分析) BBuf
用PyTorch從零開始編寫DeepSeek-V2 Deephub
圖解Mixtral 8 * 7b推理優化原理與源碼實現 猛猿
從MHA到MLA看Attention優化:談談DeepSeek拼多多級的推理價格 扎波特的橡皮擦 [zartbot](javascript:void(0)??
繼續談談MLA以及DeepSeek-MoE和SnowFlake Dense-MoE 扎波特的橡皮擦 [zartbot](javascript:void(0)??
關于 MHLA(Multi-Head Latent Attention)的一些分析 Zhengxiao Du
[LLM底座] 關于DeepSeek-V2中的MLA(含代碼) 莫冉
如何看待 DeepSeek 發布的 MoE 大模型 DeepSeek-V2? 鄭華濱
緩存與效果的極限拉扯:從MHA、MQA、GQA到MLA 蘇劍林
DeepSeek-V2 高性能推理 (1):通過矩陣吸收十倍提速 MLA 算子 ZHANG Mingxing
速讀 deepseek v2(一) —— 理解MLA Bruce 仗劍走天涯
還在用MHA?MLA來了DeepSeek-v2的MLA的總結和思考 rainbow
如何看待 DeepSeek 發布的 MoE 大模型 DeepSeek-V2? - 知乎 (zhihu.com)
Deepseek-V2技術報告解讀!全網最細! (qq.com) [包包算法筆記](javascript:void(0)?? 2
DeepSeek-V2高性能推理優化筆記:MLA優化 madsys-dev
LLM 加速技巧:Muti Query Attention deephub
大模型基礎|注意力機制|MHA|稀疏|MQA|GQA 養生的控制人
Attention優化:Flash Attn和Paged Attn,MQA以及GQA miangangzhen
大模型輕量級微調(LoRA):訓練速度、顯存占用分析 絕密伏擊
MLKV:跨層 KV Cache 共享,降低內存占用 AI閑談
繼續談談MLA以及DeepSeek-MoE和SnowFlake Dense-MoE 扎波特的橡皮擦 [zartbot](javascript:void(0)??
【深度學習】DeepSeek核心架構-MLA:剖析低秩聯合壓縮優化KV緩存、提升推理效率的技術細節 趙南夏 [南夏的算法驛站](javascript:void(0)??
DeepSeek-R1模型架構深度解讀(二)MLA [AI算法之道](javascript:void(0)??
SGLang DP MLA 特性解讀 BBuf [GiantPandaCV](javascript:void(0)??
【LLM論文詳解】MLA 技術在 DeepSeek-R1 大顯神通,清華 TransMLA 將 GQA 一鍵轉換成 MLA AI-PaperDaily [AI-PaperDaily](javascript:void(0)??
TransMLA: Multi-Head Latent Attention Is All You Need
SGLang DP MLA 特性解讀 BBuf [GiantPandaCV](javascript:void(0)??
從代碼角度學習和徹底理解 DeepSeek MLA 算法 chaofa用代碼打點醬油
全網最細!DeepSeekMLA 多頭隱變量注意力:從算法原理到代碼實現 懂點AI事兒
deepseek技術解讀(1)-徹底理解MLA(Multi-Head Latent Attention) 姜富春
[代碼學習]deepseek-v2的inference code學習-MLA-part 1 迷途小書僮
[代碼學習]deepseek-v2的inference code學習-MLA -part 3 迷途小書僮
[代碼學習]deepseek-v2的inference code學習-MLA -part 4 迷途小書僮
[代碼學習]deepseek-v2的inference code學習-MLA -part 2 迷途小書僮
緩存與效果的極限拉扯:從MHA、MQA、GQA到MLA 蘇劍林
DeepSeek開源FlashMLA之際從原理到代碼詳解MLA 杜凌霄 [探知軒](javascript:void(0)??
首個參數高效微調框架:在任何LLMs中使用DeepSeek的MLA [AcademicDaily](javascript:void(0)??
如何把預訓練好的模型中的MHA變為MLA? 杜凌霄 [探知軒](javascript:void(0)??
終于把 deepseek 中的多頭潛在注意力機制搞懂了!! 程序員小寒 [程序員學長]
DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model
DeepSeek-V2 高性能推理 (1):通過矩陣吸收十倍提速 MLA 算子
細說DeepSeek MLA矩陣消融 formath 2025-02-24
DP MLA For DeepSeek In Sglang 是小肖啊
SGLang DP MLA 特性解讀 BBuf
DeepSeek V2/V3中的MLA和Matrix Absorption ariesjzj
FlashInfer中DeepSeek MLA的內核設計 yzh119
終于把 deepseek 中的多頭潛在注意力機制搞懂了??! 程序員小寒 [程序員學長](javascript:void(0)??
DeepSeek 開源周第一天開源的項目 FlashMLA,有哪些亮點值得關注? SIY.Z
大模型KV Cache節省神器MLA學習筆記(包含推理時的矩陣吸收分析)
DeepSeek V2 “多頭潛在注意力”論文解讀 (上) 大模型咖啡時間
Deepseek MLA 一定要做吸收嗎? 代碼搬運工
DeepSeek V3推理: MLA與MOE解析 Arthur
DeepSeek MLA引發的一些記憶碎片 YyWangCS
[Deepseek v3技術報告學習] 1.MLA Duludulu
attention中的concat能不能換成相加? Zhai Feiyue
sglang mla 代碼解析 hcy
SGLang MLA 實現解析 BBuf
DeepSeek V3推理: MLA與MOE解析 Arthur
理解 FlashMLA 在 DeepSeek MLA 計算過程中的位置和作用 solrex [邊際效應]
MLA 吸收之謎 拉航母的小朱
DeepSeek-V3/R1推理效率分析(v0.17) zartbot
DeepSeek V3/R1 推理效率分析(2): DeepSeek 滿血版逆向工程分析 Han Shen
DeepSeek V3/R1 推理效率分析(3):Decode 配置泛化討論 Han Shen
DeepSeek V3/R1 推理效率分析(1):關于DeepSeek V3/R1 Decoding吞吐極限的一些不負責任估計 Han Shen
MoE Inference On AnyScale MoE-On-AnyScale
基于 chunked prefill 理解 prefill 和 decode 的計算特性 Chayenne Zhao
LLM PD 分離背后的架構問題 極客博哥
deepseek MLA推理優化 屈屈臣氏
prefill 和 decode 該分離到不同的卡上么? Chayenne Zhao
[1. deepseek模型學習筆記?](https://developnotes.readthedocs.io/zh-cn/latest/deepseek.html#id1) 李偉華
DeepSeek-V3 (671B) 模型參數量分解計算 ZihaoZhao
vLLM 深度解析:Deekseek and vLLM -1 stephenxi
DeepSeek MLA在SGLang中的推理過程及代碼實現 榴蓮酥
MHA->MQA->GQA->MLA的演進之路 假如給我一只AI
The Annotated Transformer https://nlp.seas.harvard.edu/2018/04/03/ention.html
Attention Is All You Need https://arxiv.org/pdf/1706.03762.pdf
Fast Transformer Decoding: One Write-Head is All You Need https://arxiv.org/pdf/1.02150.pdf
https://www.researchgate.net/figure/led-dot-product-self-attention-mechanism_fig1_363923096
GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints https://arxiv.org/pdf/5.13245.pdf
How Attention works in Deep Learning: understanding the attention mechanism in sequence models https://theaisummer.com/ention/
A simple overview of RNN, LSTM and Attention Mechanism https://medium.com/swlh/imple-overview-of-rnn-lstm-and-attention-mechanism-9e844763d07b
淺談Transformer的初始化、參數化與標準化 https://spaces.ac.cn/archives/0
https://theaisummer.com/self-attention/ ps://theaisummer.com/self-attention/
https://zhuanlan.zhihu.com/p/626820422 https://zhuanlan.zhihu.com/p/626820422
Are Sixteen Heads Really Better than One? https://arxiv.org/pdf/5.10650.pdf
This post is all you need(上卷)——層層剝開Transformer https://zhuanlan.zhihu.com/p/420820453
The Illustrated Transformer https://jalammar.github.io/ustrated-transformer/
Multi-Query Attention is All You Need https://blog.fireworks.ai/multi
DeepSeek MLA的序列并行和張量并行 YyWangCS
DP MLA For DeepSeek In Sglang 是小肖啊
SGLang MLA 實現解析 BBuf
Multi-Head Latent Attention (MLA) 詳細介紹(來自Deepseek V3的回答) 銀翼的魔朮師
DeepSeek面試通關(1)|MLA如何讓推理效率飆升200%? 丁師兄大模型
浙公網安備 33010602011771號