強化學(xué)習(xí)系統(tǒng)性學(xué)習(xí)筆記(二):策略優(yōu)化的理論基礎(chǔ)與算法實現(xiàn)
策略優(yōu)化的理論基礎(chǔ)與算法實現(xiàn)
3.2 REINFORCE: 最早的策略梯度算法
在完成策略梯度定理的推導(dǎo)后,我們獲得了梯度的理論形式:
然而,這個期望本身仍然無法直接計算。我們面臨的根本問題是:軌跡空間是高維甚至連續(xù)無限的,無法枚舉所有可能的 \((s_0, a_0, s_1, a_1, \dots)\) 組合。策略優(yōu)化的實踐核心在于用有限采樣近似期望:與環(huán)境交互收集 \(N\) 條軌跡 \(\{\tau_1, \dots, \tau_N\}\),然后用經(jīng)驗平均估計梯度:
這就是 REINFORCE 算法(Williams, 1992)的核心思想。其訓(xùn)練流程為:
- 用當(dāng)前策略 \(\pi_\theta\) 采樣 \(N\) 條完整軌跡
- 對每條軌跡計算累積回報 \(G_t = \sum_{t'=t}^T r_{t'}\)(從時刻 \(t\) 到終止)
- 可選地引入固定 baseline \(b(s_t)\)(如所有軌跡的平均回報)
- 計算梯度并更新參數(shù):\(\theta \leftarrow \theta + \alpha \hat{g}\)
采樣帶來的根本挑戰(zhàn):方差問題
我們真正想要的是策略的平均性能,但只能通過有限采樣來估計。這引入了兩個核心要求:
- 無偏性(unbiased):采樣梯度的期望應(yīng)等于真實梯度
- 低方差(low variance):不同采樣批次的梯度應(yīng)相近
REINFORCE 滿足無偏性,但存在高方差問題。考慮一個簡單例子:
示例:訓(xùn)練語言模型回答醫(yī)療問題。
- Prompt: "如何緩解頭痛?"
- Response 1(軌跡1): "多喝水,適當(dāng)休息,必要時服用布洛芬。" → 獎勵 \(R_1 = 0.9\)
- Response 2(軌跡2): "頭痛可能由多種原因引起..." (啰嗦但正確) → 獎勵 \(R_2 = 0.6\)
- Response 3(軌跡3): "建議立即手術(shù)治療。" (錯誤) → 獎勵 \(R_3 = -0.8\)
即使這三條回復(fù)來自同一個策略,它們的回報差異巨大(\(0.9, 0.6, -0.8\))。用這些樣本計算的梯度會劇烈波動,導(dǎo)致:
- 需要大量軌跡(如 \(N=1000\))才能得到穩(wěn)定估計
- 訓(xùn)練過程緩慢且不穩(wěn)定
- 對于長對話(如 \(T=100\) 輪),方差會指數(shù)級增長
關(guān)鍵疑問:每次更新參數(shù)后策略就變了,那我是只用一條軌跡就更新嗎?
回答:不是。REINFORCE 的標(biāo)準(zhǔn)做法是:
- 用當(dāng)前策略 \(\pi_\theta\) 采樣 \(N\) 條軌跡(如 \(N=64\))
- 用這 \(N\) 條軌跡的平均梯度更新參數(shù)一次
- 更新后策略變?yōu)?\(\pi_{\theta'}\),之前的 \(N\) 條軌跡全部作廢
- 重新用 \(\pi_{\theta'}\) 采樣新的 \(N\) 條軌跡,重復(fù)上述過程
這就是 On-Policy 的含義:數(shù)據(jù)必須來自當(dāng)前策略,每次更新后舊數(shù)據(jù)失效,導(dǎo)致樣本效率極低。
3.3 Actor-Critic
REINFORCE 的高方差源于用 Monte Carlo 回報 \(G_t\)(需要完整軌跡)。如果能用一個學(xué)習(xí)出來的函數(shù)估計未來回報,就可以:
- 降低方差(函數(shù)估計比單次采樣穩(wěn)定)
- 支持單步更新(不需要等軌跡結(jié)束)
這就是 Actor-Critic 框架的核心思想:引入 Critic 網(wǎng)絡(luò) \(V_\phi(s)\) 估計狀態(tài)價值,用它構(gòu)造低方差的優(yōu)勢函數(shù)。
雙網(wǎng)絡(luò)架構(gòu)
系統(tǒng)維護兩個神經(jīng)網(wǎng)絡(luò):
- Actor \(\pi_\theta(a|s)\):策略網(wǎng)絡(luò),負責(zé)生成動作
- Critic \(V_\phi(s)\):價值網(wǎng)絡(luò),評估狀態(tài)的好壞
訓(xùn)練目標(biāo):
-
Critic 的更新:學(xué)習(xí)預(yù)測真實回報
\[\mathcal{L}_{\text{critic}} = \mathbb{E}\left[(V_\phi(s_t) - G_t)^2\right] \]其中 \(G_t\) 是實際觀察到的累積回報(監(jiān)督信號)。
-
Actor 的更新:用 Critic 估計的優(yōu)勢調(diào)整策略
\[\mathcal{L}_{\text{actor}} = -\mathbb{E}\left[\log \pi_\theta(a_t|s_t) \cdot A_t\right] \]其中優(yōu)勢函數(shù) \(A_t = G_t - V_\phi(s_t)\) 衡量動作相對于平均水平的好壞。
關(guān)鍵實現(xiàn)細節(jié):計算優(yōu)勢時必須阻斷梯度:
advantage = reward - value.detach() # ? 阻斷梯度回傳
這確保 Actor 的更新不會干擾 Critic 的學(xué)習(xí)目標(biāo)。
單步更新的進階:TD 誤差
在 Actor-Critic (AC) 框架中,我們可以使用 TD (Temporal Difference) 誤差 來替代傳統(tǒng)的 Monte Carlo 回報,從而實現(xiàn)單步更新。
TD 優(yōu)勢的定義如下:
與 Monte Carlo 方法對比:
-
Monte Carlo 優(yōu)勢 (\(A_t^{MC}\)):
- 公式:\(A_t^{MC} = G_t - V(s_t)\)
- 特點:需要運行完整個軌跡才能計算,是無偏估計,但通常具有很高的方差。
-
TD 優(yōu)勢 (\(A_t^{TD}\)):
- 公式:\(A_t^{TD} = \delta_t\)
- 特點:只需要一步(single-step transition)即可計算,方差較低,但是一個有偏估計(其準(zhǔn)確性依賴于價值函數(shù) \(V\) 的估計精度)。
3.4 GAE (Generalized Advantage Estimation) 的推導(dǎo)
1. 真實的優(yōu)勢函數(shù)
我們首先定義一個理論上“真實”的優(yōu)勢函數(shù),它使用實際的未來回報 \(G_t\):
我們的目標(biāo)是使用一系列的 TD 誤差 \(\delta\) 來構(gòu)造一個對這個“真優(yōu)勢”的良好估計。
2. 基于 Bellman 方程的展開
根據(jù) Bellman 遞推公式,任意時刻的回報 \(G_t\) 可以展開為:
將其代入真實優(yōu)勢的定義中:
為了引入 TD 誤差 \(\delta_t\),我們在上式中同時加上和減去 \(\gamma V(s_{t+1})\):
觀察上式,我們可以發(fā)現(xiàn):
- 第一個方括號內(nèi)的部分正好是 TD 誤差 \(\delta_t\)。
- 第二個方括號內(nèi)的部分是下一時刻的真實優(yōu)勢 \(A_{t+1}^{\text{true}}\)。
于是,我們得到了一個關(guān)于真實優(yōu)勢的遞歸關(guān)系:
3. 遞歸展開與關(guān)鍵結(jié)論
將上述遞歸關(guān)系不斷展開,可以得到:
關(guān)鍵結(jié)論:真實的優(yōu)勢函數(shù),等于所有未來 TD 誤差的折扣加權(quán)和。
這個結(jié)論非常直觀:
- \(\delta_t\) 代表當(dāng)前這一步?jīng)Q策帶來的“驚喜”或“估計誤差”。
- \(\delta_{t+1}, \delta_{t+2}, \dots\) 代表未來每一步的誤差。
- 折扣因子 \(\gamma\) 確保了越遙遠的未來,其誤差對當(dāng)前優(yōu)勢的影響越小。
GAE 的核心思想:偏差-方差的權(quán)衡
問題與動機
雖然上述展開式在理論上很完美,但在實踐中存在兩個問題:
- 依賴完整軌跡:它依然需要未來所有的 \(\delta\) 值,這意味著必須等到整個回合(episode)結(jié)束后才能計算,這本質(zhì)上是 Monte Carlo 風(fēng)格的估計,方差很大。
- 誤差累積:我們不希望使用過長的序列,因為未來的不確定性高,價值函數(shù)的估計誤差會不斷累積。
我們需要在“充分利用未來信息”和“抑制噪聲(降低方差)”之間找到一個平衡點。
引入 \(\lambda\):偏差-方差的平衡因子
GAE 的核心思想是引入一個衰減系數(shù) \(\lambda\) (通常取值在 0.9 到 0.99 之間),用它來控制未來 TD 誤差的權(quán)重。
GAE 的定義:
- \(\gamma\):環(huán)境的獎勵折扣因子,反映了任務(wù)本身對未來的重視程度。
- \(\lambda\):優(yōu)勢函數(shù)的折扣因子,是我們用來控制偏差-方差權(quán)衡的人為超參數(shù)。
- \(\delta\):每一步的 TD 誤差。
理解 \(\lambda\) 的作用
-
當(dāng) \(\lambda = 0\) 時:
\(A_t = \delta_t\)
這等價于傳統(tǒng)的 TD(0) 誤差,只考慮一步信息。這種方法偏差最大,但方差最小。 -
當(dāng) \(\lambda = 1\) 時:
\(A_t = \sum_{l=0}^{\infty} \gamma^l \delta_{t+l} = G_t - V(s_t)\)
這恢復(fù)了原始的展開式,等價于 Monte Carlo 方法。這種方法無偏,但方差最大。 -
當(dāng) \(\lambda \in (0,1)\) 時:
GAE 在 TD 和 Monte Carlo 之間進行插值。未來的 \(\delta\) 權(quán)重會以 \((\gamma\lambda)\) 的速率衰減,實現(xiàn)了在“看得多遠”與“抑制噪聲”之間的平滑過渡。
GAE 的計算與實現(xiàn)
上述求和公式可以轉(zhuǎn)化為一個高效的反向遞推形式,非常適合在代碼中實現(xiàn)。
GAE 遞推公式:
這個計算過程類似于循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)中的反向傳播,我們從軌跡的末端開始,反向遍歷計算每一時刻的優(yōu)勢值。
偽代碼示例:
advantages = torch.zeros_like(rewards)
gae = 0
# 從后往前遍歷時間步
for t in reversed(range(T)):
# 1. 計算當(dāng)前步的 TD 誤差 delta
delta = rewards[t] + gamma * values[t+1] - values[t]
# 2. 使用遞推公式計算 gae
gae = delta + gamma * lam * gae
# 3. 存儲當(dāng)前步的優(yōu)勢值
advantages[t] = gae
注意:
- 計算必須反向遍歷時間,因為 \(A_t\) 依賴于未來的 \(A_{t+1}\)。
values[t+1]是 Critic 網(wǎng)絡(luò)對下一狀態(tài)的價值預(yù)測。- 這個高效的計算方法是 PPO、A2C、A3C 等現(xiàn)代強化學(xué)習(xí)算法的標(biāo)準(zhǔn)組成部分。
GAE 與 n-step TD 的關(guān)系
GAE 還可以被看作是所有 n-step TD 優(yōu)勢估計 的指數(shù)加權(quán)平均:
其中,n-step 優(yōu)勢 \(A_t^{(n)}\) 的定義為:
總結(jié)來說:
- \(\lambda\) 決定了我們將多少不同長度(n-step)的 TD 估計綜合在一起。
- 較小的 \(\lambda\) 更側(cè)重于短期的、偏差較大的估計。
- 較大的 \(\lambda\) 更側(cè)重于長期的、方差較大的估計。
- 在實踐中,\(\lambda=0.95\) 通常是一個很好的經(jīng)驗?zāi)J值。
3.5 On-Policy 的困境與重要性采樣
樣本效率的致命弱點
前述所有算法(REINFORCE, AC, A2C/A3C)都是 On-Policy:梯度計算要求數(shù)據(jù)來自當(dāng)前策略 \(\pi_\theta\)。這導(dǎo)致:
- 每次更新后,\(\pi_\theta\) 改變,舊數(shù)據(jù)立即失效
- 對于 LLM,生成一次回復(fù)需要數(shù)秒,但只能用一次就丟棄
- 訓(xùn)練 100 萬步需要采樣 100 萬條新數(shù)據(jù)
量化對比(以 Qwen-7B 為例):
| 方法 | 單次采樣耗時 | 數(shù)據(jù)復(fù)用 | 訓(xùn)練 1000 步總耗時 |
|---|---|---|---|
| On-Policy | 3 秒 | 1 次 | 3000 秒 |
| Off-Policy(PPO) | 3 秒 | 4 次 | 750 秒 |
重要性采樣:Off-Policy 的數(shù)學(xué)工具
核心問題:能否用舊策略 \(\pi_{\text{old}}\) 的數(shù)據(jù)訓(xùn)練新策略 \(\pi_\theta\)?
數(shù)學(xué)原理(重要性采樣定理):對于任意函數(shù) \(f(x)\),
證明(簡單積分變換):
應(yīng)用到策略梯度:
原目標(biāo)是 \(\mathbb{E}_{a \sim \pi_\theta}[\nabla \log \pi_\theta \cdot A]\),但數(shù)據(jù)來自 \(\pi_{\text{old}}\),引入比率修正:
進一步簡化(利用 \(\nabla \log \pi = \pi^{-1} \nabla \pi\)),可將目標(biāo)函數(shù)寫為:
醫(yī)療問答示例:
- 舊策略生成:"多喝水,休息"(概率 \(\pi_{\text{old}} = 0.3\))
- 新策略評估該回復(fù):\(\pi_\theta = 0.5\)(更傾向此回答)
- 優(yōu)勢 \(A = 0.8\)(好回答)
- 修正后的梯度貢獻:\(\frac{0.5}{0.3} \times 0.8 = 1.33\)
關(guān)鍵挑戰(zhàn):如果比率 \(r = \frac{\pi_\theta}{\pi_{\text{old}}}\) 過大(如 10),說明新舊策略差異巨大,重要性采樣失效,梯度估計方差爆炸。需要約束策略更新幅度。
3.6 TRPO: 信賴域約束下的策略優(yōu)化
優(yōu)化目標(biāo)的理論保證
TRPO(Schulman et al., 2015)的核心思想:在限制策略變化的前提下最大化性能提升。
優(yōu)化問題:
KL 散度約束衡量兩個分布的差異:
直觀理解:
- 目標(biāo)函數(shù):最大化性能(用舊數(shù)據(jù)評估新策略)
- 約束條件:KL 散度 \(\leq \delta\)(如 0.01),確保新策略不偏離太遠
醫(yī)療問答示例:
- 舊策略分布:P("多喝水")=0.3, P("休息")=0.4, P("吃藥")=0.3
- 新策略分布:P("多喝水")=0.5, P("休息")=0.35, P("吃藥")=0.15
計算 KL 散度:
如果 \(\delta=0.05\),則該更新違反約束,需要縮小更新步長。
實現(xiàn)方法:二階優(yōu)化
TRPO 用共軛梯度法求解帶約束的優(yōu)化問題,需要計算 Hessian 矩陣(目標(biāo)函數(shù)的二階導(dǎo)數(shù))。雖然理論保證強(單調(diào)改進),但計算復(fù)雜度高,實現(xiàn)困難,調(diào)參敏感。
3.7 PPO
PPO(Schulman et al., 2017)用一階優(yōu)化 + 巧妙的目標(biāo)函數(shù)設(shè)計達到 TRPO 的效果,成為深度 RL 和 RLHF 的標(biāo)準(zhǔn)算法。
3.7.1 PPO-Clip: 用裁剪替代 KL 約束
核心思想:不顯式約束 KL 散度,而是直接限制比率 \(r_t = \frac{\pi_\theta(a|s)}{\pi_{\text{old}}(a|s)}\) 的變化范圍。
目標(biāo)函數(shù):
其中 \(\text{clip}(r, 1-\epsilon, 1+\epsilon)\) 將 \(r\) 限制在 \([1-\epsilon, 1+\epsilon]\)(通常 \(\epsilon=0.2\))。
逐項分析:
情況 1: 優(yōu)勢 \(A_t > 0\)(好動作,希望增加概率)
- 如果 \(r_t < 1+\epsilon\):正常梯度,繼續(xù)增加 \(\pi_\theta(a|s)\)
- 如果 \(r_t > 1+\epsilon\):被裁剪為 \(1+\epsilon\),停止增加(防止過度優(yōu)化)
情況 2: 優(yōu)勢 \(A_t < 0\)(壞動作,希望減少概率)
- 如果 \(r_t > 1-\epsilon\):正常梯度,繼續(xù)減少 \(\pi_\theta(a|s)\)
- 如果 \(r_t < 1-\epsilon\):被裁剪為 \(1-\epsilon\),停止減少(防止過度懲罰)
醫(yī)療問答示例(具體計算):
- Prompt: "如何緩解頭痛?"
- Response: "多喝水,適當(dāng)休息"
- 舊策略: \(\pi_{\text{old}}(response|prompt) = 0.01\)(log prob = -4.6)
- 新策略: \(\pi_{\theta}(response|prompt) = 0.03\)(log prob = -3.5)
- 優(yōu)勢: \(A = 0.8\)(好回答)
- 比率: \(r = \frac{0.03}{0.01} = 3.0\)
PPO 處理(設(shè) \(\epsilon=0.2\)):
原始項: r * A = 3.0 * 0.8 = 2.4
裁剪項: clip(3.0, 0.8, 1.2) * A = 1.2 * 0.8 = 0.96
最終: min(2.4, 0.96) = 0.96 ← 被裁剪!
解讀:雖然新策略概率增加了 3 倍,但 PPO 只允許增加到 1.2 倍的幅度,防止策略突變。
3.7.2 PPO-KL: 自適應(yīng)懲罰
另一種變體直接在目標(biāo)中加入 KL 懲罰:
自適應(yīng) \(\beta\):
- 如果 \(\text{KL} > 1.5 \times \text{target}\):增大 \(\beta\)(加強懲罰)
- 如果 \(\text{KL} < 0.5 \times \text{target}\):減小 \(\beta\)(放松約束)
實踐中 PPO-Clip 更常用,因為無需調(diào)節(jié) \(\beta\)。
3.7.3 PPO-Clip 完整訓(xùn)練流程
關(guān)鍵特性:數(shù)據(jù)復(fù)用 \(K\) 次(\(K=4 \sim 10\))
for iteration in range(總迭代次數(shù)):
# 1. 采樣階段(執(zhí)行 1 次)
用當(dāng)前策略 π_θ 采樣 N 條軌跡
記錄 old_log_probs = log π_θ(a|s) # 保存!
# 2. 計算優(yōu)勢(用 GAE)
用 Critic 估計 V(s)
計算 advantages = GAE(rewards, values)
# 3. 多輪 mini-batch 更新(數(shù)據(jù)復(fù)用 K 次)
for epoch in range(K): # K=4
for batch in minibatch(trajectories):
# 重新計算新策略概率
new_log_probs = log π_θ(a|s) # 策略已更新!
# 計算比率
ratio = exp(new_log_probs - old_log_probs)
# PPO-Clip loss
loss_clip = -min(ratio * A, clip(ratio, 1-ε, 1+ε) * A)
# 價值函數(shù) loss
loss_vf = (V(s) - returns)2
# 總損失
loss = loss_clip + c_vf * loss_vf
# 梯度更新
optimizer.step()
關(guān)鍵點:
old_log_probs在 \(K\) 輪更新中保持不變(來自采樣時的策略)new_log_probs每次都重新計算(因為參數(shù)在變)- 數(shù)據(jù)復(fù)用 4 次后,重新采樣新數(shù)據(jù)
參加參數(shù)設(shè)置


浙公網(wǎng)安備 33010602011771號