多模態模型——QwenVL2.5的微調以及強化學習代碼操作
持續更新:https://www.big-yellow-j.top/posts/2025/08/29/QwenVLCode.html
從代碼角度去理解QwenVL2.5是如何處理,以及結合實際操作理解如何去對一個QwenVL2.5-3B進行SFT和強化學習處理。
簡單了解一下QwenVL2.5模型的整個處理過程,模型整體過程大致為:1、首先是通過模板化處理我的模型的輸入(image+text);2、將輸入轉化為編碼形式(比如文本tokenizer處理等);3、出入模型處理輸入然后模型輸出;4、解碼輸出內容。整體主要是上述4個過程,因此下面逐一了解一下模型到底在做什么。
內容較多對于強化學習部分之間看最后的總結部分即可:
1、trl框架下PPO代碼總結;
2、trl框架下DPO代碼總結;
3、trl框架下GRPO代碼總結
QwenVL的基本使用
1、模板化模型輸入
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
},
{"type": "text", "text": "Describe this image."},
],
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
所謂模板化模型的輸入,很容易理解(通過processor.apply_chat_template把對話 messages 轉成模型能理解的 prompt,不過值得注意的是不同模型可能處理的方式不同),就是將我的內容“填充”到模板中模擬對話內容,比如說上面處理得到的一個簡單結果就是:
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
<|vision_start|><|image_pad|><|vision_end|>Describe this image.<|im_end|>
<|im_start|>assistant
一般在data_loader里面就會提前將我們的模型需要的輸入處理好,比如說我們定義如下的模板
def format_data(self, image, text, prompt):
# self.SYSTEM_MESSAGE = """You are a helpful assistant."""
return [
{
"role": "system",
"content": [{"type": "text", "text": self.SYSTEM_MESSAGE}],
},
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
},
{
"type": "text",
"text": prompt
},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": text}],
},
]
"""
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
<|vision_start|><|image_pad|><|vision_end|>This is a prompt<|im_end|>
<|im_start|>assistant
This is a text<|im_end|>
<|im_start|>assistant
"""
對于上面內容輸出理解,首先 <|im_start|>....<|im_end|>一般是一組“發言”的開始和結束標記,而后里面內容就是我們的文本/圖像內容,user/ assistant/ system 則是分別代表:用戶、模型、角色(告訴模型今天扮什么角色)。<|vision_start|>...<|vision_end|>:表示圖像輸入的占位符,告訴模型這里有一段視覺信息。<|image_pad|>:圖像實際的 embedding 會在這里替換(填充),不是文字,而是圖像編碼后的向量。值得注意的是 assistant后面的內容就是 模型需要輸出的文本內容。上面過程很容易理解,只不過需要注意如下問題,因為QwenVL2.5對于分辨率是存在處理(一般直接通過smart_resize處理,后續有介紹),因此如果涉及到目標識別,可能需要提前將坐標進行轉換避免分辨率不同導致bbox對應不上的問題
2、編碼模板輸入
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
編碼模板輸入就比較簡單,因為我的輸入都是文本/圖片,此過程就是需要將這些內容轉化為編碼形式(比如tokenizer處理等),處理方式如下:
- 1、process_vision_info:返回我的圖像/視頻輸出(都存儲在list中)
首先是過extract_vision_info從我上面的內容中提取出圖片/視頻([{'type': 'image', 'image': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'}])提取完畢之后就是交給處理圖片/視頻的函數進行處理
圖片處理過程(fetch_image)此過程也會比較簡單,首先去判斷類型(是Image.Image對象/圖片鏈接等)然后打開圖片,而后就是確定圖片分辨率尺寸,有兩種smart_resize處理方式,第一種是直接通過:resized_height 和 resized_width來確定改變,另外一種直接通過 min_pixels 和 max_pixels 來處理圖像尺寸。對于smart_rezie函數處理過程為:
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS)
# IMAGE_FACTOR= 28
if max(height, width) / min(height, width) > MAX_RATIO:
...
h_bar = max(factor, round_by_factor(height, factor)) # round(number / factor) * factor
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor) # 按比例縮小并向下取整 math.floor(number / factor) * factor
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor) # 按比例放大并向上取整 math.ceil(number / factor) * factor
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
上面3個小的子函數表示:計算factor倍數、向上取整計算倍數、向下取整計算倍數,對于smart_resize(去實現動態分辨率)函數:通過四舍五入的方式,重新設置圖片的 h 和 w 值,確保它們可以被28整除,這樣一來就得到了圖像的需要修改的尺寸了,比如說:
輸入: 一張 1000x500 的圖像
計算基礎尺寸:round(1000/28)=36, round(500/28)=18 → 1008x504
檢查像素數:1008504 = 508,032 > MAX_PIXELS(200,704)
計算縮放系數:beta = sqrt(1000500/200704) ≈ 1.58
最終尺寸:floor(1000/1.58)=632, floor(500/1.58)=316 → 616x308(28的倍數)
視頻處理過程(fetch_video)對于視頻處理和圖像處理相類似打開-->改變尺寸。只不過在打開過程中QwenLV2.5處理過程為:
def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False):
if isinstance(ele["video"], str):
video_reader_backend = get_video_reader_backend()
try:
video, sample_fps = VIDEO_READER_BACKENDS[video_reader_backend](ele)
except Exception as e:
logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}")
video, sample_fps = VIDEO_READER_BACKENDS["torchvision"](ele)
...
對于VIDEO_READER_BACKENDS設計了3中不同范式:1、_read_video_decord;2、_read_video_torchvision;3、_read_video_torchcodec。
_read_video_decord
def _read_video_decord(
ele: dict,
) -> (torch.Tensor, float):
"""read video using decord.VideoReader
Args:
ele (dict): a dict contains the configuration of video.
support keys:
- video: the path of video. support "file://", "http://", "https://" and local path.
- video_start: the start time of video.
- video_end: the end time of video.
Returns:
torch.Tensor: the video tensor with shape (T, C, H, W).
"""
import decord
video_path = ele["video"]
st = time.time()
vr = decord.VideoReader(video_path)
total_frames, video_fps = len(vr), vr.get_avg_fps()
start_frame, end_frame, total_frames = calculate_video_frame_range(
ele,
total_frames,
video_fps,
) # 得到視頻的開始 結束 總結多少幀
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(start_frame, end_frame, nframes).round().long().tolist()
video = vr.get_batch(idx).asnumpy()
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
...
sample_fps = nframes / max(total_frames, 1e-6) * video_fps
return video, sample_fps
對于其中的 calculate_video_frame_range函數處理過程也很簡單(直接去計算視頻開始、結束、總共多少幀),而后類似動態分辨率(smart_resize中成立相類似的)對于視頻會通過智能視頻幀數計算算法(smart_nframes),用于確定從視頻中提取多少幀作為模型輸入,處理過程為:第一種直接通過round_by_factor(ele["nframes"], FRAME_FACTOR)來得到幀數;第二種處理方式為(FPS_MIN_FRAMES = 4、FRAME_FACTOR = 2、FPS_MAX_FRAMES = 768、FPS = 2.0):
fps = ele.get("fps", FPS)
min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR)
nframes = total_frames / video_fps * fps
nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
nframes = floor_by_factor(nframes, FRAME_FACTOR)
"""
config = {"nframes": 24}
result = smart_nframes(config, total_frames=100, video_fps=30)
# 輸出:24(直接使用配置值)
config = {"fps": 10, "min_frames": 16, "max_frames": 32}
result = smart_nframes(config, total_frames=100, video_fps=30)
# 計算:100/30*10 ≈ 33.33 → 約束到32 → 對齊到32(FRAME_FACTOR=8的倍數)
"""
- 2、processor:去將圖片/文本進行編碼
其中對于文本編碼直接通過 self.tokenizer 來處理,而對于圖像直接通過 self.image_processor來處理。首先在 代碼中很容易看到使用的圖像/文本處理方式image_processor_class = "AutoImageProcessor" 對于文本處理方式 tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")。
對于圖片處理方式的 Qwen2VLImageProcessor(代碼)的處理思路:
class Qwen2VLImageProcessor(BaseImageProcessor):
def __init(...):
...
def _preprocess(self, images, ...):
...
height, width = get_image_size(images[0], channel_dim=input_data_format)
resized_height, resized_width = height, width
processed_images = []
# Step-1
for image in images:
if do_resize:
resized_height, resized_width = smart_resize(
height,
width,
factor=self.patch_size * self.merge_size,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
)
image = resize(
image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
)
if do_rescale:
image = self.rescale(image,...)
if do_normalize:
image = self.normalize(image,...)
# Step-2
patches = np.array(processed_images)
if data_format == ChannelDimension.LAST:
patches = patches.transpose(0, 3, 1, 2)
if patches.shape[0] % self.temporal_patch_size != 0:
# 視頻補幀處理
repeats = np.repeat(patches[-1][np.newaxis], self.temporal_patch_size - 1, axis=0)
patches = np.concatenate([patches, repeats], axis=0)
# 計算不同 patch 網格大小
channel = patches.shape[1]
grid_t = patches.shape[0] // self.temporal_patch_size
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
patches = patches.reshape(
grid_t,
self.temporal_patch_size,
channel,
grid_h // self.merge_size,
self.merge_size,
self.patch_size,
grid_w // self.merge_size,
self.merge_size,
self.patch_size,
)
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
flatten_patches = patches.reshape(
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
)
return flatten_patches, (grid_t, grid_h, grid_w)
對于上面處理過程中,首先對于 _preprocess主要是對圖像進行一些預處理:1、do_resize:改變圖片大小(直接通過smrt_resize進行處理)2、do_rescale:像素縮減到0-1之間;3、do_normalize:對圖片進行歸一化處理(通道維度);而后直接對于預處理后的圖像直接進行切割處理為不同的patch輸入到Vit中。
回顧一下QwenVL2.5的圖片處理過程:首先是去對圖片進行改變尺寸(保證圖片最后可以整除patch_size)/縮放/歸一化。而后就是直接將圖片處理為vit能夠處理的“序列輸入”得到的維度為:[grid_t * grid_h * grid_w, channel * temporal_patch_size(2) * patch_size(14) * patch_size(14)]。
補充一:圖片輸入具體例子說明
假設默認參數為:patch_size= 14, temporal_patch_size= 2, merge_size= 2
圖像輸入為(通過process_vision_info提前處理之后的維度):(1092, 1568)
首先計算resized_height, resized_width = smart_resize得到 812 1176
首先計算:grid_t=1,grit_h=812//14=58,grid_w=1176//14=84那么計算得到為 4872另外一項為 1176也就是最后圖像處理得到的輸出為:(1*58*84, 14*14*2*3)=(4872,1176)
補充二:對于 smart_resize快速估算最后大小:
先 round 到 factor 的倍數
如果超出 max_pixels → 除以 sqrt(HW/max_pixels),floor → factor 倍數
如果小于 min_pixels → 乘以 sqrt(min_pixels/HW),ceil → factor 倍數
其實也就是:首先將圖像處理到為factor倍數的分辨率,而后去判斷和max_pixels和min_pixels之間大小,大于前者就縮小,小于前者就放大
最后通過一系列編碼之后得到輸出:
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
"""
input_ids: torch.Size([1, 1243])
attention_mask: torch.Size([1, 1243])
pixel_values: torch.Size([4872, 1176])
image_grid_thw: torch.Size([1, 3])
"""
3、模型輸入處理
generated_ids = model.generate(**inputs, max_new_tokens=128)
整體模型輸入處理,輸入模型也就是上面編碼模板輸入幾個部分,只不過主要就是如下幾個處理:首先是模型處理輸入 input_ids 以及我的圖像 pixel_values(inputs_embeds = self.model.embed_tokens(input_ids) 代碼),而后將輸入進行位置編碼處理(代碼),最后輸出模型結果(代碼),對于QwenVL2.5完整模型結構:
Qwen2_5_VLForConditionalGeneration(
(model): Qwen2_5_VLModel(
(visual): Qwen2_5_VisionTransformerPretrainedModel(
(patch_embed): Qwen2_5_VisionPatchEmbed(
(proj): Conv3d(3, 1280, kernel_size=(2, 14, 14), stride=(2, 14, 14), bias=False)
)
(rotary_pos_emb): Qwen2_5_VisionRotaryEmbedding()
(blocks): ModuleList(
(0-31): 32 x Qwen2_5_VLVisionBlock(
(norm1): Qwen2RMSNorm((1280,), eps=1e-06)
(norm2): Qwen2RMSNorm((1280,), eps=1e-06)
(attn): Qwen2_5_VLVisionAttention(
(qkv): Linear(in_features=1280, out_features=3840, bias=True)
(proj): Linear(in_features=1280, out_features=1280, bias=True)
)
(mlp): Qwen2_5_VLMLP(
(gate_proj): Linear(in_features=1280, out_features=3420, bias=True)
(up_proj): Linear(in_features=1280, out_features=3420, bias=True)
(down_proj): Linear(in_features=3420, out_features=1280, bias=True)
(act_fn): SiLU()
)
)
)
(merger): Qwen2_5_VLPatchMerger(
(ln_q): Qwen2RMSNorm((1280,), eps=1e-06)
(mlp): Sequential(
(0): Linear(in_features=5120, out_features=5120, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=5120, out_features=2048, bias=True)
)
)
)
(language_model): Qwen2_5_VLTextModel(
(embed_tokens): Embedding(151936, 2048)
(layers): ModuleList(
(0-35): 36 x Qwen2_5_VLDecoderLayer(
(self_attn): Qwen2_5_VLAttention(
(q_proj): Linear(in_features=2048, out_features=2048, bias=True)
(k_proj): Linear(in_features=2048, out_features=256, bias=True)
(v_proj): Linear(in_features=2048, out_features=256, bias=True)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(rotary_emb): Qwen2_5_VLRotaryEmbedding()
)
(mlp): Qwen2MLP(
(gate_proj): Linear(in_features=2048, out_features=11008, bias=False)
(up_proj): Linear(in_features=2048, out_features=11008, bias=False)
(down_proj): Linear(in_features=11008, out_features=2048, bias=False)
(act_fn): SiLU()
)
(input_layernorm): Qwen2RMSNorm((2048,), eps=1e-06)
(post_attention_layernorm): Qwen2RMSNorm((2048,), eps=1e-06)
)
)
(norm): Qwen2RMSNorm((2048,), eps=1e-06)
(rotary_emb): Qwen2_5_VLRotaryEmbedding()
)
)
(lm_head): Linear(in_features=2048, out_features=151936, bias=False)
)
- 首先:對于視覺部分處理(
Qwen2_5_VisionTransformerPretrainedModel)
對于視覺模型主要需要處理的就是
pixel_values,假設輸入的pixel_values信息為:[4872, 1176],image_grid_thw為: [1, 84, 58](就是對應grid_t、grid_h、grid_w這三個數值)
主要包括如下幾個模塊:
1、Qwen2_5_VisionPatchEmbed:主要進行處理通過一個 Conv3d處理,處理過程也就是說首先將輸入的維度進行修改得到:view(-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size) --> (4872,1176)-->(4872,3,2,14,14)而后再去通過卷積處理得到 (4872,1280,1,1,1)最后得到:(4872,1280),也就對應著:(grid_t*grid_h*grid_w, hiddend_size);
2、Qwen2_5_VisionRotaryEmbedding;
3、Qwen2_5_VLVisionAttention:首先去劃分window_size這一步直接根據計算得到的:[grid_t, grid_h, grid_w]去劃分windows,比如說在上述例子中,得到的cu_seqlens = [0,64,128,...,4872],而后再去通過如下處理:
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
splits = [
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
]
去劃分q、k、v(形狀都為:[1, 16, 4872, 80])然后計算注意力,而后通過Qwen2_5_VLPatchMerger將結果合并起來。
具體計算過程,首先是如何得到cu_seqlens,因為我們得到的gird_thw=(1, 84, 58)也就是說總共有84*58=4872個token去計算全局注意力,那么這就會導致計算注意力的消耗過大,因此可以先去切分成小的window然后小塊內部注意力計算。因此首先計算“塊”的大小:vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size得到結果為: 4(112/2/14)也就是說每塊大小為:4x4=16,但是不一定我的grid_h和grid_w可能整除4,因此就需要去計算填充數量 vit_merger_window_size - llm_grid_h % vit_merger_window_size 分別得到 4和2因此填充后的h和w為:88,60這樣一來計算得到window數量為:88//4 * 60//4=330每個窗口的tokens數量:16
4、圖像處理過程總結
總結上述圖像處理過程:對于任意輸入圖像首先通過smart_resize(首先將圖像改變到 factor的倍數,然后去判斷和min_pixels和max_pixels之間大小,然后進行擴大,縮小)進行處理保證都可以整除patch_size(14)然后丟到 processor中進行處理主要是對圖像歸一化、正則化、改變維度(還會通過smart_resize在處理一次),處理之后再去確定他的 grid_t, grid_h, grid_w(對于這3個參數確定:直接通過 第二次smart_resize處理之后的結果除 patch_size即可)也就是tokens數量,而后將圖像內容通過 conv3d處理得到:(grid_t* grid_h* grid_w, hidden_size),最后就是計算window_attention(首先確定widow_size索引,通過索引進行切分,最后計算注意力)
補充:對于window-attention可以用卷積的思路去理解,比如說我得到“圖像”:
(grid_t, grid_h, grid_w)我提前計算我的“卷積核”大小(vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size)為了保證我的 “圖像”可以被卷積核處理就需要做一部分填充,而后用這個“卷積核”去劃分成不同“小塊”在到這個小塊里面計算注意力。
5、位置編碼
QwenVL的微調過程
所有的代碼:https://github.com/shangxiaaabb/Docparse-QwenVL
補充一:節約顯存可以進行的操作
1、使用gradient_checkpointing:model.gradient_checkpointing_enable()
2、使用qlora進行優化
3、使用AdamW8bit而不是AdamW
4、使用xformers(model.enable_xformers_memory_efficient_attention()),不過需要注意的是 QwenVL2.5不支持使用xformers(除此之外安裝也比較復雜)
5、避免顯存碎片(不要過度的去評估模型),可以使用gc.collect() torch.cuda.empty_cache()去適當的減小緩存壓力,對于不需要的內容(中間值)直接通過del xx處理掉
SFT 處理
https://www.f22labs.com/blogs/complete-guide-to-fine-tuning-qwen2-5-vl-model/
SFT數據處理過程
首先假設數據(通過jsonl進行存儲)輸入格式為:
{"image":
"845c2f9b-0583-4127-82a6-47c4c1c3ceb7.jpg",
"prefix":
"QwenVL HTML",
"suffix":
"<body><h2 data-bbox=......"
}
構建data_loader只需要注意如下幾個流程即可:
首先構建我的輸入模板。這一步主要是將我的數據進行讀取,然后去構建成QwenVL2.5(或者其他大模型的對話形式),比如說:
def format_data(self, image, entry, text, prompt):
return [
{
"role": "system",
"content": [{"type": "text", "text": self.SYSTEM_MESSAGE}],
},
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
},
{
"type": "text",
"text": (
"Must output the layout of the image strictly in HTML format. "
"Must follow the example below:\n"
"<h2 data-bbox='x1 y1 x2 y2'>Text</h2>\n"
"<p data-bbox='x1 y1 x2 y2'>Text</p>")
},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": text}],
},
]
然后就只需要將參數丟到這個函數里面就可以自動化的將數據處理好(補充一點,對于上面幾個參數,一般來說其中text就是我的模型需要輸出的label,而后其他的內容就是模型的輸入),其次就只需要將輸入進行編碼即可也就是說直接通過:
image_inputs, _ = process_vision_info(messages)
encoding = self.processor(
text=[text],
images= image_inputs,
return_tensors="pt",
padding= False,
truncation=True,
max_length= self.max_length
)
這樣就會的得到模型的輸入內容,一般來說得到的是:input_ids: 文本編碼內容(一般來說會直接將 input_ids進行復制作為我們的 labels,當然也可以直接對與輸入解析,只需要模型那部分作為labels),attention_mask,pixel_values: 圖片像素編碼結果image_grid_thw: 我的tokens數量(grid_t*grid_h*grid_w)。
不過上面處理過程只是針對一張圖片進行處理去構建對話信息,如果需要處理多組圖片同時進行輸入(比如說3張圖片進行排序,讓QwenVL輸出)那么處理過程只需要修改 content即可(在content里面指定多個圖片即可)
"content": [
{
"type": "image",
"image": "./tmp/7.png",
},
{
"type": "image",
"image": "./tmp/1.png",
},
{"type": "text", "text": "..."},
],
SFT模型處理
一般來說如果直接使用lora去對模型進行微調,處理也比較簡答:
target_modules = ['q_proj', 'v_proj']
lora_config = LoraConfig(
task_type= config.lora_task_type,
target_modules= target_modules,
r= config.lora_rank,
lora_alpha= config.lora_alpha,
lora_dropout= config.lora_dropout,
)
model = get_peft_model(model, lora_config)
這樣一來模型就會被lora“包裹”,微調過程也就是優化lora的參數,不過如果需要使用qlora(lora量化版本)再模型加載過程中需要使用參數 quantization_config:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
...
if model_name == 'Qwen/Qwen2.5-VL-3B-Instruct':
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_name,
torch_dtype= torch.bfloat16,
cache_dir= config.cache_dir,
quantization_config= bnb_config if config.lora_type== 'qlora' else None,
)
對于模型訓練以及參數優化過程就比較簡單:
for step, batch in enumerate(train_loader):
outputs = model(**batch)
loss = outputs.loss
得到的所有的內容可以直接全部丟到model里面,他會自動計算loss值,對于outputs = model(**batch)模型返回得到結果為:
loss: Optional[torch.FloatTensor]:模型計算得到的loss(直接計算交叉熵損失得到),如果輸入內容中沒有labels(就是模型輸出那段文本)那么就不會去計算loss
logits: Optional[torch.FloatTensor]:模型輸出結果
past_key_values: Optional[list[torch.FloatTensor]]:Transformer 解碼器的 KV 緩存(每一層的注意力 key 和 value)
hidden_states: Optional[tuple[torch.FloatTensor]]:每一層的 hidden state (batch_size, seq_len, hidden_size)
attentions: Optional[tuple[torch.FloatTensor]]:每一層注意力權重 (batch_size, num_heads, seq_len, seq_len)
rope_deltas: Optional[torch.LongTensor]:旋轉位置編碼 RoPE(Rotary Position Embedding)的偏移量
RL 處理
強化學習框架很多,1、huggingface-trl: https://github.com/huggingface/trl;2、字節跳動-verl: https://github.com/volcengine/verl;3、OpenRLHF:https://github.com/OpenRLHF/OpenRLHF
強化學習處理過程(直接使用 trl(使用版本:0.22.1)庫,它里面提供了多種腳本)對于多模態/大語言模型使用RL中比較常見的的數據類型:一般就是拋出問題,而后給出選項讓模型進行選擇。此類數據集一般格式為:
{"images": [], "prompt": [], "chosen": [], "rejected": []}
# 當然這個 images 也可以替換為文本問題 "question"
比如說數據集:HuggingFaceH4/rlaif-v_formatted他的數據結構如下:

