強(qiáng)化學(xué)習(xí) 動(dòng)作空間(離散/連續(xù))
1. 離散動(dòng)作空間的策略網(wǎng)絡(luò)
在離散空間中,動(dòng)作是可數(shù)的,例如:{左, 右, 上, 下} 或 {加速, 剎車}。
網(wǎng)絡(luò)架構(gòu)與處理方式
-
輸出層:Softmax
-
策略網(wǎng)絡(luò)的最后一層是一個(gè) Softmax 層。
-
假設(shè)有
N個(gè)可選動(dòng)作,網(wǎng)絡(luò)會(huì)輸出一個(gè)長(zhǎng)度為N的向量。 -
Softmax 函數(shù)確保這個(gè)向量的所有元素都在 (0, 1) 之間,且和為 1。這樣,每個(gè)元素就代表了選擇對(duì)應(yīng)動(dòng)作的概率。
-
-
策略表示
-
策略
π(a|s)直接由網(wǎng)絡(luò)輸出給出:π(a=i|s) = Softmax(Logits(s))[i]
-
-
動(dòng)作采樣
-
根據(jù)網(wǎng)絡(luò)輸出的概率分布,進(jìn)行分類采樣來(lái)選擇動(dòng)作。
-
在 Python 中,可以使用
np.random.choice或torch.distributions.Categorical。
-
import torch import torch.nn as nn import torch.nn.functional as F class DiscretePolicyNetwork(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(DiscretePolicyNetwork, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, output_dim) # output_dim = 動(dòng)作數(shù)量 def forward(self, state): x = F.relu(self.fc1(state)) logits = self.fc2(x) # 輸出 logits,未歸一化的概率 return logits def act(self, state): logits = self.forward(state) # 創(chuàng)建分類分布 action_probs = F.softmax(logits, dim=-1) dist = torch.distributions.Categorical(action_probs) # 采樣動(dòng)作 action = dist.sample() # 計(jì)算對(duì)數(shù)概率,用于策略梯度更新 log_prob = dist.log_prob(action) return action.detach().item(), log_prob # 假設(shè)有4個(gè)動(dòng)作 policy_net = DiscretePolicyNetwork(input_dim=8, hidden_dim=128, output_dim=4) state = torch.tensor([0.1, 0.5, -0.2, ...]) # 狀態(tài)向量 action, log_prob = policy_net.act(state) print(f"Sampled action: {action}")
2. 連續(xù)動(dòng)作空間的策略網(wǎng)絡(luò)
在連續(xù)空間中,動(dòng)作是實(shí)數(shù)向量,例如:方向盤轉(zhuǎn)角 [-1, 1],機(jī)器人關(guān)節(jié)扭矩 [τ?, τ?, ...]。
這里有兩種主要設(shè)計(jì)思路:
A. 隨機(jī)策略 - 輸出分布參數(shù)
這是最常用的方法,策略網(wǎng)絡(luò)輸出一個(gè)概率分布的參數(shù),動(dòng)作從這個(gè)分布中采樣。
-
輸出層:分布參數(shù)
-
最常用的是高斯分布。網(wǎng)絡(luò)為每個(gè)動(dòng)作維度輸出兩個(gè)值:
-
均值:通常使用
tanh作為激活函數(shù),將均值限制在[-1, 1]范圍內(nèi),或者不適用激活函數(shù)。 -
標(biāo)準(zhǔn)差:通常使用
softplus等函數(shù)確保其為正數(shù)。也可以是一個(gè)與狀態(tài)無(wú)關(guān)的可學(xué)習(xí)參數(shù)。
-
-
-
策略表示
-
策略
π(a|s)是一個(gè)概率密度函數(shù)。例如,對(duì)于高斯分布:a ~ N(μ(s), σ(s)2)
-
-
動(dòng)作采樣
-
使用網(wǎng)絡(luò)輸出的均值和標(biāo)準(zhǔn)差構(gòu)建一個(gè)高斯分布,然后從這個(gè)分布中采樣。
-
由于采樣操作不可導(dǎo),在訓(xùn)練時(shí)需要使用重參數(shù)化技巧。
-
class ContinuousPolicyNetwork(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(ContinuousPolicyNetwork, self).__init__() self.output_dim = output_dim # 動(dòng)作空間的維度 self.fc1 = nn.Linear(input_dim, hidden_dim) # 輸出均值 self.mean_head = nn.Linear(hidden_dim, output_dim) # 輸出對(duì)數(shù)標(biāo)準(zhǔn)差(更穩(wěn)定),通常作為一個(gè)獨(dú)立的層 self.log_std_head = nn.Linear(hidden_dim, output_dim) # 或者:self.log_std = nn.Parameter(torch.zeros(1, output_dim)) def forward(self, state): x = F.relu(self.fc1(state)) mean = torch.tanh(self.mean_head(x)) # 將均值限制在[-1,1] log_std = self.log_std_head(x) # 使用 clamp 將標(biāo)準(zhǔn)差限制在一個(gè)合理范圍內(nèi) log_std = torch.clamp(log_std, min=-20, max=2) std = torch.exp(log_std) return mean, std def act(self, state): mean, std = self.forward(state) # 創(chuàng)建多元高斯分布(假設(shè)各維度獨(dú)立) dist = torch.distributions.Normal(mean, std) # 重參數(shù)化技巧采樣 action = dist.rsample() # 計(jì)算對(duì)數(shù)概率(對(duì)于多維動(dòng)作,需要對(duì)數(shù)概率的和) log_prob = dist.log_prob(action).sum(dim=-1) # 如果需要將動(dòng)作限制在[-1,1],可以使用tanh,但需要修正對(duì)數(shù)概率 # action = torch.tanh(raw_action) # 更復(fù)雜的實(shí)現(xiàn)會(huì)處理tanh變換后的概率計(jì)算 return action.detach().numpy(), log_prob # 假設(shè)動(dòng)作是2維的(如:速度,方向) policy_net = ContinuousPolicyNetwork(input_dim=8, hidden_dim=128, output_dim=2) state = torch.tensor([0.1, 0.5, -0.2, ...]) action, log_prob = policy_net.act(state) print(f"Sampled continuous action: {action}")
torch.clamp 將輸入張量中的所有元素限制在一個(gè)指定的區(qū)間 [min, max] 內(nèi)。具體來(lái)說(shuō):
-
如果元素小于
min,則將其設(shè)置為min -
如果元素大于
max,則將其設(shè)置為max -
如果元素在
[min, max]范圍內(nèi),則保持不變
tanh函數(shù):

torch.distributions.Normal 表示一個(gè)一元高斯分布,由兩個(gè)參數(shù)定義:
-
loc: 分布的均值 -
scale: 分布的標(biāo)準(zhǔn)差
# 創(chuàng)建分布 mean = torch.tensor([0.0, 1.0]) std = torch.tensor([1.0, 0.5]) normal = dist.Normal(mean, std) # 1. sample() - 普通采樣 samples = normal.sample() print("Sample:", samples) # 輸出: tensor([-0.1234, 1.2345]) # 2. rsample() - 重參數(shù)化采樣(可微分) reparam_samples = normal.rsample() print("Reparameterized sample:", reparam_samples) # 輸出: tensor([0.5678, 0.8765]) # 3. sample() 批量采樣 batch_samples = normal.sample((3,)) # 采樣3次 print("Batch samples shape:", batch_samples.shape) # 輸出: torch.Size([3, 2])

浙公網(wǎng)安備 33010602011771號(hào)