MPK(Mirage Persistent Kernel)源碼筆記(1)--- 基礎(chǔ)原理
MPK(Mirage Persistent Kernel)源碼筆記(1)--- 基礎(chǔ)原理
0x00 概要
CMU 賈志豪老師團隊提出的MPK(Mirage Persistent Kernel)是依托 Mirage 編譯器生態(tài)的創(chuàng)新運行時系統(tǒng),其核心能力在于將多GPU環(huán)境下大語言模型(LLM)推理任務自動轉(zhuǎn)換為適配GPU架構(gòu)的高性能巨型內(nèi)核(megakernel)。MPK的關(guān)鍵優(yōu)勢在于將傳統(tǒng)由CPU負責的內(nèi)核調(diào)度和任務依賴管理工作轉(zhuǎn)移到GPU端,通過“長期駐留的巨型內(nèi)核(Persistent Kernel)”自主完成,同時統(tǒng)籌GPU內(nèi)部計算與跨GPU通信任務。這種設(shè)計不僅大幅削減了CPU-GPU交互帶來的內(nèi)核啟動開銷,還通過計算與通信的細粒度重疊,將推理延遲優(yōu)化至接近硬件物理極限,顯著提升推理效率。
0.1 傳統(tǒng)LLM推理框架的瓶頸
傳統(tǒng)LLM推理框架的流程存在固有瓶頸:CPU需逐個發(fā)起CUDA內(nèi)核調(diào)用(如矩陣乘法、激活函數(shù)計算),待GPU執(zhí)行完當前內(nèi)核并反饋后,再觸發(fā)下一個內(nèi)核。這種“CPU發(fā)起-GPU執(zhí)行-CPU等待”的循環(huán),會產(chǎn)生頻繁的CPU-GPU通信與內(nèi)核啟動開銷,尤其在自回歸生成場景中,單次token生成需多輪內(nèi)核調(diào)用,開銷會持續(xù)累積,嚴重拖累整體推理性能。
0.2 MPK的流程重構(gòu)
MPK徹底重構(gòu)了這一流程:它僅需CPU在推理初始化階段,向GPU提交一個“永不主動退出的persistent_kernel”,之后所有任務分派(如層間計算順序)、依賴管理(如等待前一層結(jié)果再執(zhí)行下一層)均由GPU內(nèi)部自主完成。此時CPU的角色從“實時調(diào)度的包工頭”轉(zhuǎn)變?yōu)椤皢映跏蓟拈T衛(wèi)”,僅負責觸發(fā)首次內(nèi)核啟動,后續(xù)不再參與任何具體調(diào)度。
0.3 MPK的關(guān)鍵優(yōu)勢
MPK通過將多GPU的LLM推理任務轉(zhuǎn)換為高性能的巨型內(nèi)核,從根本上改變了GPU的運行模式。它不僅減少了內(nèi)核啟動開銷,還通過細粒度的軟件流水線和計算通信重疊,顯著提高了推理效率。MPK提供了一種全新的思路,將性能優(yōu)化的重心從“如何調(diào)用優(yōu)化庫”轉(zhuǎn)移到了“如何為整個模型生成一個最優(yōu)的、原生的執(zhí)行體”,在多GPU環(huán)境下實現(xiàn)了更高的吞吐量和更低的延遲。
0x01 問題
重新設(shè)計類似 Mirage 的 MegaKernel的優(yōu)勢,是將所有計算和通信融合進一個單一的巨型內(nèi)核(也稱為持續(xù)內(nèi)核)是降低大語言模型推理延遲的最有效方法之一。這種方法通過啟動一個GPU內(nèi)核來執(zhí)行整個模型,從逐層計算到GPU間通信,整個過程無需中斷。盡管有這些優(yōu)勢,將LLM編譯成巨型內(nèi)核仍然極具挑戰(zhàn)性。
1.1 現(xiàn)有框架問題
現(xiàn)有框架難以支持單一的巨型內(nèi)核。
-
現(xiàn)有的高級ML框架,如PyTorch、Triton和TVM,并不原生支持端到端巨型內(nèi)核生成。
-
現(xiàn)代LLM系統(tǒng)由各種不同的專用內(nèi)核庫構(gòu)建而成,這種碎片化使得將整個推理流水線整合進一個單一的、統(tǒng)一的內(nèi)核變得非常困難。
-
高性能GPU內(nèi)核的手工編寫需要大量的專家知識,如何自動生成高性能內(nèi)核代碼是一個痛點問題。傳統(tǒng)做法依賴于專家編寫好的內(nèi)核或者手工融合規(guī)則,但這些方法維護成本高,容易漏掉跨內(nèi)核/層級組合優(yōu)化的機會。
1.2 編程抽象層級
從編程抽象層級上來看,也缺乏最優(yōu)系統(tǒng)。
1.2.1 GPU架構(gòu)
下圖展示了當今 GPU 的層次結(jié)構(gòu)。GPU 上的計算被組織為內(nèi)核,每個內(nèi)核都是一個函數(shù),以單程序多數(shù)據(jù)(SPMD)的方式在多個 GPU 核心上同時執(zhí)行。一個內(nèi)核包括一個線程塊網(wǎng)格,每個線程塊在一個 GPU 流式多處理器上執(zhí)行,并包括多個線程來對單個數(shù)據(jù)元素進行計算。每個線程都與一個每線程寄存器文件相關(guān)聯(lián),并且線程塊內(nèi)的所有線程都可以訪問共享內(nèi)存以啟用集體操作。最后,內(nèi)核的所有輸入和輸出都存儲在 GPU 設(shè)備內(nèi)存中。
下圖是GPU hierarchy。