直接看trl中如何實現QwenVL-DPO過程代碼:
from trl import (
DPOConfig,
DPOTrainer,
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
...
dataset = load_dataset(
script_args.dataset_name,
name=script_args.dataset_config,
streaming=script_args.dataset_streaming,
)
...
# ref_model 和 model 都是直接使用QwenVL
trainer = DPOTrainer(
model,
ref_model,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=processor,
peft_config=peft_config,
)
初次之外,RL就和SFT一樣需要讓模型去按照我的數據進行輸出,因此處理也就是直接logits=model(**model_inputs).logits得到模型最后輸出(見相當于每個詞的概率)
RL-DPO處理代碼
首先在代碼(DPOTrainer)主要是通過繼承 Trainer(代碼包裹好了各種處理過程比如數據加載模型評估等各項處理過程)直接看 DPOTrainer里面的 get_batch_loss_metrics(完整模型輸入然后輸出loss):
def get_batch_loss_metrics(self, model, batch, train_eval):
...
if ...:
...
else:
model_output = self.concatenated_forward(model, batch)
if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch:
ref_chosen_logps = batch["ref_chosen_logps"]
ref_rejected_logps = batch["ref_rejected_logps"]
else:
ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch)
losses = 0
chosen_rewards = 0
rejected_rewards = 0
for idx, loss_type in enumerate(self.loss_type):
_losses, _chosen_rewards, _rejected_rewards = self.dpo_loss(
model_output["chosen_logps"],
model_output["rejected_logps"],
ref_chosen_logps,
ref_rejected_logps,
loss_type,
model_output,
)
weight = self.loss_weights[idx] if self.loss_weights else 1.0
losses = losses + _losses * weight
chosen_rewards = chosen_rewards + _chosen_rewards * weight
rejected_rewards = rejected_rewards + _rejected_rewards * weight
return losses.mean(), ...
對于DPOTrainer里面data_loader處理過程為,首先對于 dataset會通過 processing_class(一般來說也就是對于文本直接使用 tokenizer,亦或者直接使用 AutoProcessor.from_pretrained(...))進行處理,也就是說會提前將數據processor處理(和SFT處理方式相同)那么就會得到 self.train_dataset,那么接下來就是直接去通過代碼(加載train_loader數據),其中處理方式為:ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch) 對于 compute_ref_log_probs里面處理過程為:直接去通過 model/ref_model去處理:self.concatenated_forward(代碼)得到模型輸出: model_output,而后再去使用 self.dpo_loss去計算損失。
self.concatenated_forward處理過程 Github-代碼(實際解釋使用 trl:0.22.1版本代碼和github有差異)
def concatenated_forward(model, batch, is_ref_model):
concatenated_batch = self.concatenated_inputs(batch, padding_value=self.padding_value)
prompt_input_ids = concatenated_batch["prompt_input_ids"] # 問題文本
prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
completion_input_ids = concatenated_batch["completion_input_ids"] # 回答文本 同時拼接了chosen_input_ids 和 rejected_input_ids
completion_attention_mask = concatenated_batch["completion_attention_mask"]
if self.is_encoder_decoder:
labels = completion_input_ids
labels[completion_attention_mask == 0] = self.label_pad_token_id
outputs = model(
input_ids=prompt_input_ids,
attention_mask=prompt_attention_mask,
labels=labels, # we need the labels for the logits to be returned
**model_kwargs,
)
logits = outputs.logits
loss_mask = completion_attention_mask.bool()
else:
# Process-1
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
...
outputs = model(input_ids, **model_kwargs)
logits = outputs.logits
# Process-2
Process-1:首先是將文本和回答進行拼接,而后去判斷如果指定 max_length那么就去根據 truncation_mode(掐頭/去尾:保留序列末尾,移除開頭多余部分)去裁減輸入以及移除填充和限制計算范圍來優化內存和性能最后丟到模型中進行處理。
掐頭去尾過程
keep_start:保留序列開頭。先調用 flush_left(所有有效的token左移動去除中間padding)。然后截斷到 max_length([:, :self.max_length])。[0, 0, x, x, x, x] → flush_left后[x, x, x, x],若 max_length=3,則截斷為[x, x, x]
keep_end:保留序列末尾。先調用 flush_right(將所有有效token向右移動,前面填充padding)。截斷到最后 max_length 個 token([:, -self.max_length:])。再次調用 flush_left,確保左側無填充。[0, 0, x, x, x, x] → flush_right后[0, 0, x, x],截斷后[x, x],flush_left 后保持不變。
回顧一下self.concatenated_forward(模型處理)整個過程:首先是將chosen_input_ids 和 rejected_input_ids兩部分進行拼接(self.concatenated_inputs做的,于此同時對于其他內容也都會拼接成兩部分)作為我們模型的回答。而后丟到模型中進行處理(對于 is_encoder_decoder 可以直接給模型處理,如果不是那么就通過截斷裁剪等處理來節約存儲在由模型處理)得到 logits,去通過logits, label得到每個token的對數概率:all_logps,而后再去判斷是否進行優化策略: ipo 或者 ld_alpha(長度去敏化)去優化得到的 all_logps(對其直接切分就可以得到:chosen_logps 和 rejected_logps)
self.dpo_loss計算損失過程 Github-代碼(實際解釋使用 trl:0.22.1版本代碼和github有差異)
model_output = self.concatenated_forward(model, batch)
if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch:
# 直接使用數據里面的的結果
ref_chosen_logps = batch["ref_chosen_logps"]
ref_rejected_logps = batch["ref_rejected_logps"]
else:
# 相對于直接在用模型處理一下得到結果
ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch)
_losses, _chosen_rewards, _rejected_rewards = self.dpo_loss(
model_output["chosen_logps"],
model_output["rejected_logps"],
ref_chosen_logps,
ref_rejected_logps,
loss_type,
model_output,)
if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch:直接使用數據里面的結果過程一樣的還是通過模型self.compute_ref_log_probs(batch)(這個還是調用了self.concatenated_forward)去得到chosen_logps 和 rejected_logps結果。
對于 dpo_loss 里面model_ 和 ref_ 這兩部分理論上是兩個不同的模型的輸出結果,但是如果沒有指定 ref_model 那么直接就都直接使用 model 即可
對于DPO的loss處理過程就比較簡單,在trl中提供3種計算方式:
1、Alpha散度計算

