LLM 場景下的強化學習技術掃盲
1. 強化學習基礎:行業黑話
想象你正在和一個剛訓練好的語言模型聊天。你問:“今天過得怎么樣?”
模型可能回:“還行?!?也可能回:“我是個 AI,沒有感情。”
人類覺得前者更自然、更友好——這就是偏好反饋。強化學習(RL)在 LLM 中的核心任務,就是讓模型學會生成“人類更喜歡”的回復。
為了做到這一點,我們需要一套語言來描述這個過程。下面我們以 LLM 場景為基礎介紹幾個 RL 的“行業黑話”。
1.1 基本概念
-
時刻 \(t\) :就是對話的第幾步。比如:
- \(t=0\):用戶輸入 “今天過得怎么樣?” → 這是初始狀態 \(s_0\)
- \(t=1\):模型輸出第一個詞 “今” → 動作 \(a_0 = \text{“今”}\)
- \(t=2\):模型輸出第二個詞 “天” → 動作 \(a_1 = \text{“天”}\)
- … 直到生成完整回復,比如 “今天過得不錯!”
-
在 LLM 中,狀態 \(s_t\) 通常就是到第 \(t\) 步為止已生成的 token 序列(包括用戶輸入和模型已輸出的部分)
-
動作 \(a_t\) 就是模型在第 \(t\) 步選擇的下一個 token。
-
獎勵 \(r_t\):這是人類(或獎勵模型)對模型行為的真實反饋信號。比如:
- 如果模型最終生成了“今天過得不錯!”,人類覺得回答的不錯,打 5 分 → 這個分數會折算成一個最終獎勵 \(r_T\)(通常只在序列結束時給,即最后一個 token)
- 中間步驟一般沒有即時獎勵(\(r_t = 0\) for \(t < T\))
-
價值 \(v\):獎勵 \(r\) 是真實的、來自外部的信號(比如人類打分),相對應的,價值(value)是對未來獎勵的估計——因為模型不能預知未來,只能靠猜。
1.2. 價值(Value):對未來獎勵的“預判”
既然模型不能看到未來,它就需要一個“預判能力”:我現在處在某個狀態,未來大概能拿多少分?
這就引出了兩個核心函數:
1) 狀態價值函數 \(V(s_t)\)
它表達的是:在當前已生成的對話上下文 \(s_t\)(比如用戶剛問完 “今天過得怎么樣?”,而模型還沒開始回答,或已輸出“今”),模型按照當前策略繼續生成后續內容,平均能獲得多少人類打分。
- \(\pi(a|s_t)\) 是模型在狀態 \(s_t\) 下選擇下一個詞 \(a\) 的概率(例如在“今天過得怎么樣?”之后,選“今”還是“還”);
- \(Q(s_t, a)\) 表示如果此刻選了某個具體詞 \(a\),最終能拿到的預期總分;
- 把所有可能的下一個詞按模型當前的偏好加權平均,就得到了該狀態的整體“預期得分”——也就是 \(V(s_t)\)。
舉個例子:當模型已經輸出 “今天過得”,它會評估:“按我現在的風格繼續回答,人類大概率會覺得自然,可能打 4 分”,于是 \(V(s_t) \approx 4\)。
2) 動作價值函數 \(Q(s_t, a)\)
它表達的是:如果我現在處于狀態 \(s_t\)(比如上下文是“今天過得”),并選擇動作 \(a\)(比如生成“不”),那么我能獲得當前的真實獎勵 \(r_t\)(通常是 0,因為回復還沒結束),再加上未來所有狀態價值的折扣和。
對應到 LLM 應用場景就表示:
“如果我現在在‘今天過得’后面接‘不’,形成‘今天過得不’,那接下來我大概率會說‘錯!’,組成一句完整、積極的回復,最終人類可能會打 5 分?!?/p>
其中:
- \(r_t\) 是真實發生的獎勵,但在 LLM 生成過程中,只有完整回復結束后才有非零值(例如人類打分 \(r_T = 5\));在中間步驟(如生成“今”“天”時),\(r_t = 0\);
- \(V(s_{t+1}), V(s_{t+2}), \dots\) 是模型自己估計的未來價值(比如生成“不”之后,預估“今天過得不錯!”能拿 4.9 分);
- \(\gamma \in [0,1]\) 是折扣因子(如 0.95),表示“未來的分不如現在的分值錢”——越靠后的 token 對當前決策的影響越小。
雖然中間每一步的 \(r_t = 0\),但 \(Q(s_t, a)\) 依然非常關鍵:它通過 \(V(s_{t+1})\) 等未來價值,把對最終人類反饋的預判傳遞回當前決策。這正是 LLM 在生成每個詞時具備“前瞻能力”的來源——它不是隨機選詞,而是基于“這樣說人類會不會喜歡”的長期預期來做選擇。
為什么估計的價值函數 Q 里包含真實的 \(r_t\)?
因為 RL 的目標是用真實獎勵來校準價值估計。模型通過不斷對比“預測的未來得分”和“實際拿到的獎勵”,來修正自己的 \(V\) 和 \(Q\) 函數。
2. PPO:RLHF 的“老大哥”
PPO(Proximal Policy Optimization)是傳統 RLHF(基于人類反饋的強化學習)流程中的核心算法,是 openai 在 2016年左右提出來的。原來 closeAI 的成功在那個時候就開始蓄力了。PPO的目標很直接:讓語言模型生成更受人類歡迎的回復。
PPO 中的幾個關鍵角色
| 模型 | 是否訓練 | 輸入 | 輸出 | 輸出維度說明 |
|---|---|---|---|---|
| Policy Model \(\pi_\theta\) | ? 是 | prompt \(x\)(token IDs,長度 \(L_x\)) | 生成回復 \(y = (a_1,\dots,a_T)\),以及每個 token 的 log-prob \(\log \pi_\theta(a_t | s_t)\) | \(y\): \([T]\) logprobs: \([T]\) |
| Reference Model \(\pi_{\text{ref}}\) | ? 凍結 | 同上 \(x\) | 同上 log-prob \(\log \pi_{\text{ref}}(a_t | s_t)\) | \([T]\) |
| Critic Model \(V_\psi\) | ? 是 | 狀態序列 \(s_t = x \oplus y_{\le t}\)(token IDs,長度 \(L_x + t\)) | 價值估計 \(V_\psi(s_t)\) | 標量(或 \([1]\)),對每個 \(t=0,\dots,T\) 輸出一個值 → 總輸出 \([T+1]\) |
| Reward Model \(r_\phi\) | ? 凍結 | \((x, y)\)(完整 prompt + response) | 標量獎勵 \(R = r_\phi(x, y)\) | 標量(或 \([1]\)) |
注:\(\oplus\) 表示 token 拼接;\(T\) 是生成回復的長度(可變,但訓練時通常 padding 到固定長度)。
PPO 的兩階段訓練流程
PPO通過分階段解耦“數據生成”和“策略學習”,在保證訓練穩定性的同時,讓模型逐步學會生成更符合人類偏好的回復。整個流程分為如下兩個階段:
階段 1:采樣與反饋(Sample + Label)
? 目標
用當前策略模型生成一批回復,并利用凍結的獎勵模型打分,再結合當前評論家模型估計價值,最終為每個 token 動作計算出優勢(Advantage) 和回報(Return),作為后續訓練的監督信號。
?? 關鍵點:此階段不更新任何模型參數,只是“收集數據”。Policy 和 Critic 在采樣時使用的是當前最新參數,但輸出會被 detach(視為常數),作為“舊策略”和“舊評論家”的快照。
?? 參與模型與接口
| 模型 | 是否更新 | 輸入 | 輸出 | 輸出維度 |
|---|---|---|---|---|
| Policy Model \(\pi_\theta\) | ?(采樣時不更新) | prompt \(x \in \mathbb{Z}^{L_x}\) | 生成回復 \(y \in \mathbb{Z}^T\) 及每個 token 的 log-prob \(\log \pi_\theta(a_t | s_t)\) |
\(y\): \([T]\) logprobs: \([T]\) |
| Critic Model \(V_\psi\) | ?(采樣時不更新) | 狀態 \(s_t = x \oplus y_{\le t} \in \mathbb{Z}^{L_x + t}\) | 價值估計 \(V_\psi(s_t) \in \mathbb{R}\) | 對 \(t=0,\dots,T\) 輸出 \([T+1]\) 個標量 |
| Reward Model \(r_\phi\) | ?(始終凍結) | \((x, y)\) | 標量獎勵 \(R = r_\phi(x, y)\) | \([1]\) |
注:\(L_x\) 是 prompt 長度,\(T\) 是生成回復長度(實際中常 padding 到固定 max_len)。
?? 核心計算邏輯
-
生成軌跡:對每個 prompt \(x\),用當前策略生成完整回復 \(y = (a_1, ..., a_T)\),形成狀態序列:
\[s_0 = x,\quad s_1 = x \oplus a_1,\quad \dots,\quad s_T = x \oplus y \] -
獲取最終獎勵:調用凍結的 Reward Model:
\[R = r_\phi(x, y) \](中間步驟無獎勵,即 \(r_t = 0\) for \(t < T\))
-
計算回報(Return):
\[\hat{R}_t = \sum_{k=t}^{T} \gamma^{k-t} r_k = \gamma^{T - t} R \]因為只有最后一步有獎勵?;貓笮蛄?\(\hat{R}_0, \hat{R}_1, ..., \hat{R}_T\) 構成目標值。
-
計算優勢(Advantage):
\[A_t = \hat{R}_t - V_\psi(s_t), \quad t = 0, 1, ..., T-1 \]表示:在狀態 \(s_t\) 下執行動作 \(a_t\),比“平均水平”好多少。
-
保存“舊”值:將當前策略的 log-prob 和評論家的 value detach,作為階段 2 的基準(即“old policy”和“old critic”)。
?? 偽代碼(階段 1)
trajectories = []
for x in prompts: # x: [L_x]
# 1. 用當前策略生成回復 y 和 log-prob
y, logprobs = policy_model.generate_with_logprobs(x) # y: [T], logprobs: [T]
# 2. 構建狀態序列 s_0 ... s_T
states = [torch.cat([x, y[:t]]) for t in range(len(y) + 1)] # len = T+1
# 3. 用當前評論家估計每個狀態的價值
values = torch.stack([critic_model(s) for s in states]) # [T+1]
# 4. 獎勵模型打分(僅最終獎勵)
R = reward_model(x, y) # scalar
# 5. 計算回報:R_t = γ^{T?t} * R
T_len = len(y)
returns = torch.zeros(T_len + 1)
returns[T_len] = R
for t in reversed(range(T_len)):
returns[t] = gamma * returns[t + 1]
# 6. 計算優勢:A_t = R_t - V(s_t),僅對 t=0..T-1 有效
advantages = returns[:-1] - values[:-1] # [T]
# 7. 保存“舊”值(detach 阻斷梯度)
trajectories.append({
'x': x,
'y': y,
'logprobs_old': logprobs.detach(), # [T]
'values_old': values.detach(), # [T+1]
'advantages': advantages.detach(), # [T]
'returns': returns.detach() # [T+1]
})
? 此階段結束時,我們得到一個固定的數據集,后續訓練將在此數據上多次迭代。
階段 2:策略與評論家更新(Policy & Critic Learning)
? 目標
利用階段 1 收集的固定軌跡數據,更新策略模型(Policy)和評論家模型(Critic),使得:
- 策略更傾向于選擇高優勢的動作;
- 評論家更準確地預測未來回報;
- 同時通過 PPO-clip 和 KL 正則防止策略突變或偏離合理語言分布。
?? 參與模型與接口
| 模型 | 是否更新 | 作用 |
|---|---|---|
| Policy Model \(\pi_\theta\) | ? | 被優化的主模型 |
| Critic Model \(V_\psi\) | ? | 被優化的價值估計器 |
| Reference Model \(\pi_{\text{ref}}\) | ?(始終凍結) | 提供 KL 正則基準(通常是 SFT 后的初始模型) |
| Reward Model | ? | 不參與此階段 |
?? 核心計算邏輯
-
策略損失(PPO-Clip)
定義概率比:\[r_t(\theta) = \frac{\pi_\theta(a_t | s_t)}{\pi_{\text{old}}(a_t | s_t)} = \exp\left( \log \pi_\theta(a_t|s_t) - \log \pi_{\text{old}}(a_t|s_t) \right) \]PPO 損失為:
\[\mathcal{L}^{\text{PPO}} = \mathbb{E}_t \left[ \min\left( r_t(\theta) A_t,\ \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right] \]- 若 \(A_t > 0\):鼓勵增加動作概率,但最多增加 \((1+\epsilon)\) 倍;
- 若 \(A_t < 0\):鼓勵減少概率,但最多減少到 \((1-\epsilon)\) 倍。
-
KL 散度正則(防止語言退化)
\[\mathcal{L}^{\text{KL}} = \beta \cdot \mathbb{E}_t \left[ \log \pi_\theta(a_t|s_t) - \log \pi_{\text{ref}}(a_t|s_t) \right] \]- \(\pi_{\text{ref}}\) 是凍結的 SFT 模型;
- \(\beta\) 控制正則強度(如 0.01~0.1)。
-
評論家損失(Value MSE)
\[\mathcal{L}^{\text{value}} = \mathbb{E}_t \left[ \left( V_\psi(s_t) - \hat{R}_t \right)^2 \right] \]- 目標是讓評論家準確預測階段 1 計算出的回報 \(\hat{R}_t\)。
-
總損失:
\[\mathcal{L}_{\text{total}} = -\mathcal{L}^{\text{PPO}} + \beta \cdot \text{KL} + c_1 \cdot \mathcal{L}^{\text{value}} \]
?? 偽代碼(階段 2)
for epoch in range(K): # K=2~4,對同一數據集多輪優化
for traj in trajectories:
x, y = traj['x'], traj['y'] # x: [L_x], y: [T]
logprobs_old = traj['logprobs_old'] # [T]
advantages = traj['advantages'] # [T]
returns = traj['returns'] # [T+1]
# --- 1. 策略損失 ---
logprobs_curr = policy_model.get_logprobs(x, y) # [T]
ratio = torch.exp(logprobs_curr - logprobs_old) # [T]
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - eps, 1 + eps) * advantages
ppo_loss = -torch.mean(torch.min(surr1, surr2))
# KL 正則(ref_model 凍結)
with torch.no_grad():
logprobs_ref = ref_model.get_logprobs(x, y) # [T]
kl_loss = torch.mean(logprobs_curr - logprobs_ref)
policy_loss = ppo_loss + beta * kl_loss
# --- 2. 評論家損失 ---
states = [torch.cat([x, y[:t]]) for t in range(len(y) + 1)]
values_pred = torch.stack([critic_model(s) for s in states]) # [T+1]
value_loss = F.mse_loss(values_pred, returns)
# --- 3. 優化 ---
total_loss = policy_loss + c1 * value_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
?? 總結:PPO 的設計哲學
- 階段 1 是“探索”:用當前策略生成多樣回復,用外部信號(RM)和內部估計(Critic)打標簽;
- 階段 2 是“學習”:在固定數據上保守更新,通過 clip 和 KL 防止“學歪”;
- Reference Model 是安全網:確保語言依然流暢、合理;
- 整個流程可迭代:每輪 PPO 后,策略更強,下一輪采樣質量更高。
這種“采樣-學習”交替的模式,正是 PPO 能在 LLM 對齊中兼顧效果、穩定性和安全性的關鍵。
3. DPO:繞過 RL 的“聰明辦法”
DPO(Direct Preference Optimization)發現:其實不需要顯式訓練 Reward Model + PPO,可以直接從人類偏好數據中優化策略。
DPO 的核心洞察
人類偏好數據是成對的:\((x, y_w, y_l)\),其中:
- \(x\):用戶輸入(prompt)
- \(y_w\):人類偏好的回復(win)
- \(y_l\):較差的回復(lose)
DPO 證明:最大化人類偏好等價于最小化下面這個損失:
這個公式到底在算什么?
- \(\pi_\theta(y|x)\):當前訓練模型在 prompt \(x\) 下生成完整回復 \(y\) 的概率
→ 實際計算時,是把 \(y\) 拆成 token 序列,求 \(\prod_t \pi_\theta(y_t | x, y_{<t})\) - \(\pi_{\text{ref}}(y|x)\):參考模型(SFT 模型)生成 \(y\) 的概率
- \(\beta\):溫度參數,控制優化強度(越大越激進)
通俗理解:DPO 希望模型對“好回復”的相對概率(相比參考模型)比“壞回復”更高。
DPO 偽代碼
for batch in preference_data:
x, y_w, y_l = batch
# 計算當前模型和參考模型對兩個回復的 log 概率
logp_w = policy_model.log_prob(x, y_w)
logp_l = policy_model.log_prob(x, y_l)
ref_logp_w = ref_model.log_prob(x, y_w)
ref_logp_l = ref_model.log_prob(x, y_l)
# 計算 logits 差
logits = beta * ((logp_w - ref_logp_w) - (logp_l - ref_logp_l))
# 二分類損失:希望 logits 越大越好
loss = -F.logsigmoid(logits).mean()
optimizer.step(loss)
DPO 本質是一個帶參考模型的對比學習(contrastive learning),完全不需要 RL 循環,所以訓練快、穩定。
4. GRPO:在 PPO 和 DPO 之間找平衡
PPO vs DPO:各自的痛
| 方法 | 優點 | 缺點 |
|---|---|---|
| PPO | 支持 online learning(邊生成邊學),樣本利用率高;可結合多種獎勵(如安全性、事實性) | 需要訓練 4 個模型(Policy, Critic, RM, Reference),流程復雜;RM 質量直接影響效果 |
| DPO | 訓練簡單,只需 2 個模型(Policy + Reference);效果接近 PPO | 完全依賴離線(offline)偏好數據;容易過擬合(尤其數據少時);無法引入動態獎勵 |
GRPO:群體相對優化
GRPO(Group Relative Policy Optimization)的思路是:
既然人類經常面對多個選項做判斷(比如從 4 個回復中選最好的 2 個),那就直接建模這種“群體偏好”。
GRPO 的做法
- 對每個 prompt \(x\),用當前策略生成 \(K\) 個回復(比如 \(K=4\))
- 根據 Reward Model(或人類)將這些回復分成“好組”和“壞組”
- 優化目標:拉大組間差異,縮小組內差異
GRPO 損失函數(簡化版)
這其實是一個帶參考模型的 softmax 分類損失:希望“好回復”的歸一化概率更高。
GRPO 偽代碼
for x in prompts:
# 1. 生成 K 個回復
responses = [policy_model.generate(x) for _ in range(K)]
# 2. 用 RM 打分并分組(比如 top-2 為 good)
scores = [reward_model.score(x, y) for y in responses]
good_mask = get_top_k_mask(scores, k=2)
# 3. 計算每個回復的 log ratio
ratios = []
for y in responses:
logp = policy_model.log_prob(x, y)
ref_logp = ref_model.log_prob(x, y)
ratios.append(beta * (logp - ref_logp))
# 4. softmax 分類損失
logits = torch.stack(ratios)
loss = F.cross_entropy(logits.unsqueeze(0), target=good_mask)
optimizer.step(loss)
GRPO 的優勢:
- 保留了 PPO 的 online 生成能力(自己造數據)
- 像 DPO 一樣只優化策略模型,無需 Critic
- 對 RM 的依賴比 PPO 弱(只需排序,不要求絕對分數準確)
總結:選哪個?
| 方法 | 模型數量 | 是否需要 RM | 是否 RL | 適合場景 |
|---|---|---|---|---|
| PPO | 4(Policy, Critic, RM, Ref) | ? | ? | 高質量對齊,多目標獎勵 |
| DPO | 2(Policy, Ref) | ? | ? | 快速迭代,偏好數據充足 |
| GRPO | 3(Policy, RM, Ref) | ?(弱依賴) | ??(類 RL) | 平衡效率與效果,支持 online 學習 |
強化學習在 LLM 中,早已不是“必須用 PPO”的時代。DPO 讓對齊變得像 SFT 一樣簡單,GRPO 則試圖把 PPO 的靈活性和 DPO 的簡潔性結合起來。
技術在進化,我們的工具箱也在變豐富。選對方法,比盲目堆資源更重要。

浙公網安備 33010602011771號