下圖為GPU 計算架構(gòu)和編程抽象示意圖

1.2.2 編程視角
Triton 是一款高級 GPU 編程框架,其編程視角主要聚焦于塊(Block)級別。該框架的設(shè)計允許開發(fā)者以塊為單位進行編程,而塊內(nèi)部的優(yōu)化工作則由 Triton 編譯器自動完成。這種設(shè)計模式使開發(fā)者能夠?qū)⒕性诟邔舆壿嫷臉?gòu)建上,無需深入研究線程(Thread)級別的細節(jié)實現(xiàn)。Triton 的核心優(yōu)勢在于其簡潔的編程模型和自動化優(yōu)化能力,這使得它在處理復雜并行任務時具有更高的效率。
Cutlass 則屬于底層 GPU 編程庫,其編程視角覆蓋了塊(Block)、線程束(Warp)與線程(Thread)的完整層級。Cutlass 提供了豐富的 CUDA 模板和底層控制接口,開發(fā)者可以利用這些工具精細調(diào)控每個線程的行為,從而實現(xiàn)高度優(yōu)化的計算內(nèi)核。這種細粒度的控制能力讓 Cutlass 在對性能有極致要求的場景中表現(xiàn)出色,但同時也增加了編程的復雜性。
正是這種編程視角的層級差異,構(gòu)成了當前高性能 GPU 編程領(lǐng)域的核心挑戰(zhàn):缺乏一套能夠 “跨內(nèi)核(Kernel)、線程塊(Block)、線程(Thread)三個層級” 聯(lián)合搜索最優(yōu)計算方案,并自動驗證方案正確性的系統(tǒng)。現(xiàn)有框架要么局限于單一層級的優(yōu)化(例如 Triton 僅針對塊內(nèi)部邏輯進行優(yōu)化,而 Cutlass 則需要開發(fā)者手動協(xié)調(diào)全層級的適配),要么無法在多層級協(xié)同后確保計算結(jié)果的準確性。這一問題在大型語言模型(LLM)推理等復雜張量計算場景中,會顯著增加開發(fā)成本與優(yōu)化難度。
0x02 總體思路
MegaKernel 可被視為一種 grid 級(網(wǎng)格級)的內(nèi)核抽象。與 CUDA 的 thread 級(線程級)抽象、Triton 的 block 級(塊級)抽象不同,它提供了層次更高的抽象能力,允許開發(fā)者在 grid 級開展編程工作。這種抽象設(shè)計能讓開發(fā)者更靈活地管理 GPU 上的計算資源,進而實現(xiàn)更高效的內(nèi)核生成與執(zhí)行。
MPK 的工作原理主要包含以下兩部分:
- MPK 編譯器:負責將大語言模型(LLM)的計算圖轉(zhuǎn)換為經(jīng)過優(yōu)化的任務圖。
- MPK 運行時系統(tǒng):在單個巨型內(nèi)核內(nèi)部執(zhí)行任務圖,以此達成高吞吐量與低延遲的目標。
2.1 編譯過程
- 模型翻譯:將 PyTorch 框架下的模型翻譯為 MPK 的指令集,這一步驟本質(zhì)上相當于用 MPK 的指令重新構(gòu)建模型的過程。盡管 PyTorch 具備強大的自動微分與優(yōu)化能力,但要將模型完整轉(zhuǎn)換為 MPK 的指令集,仍需進行大量手動調(diào)整與優(yōu)化操作。
- 任務圖生成:編譯器會將翻譯后的模型進一步轉(zhuǎn)換為細粒度任務圖。該任務圖屬于有向無環(huán)圖(DAG),圖中每個節(jié)點代表一項具體任務,節(jié)點間的邊則代表任務之間的依賴關(guān)系。這一步驟要求編譯器能夠準確識別并優(yōu)化任務間的依賴關(guān)系,為后續(xù)高效調(diào)度奠定基礎(chǔ)。
2.2 執(zhí)行過程
- 任務調(diào)度:將生成的任務圖交付調(diào)度器執(zhí)行。調(diào)度器負責管理 GPU 中的流式多處理器(SM),并通過 warp specialization(線程束特化)技術(shù)將 SM 劃分為 worker(工作單元)與 scheduler(調(diào)度單元)。這種設(shè)計與數(shù)據(jù)處理領(lǐng)域的 actor 模型(角色模型)相似:scheduler 負責協(xié)調(diào)任務的執(zhí)行順序,worker 則負責具體執(zhí)行分配到的任務。
- 性能優(yōu)化:在小模型與低 Batch(批次)場景下,MPK 通過多種方式顯著降低延遲,具體包括:消除內(nèi)核啟動開銷、打破內(nèi)核邊界限制、實現(xiàn)細粒度的 SM 調(diào)度,以及對任務特定模式進行融合。
0x03 通過代碼來打通流程
我們以demo_chat.py為例來進行全局打通。在此文件中會將Python模型結(jié)構(gòu)映射為Mirage的計算圖表示,然后編譯為高效的持久化CUDA內(nèi)核執(zhí)行。
3.1 核心模塊說明
3.1.1 三層結(jié)構(gòu)化圖模型
Mirage 實現(xiàn)了多層次計算圖表示(μGraphs),通過 kernel-graph、block-graph 和 thread-graph 這三層結(jié)構(gòu)化圖模型,精確映射了 GPU 程序從內(nèi)核到線程的執(zhí)行邏輯與存儲層級。這種三層結(jié)構(gòu)與 CUDA 程序的執(zhí)行層級及 GPU 的存儲體系緊密對應,每層都清晰定義了“算子類型 - 張量存儲 - 核心功能”的關(guān)聯(lián)。
三層圖功能如下:
- Kernel Graph 是最高計算圖,定義整個執(zhí)行流程。通過自定義操作管理多個block graph
- Block Graph 是嵌套在自定義操作中,定義線程塊執(zhí)行序列
- Thread Graph是最低層,定義線程級別執(zhí)行細節(jié)
3.1.2 PersistentKernel
PersistentKernel 作為計算圖的容器和執(zhí)行器,提供了從計算圖構(gòu)建、優(yōu)化到執(zhí)行的過程。
persistent_kernel.py是 PersistentKernel的Python接口,本質(zhì)是Python到CUDA持久化內(nèi)核系統(tǒng)的橋梁,允許用戶用python定義復雜的計算圖,然后在GPU上高效執(zhí)行。
3.1.3 層級關(guān)系
計算圖與 PersistentKernel 的關(guān)系如下:
-
包含關(guān)系:PersistentKernel 內(nèi)部包含并管理一個 Kernel graph
-
構(gòu)建關(guān)系:通過 PersistentKernel 的各種layer方法構(gòu)建計算圖。
-
轉(zhuǎn)換關(guān)系:PersistentKernel 將計算圖轉(zhuǎn)換為可執(zhí)行的任務圖
-
執(zhí)行關(guān)系:PersistentKernel 是計算圖的執(zhí)行引擎。
3.1.4 數(shù)據(jù)流關(guān)系
數(shù)據(jù)流關(guān)系可以近似如下圖所示:
應用層:PersistentKernel.py(創(chuàng)建并管理kernel graph)
│
│
▼
輸入張量
│
│
▼
計算圖節(jié)點(各種layer方法添加)
│
│
▼
任務層:kernel graph(包括所有操作和計算流,即定義張量數(shù)據(jù)流)
│
│
▼
并行層:block graph(嵌套在自定義操作中,定義線程塊執(zhí)行序列,即定義內(nèi)存訪問模式)
│
│
▼
執(zhí)行層:task graph(kernel graph生成的可執(zhí)行任務圖,taskDesc是可執(zhí)行任務,EventDesc管理事件同步和依賴)
│
│
▼
運行時環(huán)境:PersistentKernel 執(zhí)行引擎
│
│
▼
硬件層:Thread graph,在實際GPU線程中執(zhí)行具體操作
3.2 main()代碼
demo_chat.py的main()如下。
def main():
world_size, rank = setup_distributed_environment()
model, tokenizer = load_model_and_tokenizer(rank)
tokens = torch.full((1, MAX_SEQ_LEN), 0, dtype=torch.long, device="cuda")
step_tensor = torch.tensor([0], dtype=torch.int32, device="cuda")
mpk = None
if args.use_mirage:
# 構(gòu)建計算圖
mpk = build_mirage_graph(model, world_size, rank, args, tokens, step_tensor)
positions = torch.arange(MAX_SEQ_LEN).unsqueeze(0).to(model.device)
position_embeddings = model.model.rotary_emb(positions)
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
while True:
prompt_container = [None]
if rank == 0:
try:
prompt = input("> User: ")
prompt_container[0] = prompt
except EOFError:
prompt_container[0] = "exit"
if world_size > 1:
dist.broadcast_object_list(prompt_container, src=0)
prompt = prompt_container[0]
messages.append({"role": "user", "content": prompt})
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
new_prompt_len = model_inputs.input_ids.shape[-1]
tokens[0, :new_prompt_len] = model_inputs.input_ids[0]
if new_prompt_len < tokens.shape[1]:
tokens[0, new_prompt_len:] = 0
prompt_len = new_prompt_len
if args.use_mirage:
end_pos, run_time, generated_len = run_mirage_generation(
model, mpk, tokens, prompt_len, step_tensor, position_embeddings
)
else:
end_pos, run_time, generated_len = run_pytorch_generation(
model, tokens, prompt_len, step_tensor, position_embeddings
)
if rank == 0:
assistant_response_ids = tokens[0, prompt_len:end_pos]
assistant_response = tokenizer.decode(assistant_response_ids, skip_special_tokens=True)
if world_size > 1:
dist.destroy_process_group()
print("Exiting demo.")
總體過程如下:
- 模型定義階段
- 使用PyTorch/HuggingFace定義模型結(jié)構(gòu)。
- 加載預訓練權(quán)重
- 初始化輸入張量和相關(guān)參數(shù)。
- 任務圖構(gòu)建階段
- 通過KNOperator定義計算操作
- 構(gòu)建完整的計算圖結(jié)構(gòu)
- 設(shè)置任務配置參數(shù)
- 任務圖優(yōu)化階段
- 分析任務間的依賴關(guān)系
- 生成事件描述以管理依賴
- 對任務進行合理分組以優(yōu)化執(zhí)行
- 任務圖轉(zhuǎn)換階段
- 生成TaskDesc描述每個計算任務
- 生成EventDesc描述任務間同步事件
- 生成CUDA可執(zhí)行代碼
- 輸出JSON配置文件用于運行時加載。
- 運行時初始化階段
- 配置GPU資源(worker,調(diào)度器等)
- 分配GPU內(nèi)存給任務隊列和事件隊列
- 初始化工作隊列和調(diào)度隊列。
- 設(shè)置事件計數(shù)器和相關(guān)同步機制
- 持久化內(nèi)核運行階段
- worker執(zhí)行具體計算任務
- 調(diào)度器負責任務調(diào)度和事件管理
- 通過事件機制協(xié)調(diào)任務間的依賴關(guān)系
- 支持多GPU環(huán)境下的分布式執(zhí)行。
3.3 關(guān)鍵步驟
3.3.1 計算圖構(gòu)建過程
此處對應模型翻譯過程,即將 PyTorch 框架下的模型翻譯為 MPK 的指令集,這一步驟本質(zhì)上相當于用 MPK 的指令重新構(gòu)建模型的過程。盡管 PyTorch 具備強大的自動微分與優(yōu)化能力,但要將模型完整轉(zhuǎn)換為 MPK 的指令集,仍需進行大量手動調(diào)整與優(yōu)化操作。
模型轉(zhuǎn)換為計算圖的工作是在build_mirage_graph函數(shù)中,其主要步驟如下:
初始化持久化內(nèi)核
首先構(gòu)建PersistentKernel實例。
mpk = mi.PersistentKernel(
world_size=world_size,
mpi_rank=rank,
num_workers=96,
num_local_schedulers=48,
num_remote_schedulers=0,
max_seq_length=4096,
eos_token_id=model.config.eos_token_id,
meta_tensors=[step_tensor, tokens_tensor],
profiler_tensor=profiler_tensor,
)
定義張量
將模型權(quán)重和中間張量添加到計算圖中。
# 輸入張量
x = mpk.attach_input(torch_tensor=input_tokens, name="input_token")
# 位置編碼
positions = torch.arange(MAX_SEQ_LEN).unsqueeze(0).to(model.device)
position_embeddings = model.model.rotary_emb(positions)
x = mpk.attach_input(torch_tensor=input_tokens, name="input_token")
cos_pos_embed = mpk.attach_input(
torch_tensor=position_embeddings[0][0, :MAX_CONTEXT_LEN, :],
name="cos_position_embedding",
)
sin_pos_embed = mpk.attach_input(
torch_tensor=position_embeddings[1][0, :MAX_CONTEXT_LEN, :],
name="sin_position_embedding",
)
# 計算圖的中間結(jié)果張量
embed_out = mpk.new_tensor(dims=(batch_size, hidden_size), dtype=mi.bfloat16, name="embed_out")
attn_in = mpk.new_tensor(dims=(batch_size, fused_outdim_1 // world_size), dtype=mi.bfloat16, name="attn_in")
attn_out = mpk.new_tensor(dims=(batch_size, num_local_q_heads * head_dim), dtype=mi.bfloat16, name="attn_out")
is_nvshmem = "nvshmem_tensor" if world_size > 1 else "cuda_tensor"
attn_proj_out = mpk.new_tensor(dims=(batch_size, hidden_size), dtype=mi.bfloat16, name="attn_proj_out", io_category=is_nvshmem)
allreduce_buf = mpk.new_tensor(dims=(world_size, batch_size, hidden_size), dtype=mi.bfloat16, name="all_reduce_buf", io_category=is_nvshmem)
attn_allreduce_out = mpk.new_tensor(dims=(batch_size, hidden_size), dtype=mi.bfloat16, name="attn_allreduce_out", io_category=is_nvshmem)
mlp_mid = mpk.new_tensor(dims=(batch_size, fused_outdim_2 // world_size), dtype=mi.bfloat16, name="mlp_mid")
mlp_out = mpk.new_tensor(dims=(batch_size, hidden_size), dtype=mi.bfloat16, name="mlp_out", io_category=is_nvshmem)
mlp_final = mpk.new_tensor(dims=(batch_size, hidden_size), dtype=mi.bfloat16, name="mlp_final", io_category=is_nvshmem)
argmax_in = mpk.new_tensor(dims=(batch_size, vocab_size), dtype=mi.bfloat16, name="argmax_in")
argmax_part_value = mpk.new_tensor(dims=(batch_size, 96), dtype=mi.bfloat16, name="argmax_part_value")
argmax_part_index = mpk.new_tensor(dims=(batch_size, 96), dtype=mi.int64, name="argmax_part_index")
argmax_out = mpk.new_tensor(dims=(batch_size, 1), dtype=mi.int64, name="argmax_out")
構(gòu)建計算層
通過調(diào)用各種layer方法將模型層添加到計算圖。此處會把HuggingFace模型權(quán)重映射到Mirage張量。也可以融合張量以提高計算效率。
# --- Define the Model Graph ---
w_embed = mpk.attach_input(torch_tensor=model.model.embed_tokens.weight, name="embed_tokens")
mpk.embed_layer(input=x, weight=w_embed, output=embed_out, grid_dim=(1, 1, 1), block_dim=(128, 1, 1))
x = embed_out
for i, layer in enumerate(model.model.layers):
# Attention block
w_norm_attn = mpk.attach_input(torch_tensor=layer.input_layernorm.weight, name=f"layer_{i}_input_layernorm")
w_q = mpk.attach_input(torch_tensor=layer.self_attn.q_proj.weight, name=f"layer_{i}_q_proj")
w_k = mpk.attach_input(torch_tensor=layer.self_attn.k_proj.weight, name=f"layer_{i}_k_proj")
w_v = mpk.attach_input(torch_tensor=layer.self_attn.v_proj.weight, name=f"layer_{i}_v_proj")
w_qkv = mpk.fuse_tensors(inputs=[w_q, w_k, w_v], fused_dim=0, num_groups=num_local_kv_heads, name=f"layer_{i}_qkv_proj")
mpk.rmsnorm_linear_layer(input=x, weight_norm=w_norm_attn, weight_linear=w_qkv, output=attn_in, grid_dim=(96, 1, 1), block_dim=(128, 1, 1))
w_q_norm = mpk.attach_input(torch_tensor=layer.self_attn.q_norm.weight, name=f"layer_{i}_q_norm")
w_k_norm = mpk.attach_input(torch_tensor=layer.self_attn.k_norm.weight, name=f"layer_{i}_k_norm")
k_cache = mpk.attach_input(torch_tensor=model.model.kv_cache[0][i], name=f"layer_{i}_k_cache")
v_cache = mpk.attach_input(torch_tensor=model.model.kv_cache[1][i], name=f"layer_{i}_v_cache")
mpk.attention_layer(input=attn_in, q_norm=w_q_norm, k_norm=w_k_norm, k_cache=k_cache, v_cache=v_cache, cos_pos_embed=cos_pos_embed, sin_pos_embed=sin_pos_embed, output=attn_out, grid_dim=(batch_size, num_local_kv_heads, 1), block_dim=(128, 1, 1))
w_o_proj = mpk.attach_input(torch_tensor=layer.self_attn.o_proj.weight, name=f"layer_{i}_o_proj")
mpk.linear_with_residual_layer(input=attn_out, weight=w_o_proj, residual=x, output=attn_proj_out, grid_dim=(hidden_size // 64, 1, 1), block_dim=(128, 1, 1))
x = attn_proj_out
if world_size > 1:
mpk.allreduce_layer(input=attn_proj_out, buffer=allreduce_buf, output=attn_allreduce_out, grid_dim=(hidden_size // 64, 1, 1), block_dim=(128, 1, 1))
x = attn_allreduce_out
# MLP block
residual_mlp = x
w_norm_mlp = mpk.attach_input(torch_tensor=layer.post_attention_layernorm.weight, name=f"layer_{i}_post_attn_layernorm")
w_gate_proj = mpk.attach_input(torch_tensor=layer.mlp.gate_proj.weight, name=f"layer_{i}_gate_proj")
w_up_proj = mpk.attach_input(torch_tensor=layer.mlp.up_proj.weight, name=f"layer_{i}_up_proj")
w_gatedup = mpk.fuse_tensors(inputs=[w_gate_proj, w_up_proj], fused_dim=0, num_groups=1, name=f"layer_{i}_gatedup_proj")
mpk.rmsnorm_linear_layer(input=x, weight_norm=w_norm_mlp, weight_linear=w_gatedup, output=mlp_mid, grid_dim=(96, 1, 1), block_dim=(128, 1, 1))
w_down_proj = mpk.attach_input(torch_tensor=layer.mlp.down_proj.weight, name=f"layer_{i}_down_proj")
mpk.silu_mul_linear_with_residual_layer(input=mlp_mid, weight=w_down_proj, residual=residual_mlp, output=mlp_out, grid_dim=(hidden_size // 64, 1, 1), block_dim=(128, 1, 1))
x = mlp_out
if world_size > 1:
mpk.allreduce_layer(input=mlp_out, buffer=allreduce_buf, output=mlp_final, grid_dim=(hidden_size // 64, 1, 1), block_dim=(128, 1, 1))
x = mlp_final
# Final layer
w_final_norm = mpk.attach_input(torch_tensor=model.model.norm.weight, name="model_norm_weight")
w_lm_head = mpk.attach_input(torch_tensor=lm_head_weight, name="lm_head")
mpk.rmsnorm_linear_layer(input=x, weight_norm=w_final_norm, weight_linear=w_lm_head, output=argmax_in, grid_dim=(96, 1, 1), block_dim=(128, 1, 1))
# Argmax
mpk.argmax_partial_layer(input=argmax_in, output=(argmax_part_value, argmax_part_index), grid_dim=(96, 1, 1), block_dim=(128, 1, 1))
mpk.argmax_reduce_layer(input=(argmax_part_value, argmax_part_index), output=argmax_out, grid_dim=(1, 1, 1), block_dim=(128, 1, 1))
3.3.2 任務圖生成
此處對應任務圖生成:編譯器會將翻譯后的模型進一步轉(zhuǎn)換為細粒度任務圖。該任務圖屬于有向無環(huán)圖(DAG),圖中每個節(jié)點代表一項具體任務,節(jié)點間的邊則代表任務之間的依賴關(guān)系。這一步驟要求編譯器能夠準確識別并優(yōu)化任務間的依賴關(guān)系,為后續(xù)高效調(diào)度奠定基礎(chǔ)。
調(diào)用compile()方法生成最終的執(zhí)行圖。compile()函數(shù)內(nèi)會執(zhí)行:
- 生成任務圖。
- 創(chuàng)建CUDA代碼。
- 調(diào)用nvcc編譯器。
- 創(chuàng)建Python綁定模塊。
mpk.compile()
print("Mirage graph compiled.")
return mpk
3.3.3 runtime執(zhí)行
run_mirage_generation()函數(shù)是執(zhí)行引擎運行任務圖過程。
def run_mirage_generation(model, mpk, tokens, prompt_len, step_tensor, position_embeddings):
# 初始化CUDA事件用于計時(starter記錄開始,ender記錄結(jié)束)
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
# 創(chuàng)建CUDA流,用于管理異步計算任務的執(zhí)行順序
stream = torch.cuda.Stream()
# 預填充階段(處理輸入的prompt文本,生成初始上下文)
# 將步驟張量的值設(shè)為prompt長度減1,標記預填充階段的結(jié)束位置
step_tensor.fill_(prompt_len - 1)
# 從輸入 tokens 中截取前 prompt_len 個token作為初始輸入(即prompt部分)
input_ids = tokens[:, 0:prompt_len]
# 提取與prompt長度匹配的位置編碼(余弦部分)
cos_embeddings = position_embeddings[0][:, 0:prompt_len]
# 提取與prompt長度匹配的位置編碼(正弦部分)
sin_embeddings = position_embeddings[1][:, 0:prompt_len]
# 調(diào)用模型前向傳播,處理prompt并生成初始logits
logits = model.forward(
input_ids=input_ids, # 輸入的prompt token序列
position_embeddings=(cos_embeddings, sin_embeddings), # 對應的位置編碼
step=step_tensor, # 當前處理步驟標記
stream=stream # 使用指定的CUDA流進行計算
)
# 從logits中選取概率最大的token作為下一個生成的token(取最后一個位置的輸出)
next_token = logits.argmax(dim=-1)[0, -1]
# 將生成的第一個token寫入tokens張量的prompt_len位置,作為生成階段的起始
tokens[0, prompt_len] = next_token
# 等待CUDA流中的所有操作完成,確保預填充階段計算結(jié)果就緒
torch.cuda.synchronize()
# 為下一輪生成重新初始化持久化內(nèi)核(MPK)
# 收集元數(shù)據(jù)張量的指針地址,供內(nèi)核訪問
meta_tensors_ptr = [tensor.data_ptr() for tensor in mpk.meta_tensors]
# 獲取性能分析緩沖區(qū)的指針(若不存在則設(shè)為0)
profiler_buffer_ptr = (
mpk.profiler_tensor.data_ptr() if mpk.profiler_tensor is not None else 0
)
# 調(diào)用MPK的初始化函數(shù),配置內(nèi)核運行參數(shù)
mpk.init_func(
meta_tensors_ptr, # 元數(shù)據(jù)張量指針列表
profiler_buffer_ptr, # 性能分析緩沖區(qū)指針
mpk.mpi_rank, # 當前MPI進程的排名(分布式場景)
mpk.num_workers, # 工作單元(worker)的數(shù)量
mpk.num_local_schedulers, # 本地調(diào)度器的數(shù)量
mpk.num_remote_schedulers # 遠程調(diào)度器的數(shù)量(分布式場景)
)
# 生成階段(基于預填充的上下文,持續(xù)生成后續(xù)token)
# 將步驟張量的值設(shè)為prompt_len,標記生成階段的起始位置
step_tensor.fill_(prompt_len)
# 記錄生成階段開始時間
starter.record()
# 執(zhí)行持久化內(nèi)核,啟動生成過程
mpk()
# 記錄生成階段結(jié)束時間
ender.record()
# 等待CUDA操作完成,確保計時準確
torch.cuda.synchronize()
# 計算生成階段的運行時間(毫秒)
run_time = starter.elapsed_time(ender)
# 獲取生成結(jié)束時的位置(從步驟張量中提取具體數(shù)值)
end_pos = step_tensor[0].item()
# 計算實際生成的token長度(總長度減去prompt長度)
generated_len = end_pos - prompt_len
# 返回生成結(jié)束位置、運行時間和生成長度
return end_pos, run_time, generated_len
0xFF 參考
如何評價CMU將LLM轉(zhuǎn)化為巨型內(nèi)核的Mirage Persistent Kernel(MPK)工作?
Mirage: A Multi-Level Superoptimizer for Tensor Programs 簡記 塵伊光
OSDI2025論文筆記:Mirage: A Multi-Level Superoptimizer for Tensor Programs 畫餅充饑
Mirage: A Compiler for High-Performance Tensor Programs on GPUs
https://mirage-project.readthedocs.io/en/latest/mugraph.html
https://mirage-project.readthedocs.io/en/latest/transpiler.html
浙公網(wǎng)安備 33010602011771號