模型算法-MHA-MQA-GQA(1)
1. 介紹:
基于最近對大模型 KV_cache,及 Attention 變種學習中遇到的問題和理解記錄下來,幫助大家解決一點疑惑。
2. kv_cache 顯存對比:
參數說明
- batch_size:B
- seq_len:L
- head_num:H
- head_dim:D
- layer_num:N
- group_size:G,每組 Q_head 數量
- embedding_dim:D_em = H * D
MHA : 2 * BLHDN * sizeof(DataType)
MQA:2 * BLDN * sizeof(DataType)
GQA:2 * BLDN * (H/G) * sizeof(DataType)
3. MQA和GQA計算量沒有減少,為什么能夠加速?
- 因為頭的數量減少,WK WV矩陣參數量減少,帶來前置計算量減少。
4. MQA 多頭Q與單頭 KV 計算如何組織數據?
MQA:
- Q_mul_heads 從 (B, S, H, D) reshape 為 (B, H, S, D);
- K_head 從 (B, S, 1, D) reshape 為 (B, 1, D, S);
matmul(Q_mul_heads, K_head) = (B, H, S, S) ,matmul 將 K_head 復制 H 份與 Q_head 計算。
GQA:
- Q_mul_heads 從 (B, S, H, D) reshape 為 (B, H, S, D);
- K_head 從 (B, S, H/G, D) -> (B, S, H/G, 1, D) ,再 expand 復制最后一個維度為 (B, S, H/G, G, D), reshape 為 (B, S, H, D) 與 Q_mul_heads 大小一致, 再 reshape 為 (B, H, D, S) 可以進行 malmul 計算。

浙公網安備 33010602011771號