2、KL散度計算

3、JS散度計算

在計算得到不同方式得到的結果:logits然后再去根據不同 loss_type去做處理(比如說:loss_type == "sigmoid" 處理過程為:losses = (-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)- F.logsigmoid(-self.beta * logits) * self.label_smoothing))
RL-DPO處理過程總結
首先對于我們的數據集(假設為3元組:[問題, 接受回答, 拒絕回答])首先就是去通過 processor(比如Qwen2.5vl可以直接 load)去編碼我的所有內容(這一步和SFT過程相似),而后就是去通過self.concatenated_forward這個函數將我們的3元組進行拼接得到:[問題,問題], [接受回答, 拒絕回答]而后得到模型的輸入為:[問題+接受回答, 問題+拒絕回答],將輸入直接交給的模型(由于見內容直接拼接起來,可能會優化模型的輸入/出長度過長導致爆顯存,因此輸入之前會由一些裁剪處理操作)去得到輸出:logits,而后通過logits, label得到每個token的對數概率:all_logps,(通過對all_logps進行拆分)就可以得到接受回答的值(chosen_logps),以及拒絕回答的值(rejected_logps),最后在得到這兩部分值之后就是直接去計算loss。
對于loss計算過程(假設為KL散度):\(\mathrm{loss}=-\frac{1}{N}\sum_{i=1}^{N}\log\sigma\left(\beta\cdot((\log\pi_{\theta}(y_{w}|x)-\log\pi_{\theta}(y_{l}|x))-(\log\pi_{\mathrm{ref}}(y_{w}|x)-\log\pi_{\mathrm{ref}}(y_{l}|x)))\right)\)。對于里面兩項相減過程代碼:
chosen_logratios = chosen_logps.to(device) - (not self.reference_free) * ref_chosen_logps.to(device)
rejected_logratios = rejected_logps.to(device) - (not self.reference_free) * ref_rejected_logps.to(device)
反思:如果需要手搓一個DPO訓練過程代碼(需要借鑒concatenated_forward代碼來輔助實現)
RL-GRPO處理代碼
官方實現代碼,對于DPO過程很容易發現一點在GRPO中直接不要ref_model 只是用一個model不過設計了一個reward_function。
- 數據處理過程
以官方代碼為例(訓練一個具有思考過程的多模態模型),在數據處理層面使用類似如下數據集

以為需要設計一個“輸出”思考過程的模型因此設計設計具有“思考”過程的prompt,最后輸入模型數據格式為:
# 原始文本
{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=147x86 at 0x7FF65C5776D0>,
'original_answer': ...,
'original_question': ...,
'problem': ...
'prompt': [{'content': 'system-content',
'role': 'system'},
{'content': 'user-content',
'role': 'user'}],
'solution': "<think>...</think>'
'<answer>...</answer>'}
# 初步處理后文本
{'The prompt Text: '
'<|im_start|>system\n systen-content <|im_end|>\n'
'<|im_start|>user\n user-content <|im_end|>\n'
'<|im_start|>assistant\n'}
# 模型最后得到的輸出
output = {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"advantages": advantages,
"num_items_in_batch": num_items_in_batch,
}
不過在得到類似上面數據集之后,不是直接丟到模型里面進行處理,在DPOTrainer中首先會去由_prepare_inputs(代碼)函數進行處理,對于測試直接通過函數 self._generate_and_score_completions(...)處理,對于訓練數據集
_generate_and_score_completions:
第一步、格式化數據。(對于多模態/只有文本)這個過程主要是爭對我上面數據中的prompt直接通過模板進行處理得到prompts_text,而后就是直接再去通過processing_claa(直接調用QwenVL的processor)處理得到prompt_inputs,而后就是如果self.max_prompt_length那么就會去對多模態(文字 + 圖像)輸入時,對prompt_inputs["input_ids"]還原文本然后去除類似<pad>和一些重復/錯誤的<image>得到干凈的prompts_text。
第二步、生成回答。在trl中使用了3種生成方式:1、直接用模型生成;2、使用vllm方式生成;3、使用use_transformers_paged方式。對于生成(直接通過模型)過程而言就比較簡單直接將prompt_inputs["input_ids"]和prompt_inputs["attention_mask"]丟到模型里面得到prompt_completion_ids再去將 prompt內容和回答截取出來得到prompt_ids和completion_ids
第三步、計算獎勵值。這個過程就比較簡單,直接將模型的回答進行解碼再去通過獎勵函數計算回答的獎勵值,而后歸一化成優勢函數(advantages),按 group(一次生成多個樣本)算均值,計算每個樣本的 相對優勢(比如說兩個回答打分為 [0.8, 0.5]那么減去 group 內均值,假設為[+0.15, -0.15])
最后、返回輸出。
在最后返回的輸出中old_per_token_logps和ref_per_token_logps處理直接通過函數_get_per_token_logps_and_entropies(就相當于把 第二步得到的prompt_completion_ids在交給模型里面去計算每個token的概率)
- 獎勵函數設計
GRPO沒有使用ref_model轉而使用獎勵函數,對于獎勵函數設計:think_format_reward, accuracy_reward。對于accuracy_reward很容易理解代碼就是直接對比模型輸出和答案之間是否正確(通過parse [from math_verify import LatexExtractionConfig, parse, verify] 去解析最后輸出打答案然后對比兩者之間是否正確)。對于think_format_reward:這個更加直接,直接去判斷輸出是不是有 <think>...</think> 包裹(有=1,無/缺失=0)
當然不一定要使用自定義的(這么粗糙的)在DPOTrainer中對于self.reward_funcs(代碼)也可以直接去加載訓練好的模型 AutoModelForSequenceClassification.from_pretrained(...)
- 模型處理過程
直接去看loss計算過程:
def compute_loss(self, model, inputs, return_outputs, num_items_in_batch):
...
if self.use_liger_loss:
unwrapped_model = self.accelerator.unwrap_model(model)
return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs)
else:
return self._compute_loss(model, inputs)
其中使用了兩種loss處理過程:_forward_redirection 以及 _compute_loss。
self._compute_loss處理過程(Github-代碼)(實際解釋使用 trl:0.22.1版本代碼和github有差異)
首先是將輸入問題和回答拼接起來,然后直接丟到self._get_per_token_logps_and_entropies(直接將數據丟到模型中,而后去截取模型輸出中“真正回答”的內容)中進行處理得到per_token_logps(每個token的概率),entropies(每個token的信息熵),而后就是通過高熵去過濾token只在高熵位置計算 loss,而后就是計算KL散度(torch.exp(inputs["ref_per_token_logps"] - per_token_logps) - (inputs["ref_per_token_logps"] - per_token_logps) - 1)),避免新策略漂移太遠
self._get_per_token_logps_and_entropies處理過程(Github-代碼)(實際解釋使用 trl:0.22.1版本代碼和github有差異)
其處理過程比較簡單,直接將所有的數據都處理成模型輸入(GRPO不想DPO那樣需要將3元組進行拆開拼接)如:input_ids、pixel_values等然后直接logits = model(**model_inputs).logits在得到模型的輸出之后后續就是對輸出做一些截斷處理(如只需要模型回答部分的輸出logits[:, -logits_to_keep:, :])而后去計算logits / self.temperature(通過溫度系數來確定輸出內容多樣化)最后再去通過:logps = selective_log_softmax(logits, completion_ids)(selective_log_softmax只去計算completion_ids部分的log_softmax值)就可以得到最后的值。
RL-GRPO處理過程總結

對于上面loss計算公式中主要就是如下幾個值需要關注:1、advantage值;2、KL散度值。
因此簡單總結一些GRPO代碼處理過程[1],首先,對于數據處理,這塊內容比較簡單直接 模板化、編碼內容即可,因為GRPO是“一個問題拋出多組回答然后評估回答”,因此在數據處理過程中通過模型生成回答 prompt_completion_ids=model.generate(...)而后需要做的就是將生成內容進行拆分得到prompt_ids和 completion_ids(得到這一部分值之后就只需要在去還原成text文本然后再去通過reward函數去計算reward值以及計算最后需要的 advantage值),除此之外還會去通過model和model_ref分別計算回答中每個token的logits值:old_per_token_logps 和 ref_per_token_logps
這個過程直接通過函數 _get_per_token_logps_and_entropies處理,他的處理思路簡單直接將 model需要的內容再丟到model里面得到每個token的logits然后再去計算softmax值
最后得到一個完整的output如下:
output = {
"prompt_ids": prompt_ids, # 問題token
"prompt_mask": prompt_mask,
"completion_ids": completion_ids, # 問題的回答token
"completion_mask": completion_mask,
"advantages": advantages,
"num_items_in_batch": num_items_in_batch,
"old_per_token_logps": old_per_token_logps
"importance_sampling_ratio": importance_sampling_ratio
"ref_per_token_logps": ref_per_token_logps
"pixel_values": prompt_inputs["pixel_values"]
"image_grid_thw": prompt_inputs["image_grid_thw"]
"pixel_attention_mask": prompt_inputs["pixel_attention_mask"]
"image_sizes": prompt_inputs["image_sizes"]
}
而后,對于loss計算過程首先將上面output中的 問題+回答進行組合再丟到_get_per_token_logps_and_entropies中得到每個token概率以及熵的值:per_token_logps,entropies,而后就是:1、選擇出高熵值的token(entropy_mask);2、計算KL散度(torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1);3、重要性采樣權重:比較當前 log 概率和舊策略(per_token_logps - old_per_token_logps),得到 importance weight,做 clipping 限制。構造兩個候選 loss(不裁剪和裁剪),取最小值,形成 per_token_loss再去乘上 entropy_mask和加上 KL 懲罰項就可以得到最后的loss值。
RL-PPO處理代碼
借用huggingface中對于PPO過程描述圖:

對于代碼使用,相比較GRPO和DPO要簡單很多(不過在使用模型上,DPO和PPO都需要加載model和ref_model而GRPO只需要加載一個model),按照上面的處理過程:
首先計算rollout輸出,直接通過加載的模型然后模型對于“問題”去得到“回答”query_responses(完整的模型生成內容:prompt+模型的回答),logitss,接下來(代碼)去計算model和ref_model中每個token的log概率值(這個過程和GRPO處理是一樣的,將問題+回答拼接起來而后丟到模型中計算每個token的log概率值)最后分別得到模型的輸出結果:logprob response(截取model回答內容) 和 ref_logprob。后面部分(代碼)就是直接根據 response(model的回答) 以及 query(就是我們的問題)去計算reward的值scores。
接下來處理過程:1、處理 EOS 缺失懲罰:將socres中如果生成內容不含結束標記就從scores中減去數值;2、計算kl以及最后的rewards值,對kl直接首先通過mask去掩蓋部分logprobs(ref_logprobs)然后直接通過 kl = -(ref_logprobs - logprobs) if args.kl_estimator == "k1" else ((ref_logprobs - logprobs).exp() - 1) - logr得到kl值;3、計算advantage值(代碼)
最后就是迭代優化模型參數(代碼)這個過程(對采樣得到的一批序列數據做多輪(num_ppo_epochs)小批次更新,通過 ratio = πθ/π_old 和裁剪(clip)來構造策略損失,同時對價值函數做裁剪的 value loss)主要是進行如下處理流程:首先是直接將最上面得到的query_responses中選擇部分例子丟到模型中去計算每一個token的logits( new_logprobs = selective_log_softmax(logits, mb_responses) ) 而后計算策略損失值(pg_loss)以及vf_loss
回顧一下,對于加載的llm在使用generate時一般返回如下4個值:
sequences:生成的 token ids(跟默認返回一樣);
scores:每一步的 logits(如果 output_scores=True)
attentions:注意力矩陣(如果 output_attentions=True)
hidden_states:隱藏層表示(如果 output_hidden_states=True)
一般而言使用到的主要是上面兩項,對于第一項sequences一般得到的完整的回答(prompt+模型生成的內容),所以一般會有一個截取處理(只需要記錄inputs['input_ids'].shape[1]然后去截取即可);對于第二項scores一般得到的是通常是logits(需要去通過softmax計算才能得到token概率);因此在GRPO和PPO中為了得到每一個token的log概率值,logprob = selective_log_softmax(logits, response)直接通過這種方式去計算來節約顯存。
除此之外也有直接通過model(**model_inputs)這樣處理一般得到的是
RL-PPO處理過程總結
第一階段:首先是對于問題(query)通過丟到模型batch_generation中處理得到query_responses(完整問題+模型回答) 和logitss(每個token對應的概率),進一步將其得到回答token的概率值logprob(selective_log_softmax)同樣的處理過程通過policy_model將query_response(從 query_responses挑選的)輸入到模型進行處理同樣的處理得到ref_logprob,最后就是通過reward_model去計算(torch.cat((query, postprocessed_response), 1))得到獎勵值。
第二階段:kl值:直接計算ref_logprobs - logprobs(也就是計算上面階段的ref_logprob和 logprob之間差值);rewards值:直接copy計算的kl結果然后再序列的結尾補充上scores;advantage值:根據 reward 和 value,用 GAE 算 advantage。GAE計算過程:\(\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)\) 和\(A_t = \delta_t + \gamma \lambda A_{t+1}\)最后計算advantages + values也就是 \(R_t=A_t+V(s_t)\)
第三階段:進行迭代優化模型參數過程,優化過程首先是直接將小批次的query_responses 輸入到模型中計算得到output, vpred_temp然后就是老操縱得到每個token的logits值new_logprobs,然后計算去計算vf_loss:計算loss1(torch.square(vpred - mb_return))和loss2(torch.square(vpredclipped - mb_return))的最大值。pg_loss:計算loss1(-mb_advantage * ratio)和loss2(-mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange))的最大值然后取mean。最后得到loss為pg_loss + args.vf_coef * vf_loss
vpred、vpredclipped、mb_return分別通過從vpred_temp選擇回答token、對vpred進行clamp裁剪、advantages + values
RL算法對比
對比一下GRPO和DPO的處理過程
DPO純數據驅動過程,數據驅動:訓練時需要標注好的偏好對:\([q, y^+], [q, y^-]\)。計算流程:1. 輸入同一個問題 \(q\),分別拼接上正樣本回答 \(y^+\) 和負樣本回答 \(y^-\)。2. 用當前模型和參考模型分別計算 \(\log \pi_\theta(y^+|q), \log \pi_\theta(y^-|q), \log \pi_{\text{ref}}(y^+|q), \log \pi_{\text{ref}}(y^-|q)\)。3. 基于這 4 個 log-prob,直接計算一個 logistic 回歸式的 loss,強制模型在正樣本上比分數更高,在負樣本上比分數更低。
GRPO生成驅動過程,生成驅動:訓練時只給定問題 prompt,模型自己 roll-out 多個回答。計算流程:1. 對每個問題生成 \(G\) 個回答。2. 通過獎勵函數(或打分器)給每個回答打分 \(r_i\)。3. 組內歸一化獎勵 → 得到 advantage 值 \(A_i\)(比組內平均好/差多少)。4. 用參考模型計算 ref_per_token_logps(使用ref_model生成沒有的話直接用model代替ref_model)。5. 用舊策略(凍結一幀的當前模型)得到 old_per_token_logps(直接通過model生成)。6. 用當前模型得到 per_token_logps。7. 計算重要性比率和 KL 散度(使用per_token_logps和ref_per_token_logps計算)近似,再套 PPO 風格的剪切目標(使用old_per_token_logps和per_token_logp) → 最終 loss。
對于DPO、GRPO、PPO中KL計算差異
\(KL(p||q)=\sum_x p(x)\log\frac{p(x)}{q(x)}=H(p,q)-H(q)\),交叉熵-熵
計算交叉熵的目的在于約束新策略不要偏離參考策略太多,類似的對于交叉熵損失(\(H(p,q)=-\sum_x p(x)\log q(x)\))兩者之間差異是交叉熵是讓“q去擬合p”,而KL則是度量“q和p之間距離”
1、DPO中計算KL:在model_ref以及model分別輸入“3元組”數據之后會去計算不同token的概率值,也就是model和ref都會生成 reject和choose的概率值,然后去計算:\(\mathrm{loss}=-\frac{1}{N}\sum_{i=1}^{N}\log\sigma\left(\beta\cdot((\log\pi_{\theta}(y_{w}|x)-\log\pi_{\theta}(y_{l}|x))-(\log\pi_{\mathrm{ref}}(y_{w}|x)-\log\pi_{\mathrm{ref}}(y_{l}|x)))\right)\) 的sigmoid 損失優化相對偏好
2、GRPO中計算KL:通過model_ref對于問題Q以及模型生成的多組回答進而可以得到每組回答的token概率:ref_per_token_logps 而后我又通過model去生成多組回答以及token概率:per_token_logps接下來就是直接他們之間KL散度:

3、PPO中計算KL:通過model得到回答中的每一個token的概率logprobs,同樣的再去通過model_rf也去計算每一個token的概率ref_logprobs然后去計算KL

DPO:通過“偏好差值”間接引入 KL 約束,偏重于 對比學習。
GRPO:顯式計算 生成候選組的 token 級 KL,作為正則項,保證模型不偏離參考策略。
PPO:基于當前策略與參考策略(或舊策略)的 KL,常作為 正則或 early stopping 信號
對于GRPO以及PPO中優勢值計算過程
GRPO優勢值計算過程:對于給出多組回答直接通過獎勵函數去計算每組回答的獎勵值而后去上計算:\(A_i = \frac{r_i- mean(r)}{std(r)}\)
PPO優勢值計算過程:一般直接通過廣義優勢估計方法GAE來計算優勢值,首先通過獎勵函數評估模型輸出(問題+回答),而后計算GAE
對比DPO、GRPO、PPO中loss計算差異
DPO的loss計算:
GRPO的loss計算:
PPO的loss計算:


浙公網安備 33010602011771號