PLE模型簡潔解讀
PLE模型簡潔解讀

基礎設定
- 有 2 個任務:CTR、CVR
- 使用 1 層 PLE(num_levels = 1)
- 每個任務 2 個任務特定專家(specific_expert_num = 2)
- 有 1 個共享專家(shared_expert_num = 1)
- 輸入 embedding 是:[batch_size, 64] 的拼接向量
我們來看看“這一層”里的每一個步驟數據是如何流動的。
第 1 步:準備輸入
ple_inputs = [x_ctr, x_cvr, x_shared]
- x_ctr = CTR 的輸入 = 原始 embedding 向量 [B, 64]
- x_cvr = CVR 的輸入 = 同上
- x_shared = Shared 的輸入 = 同上
注意:這三個向量在第 1 層是一樣的,但在后續層會變得不同。
第 2 步:任務專家和共享專家網絡
每個任務的 experts:
每個任務有 2 個 specific expert,輸入是自己:
- CTR 的兩個專家 → 輸入
x_ctr→ 輸出[B, 64] - CVR 的兩個專家 → 輸入
x_cvr→ 輸出[B, 64]
共享 experts:
只有 1 個共享專家,輸入是 x_shared,輸出 [B, 64]
第 3 步:Gate 網絡
我們看 CTR 任務的 gate 是怎么處理的:
CTR 的 gate 做了什么?
-
輸入:
x_ctr→ shape[B, 64] -
過一個小 DNN: 輸出變為
[B, H](比如 H=32) -
線性變換 + softmax: 輸出為
[B, 3],表示對 3 個專家的權重:- expert_1_ctr
- expert_2_ctr
- expert_shared
gate_input = DNN(...)(x_ctr) # [B, 32]
gate_weights = Dense(3, activation='softmax')(gate_input) # [B, 3]
第 4 步:Gate × Experts
將所有專家輸出堆疊:
expert_outputs = tf.stack([expert_1_ctr, expert_2_ctr, expert_shared], axis=1) # [B, 3, 64]
將 gate 權重 reshape:
gate_weights = tf.expand_dims(gate_weights, -1) # [B, 3, 1]
點乘加權求和:
fused_output = tf.reduce_sum(expert_outputs * gate_weights, axis=1) # [B, 64]
? 這就是 CTR 任務在這一層提取到的特征,來自自己和共享專家的動態組合。
CVR 任務也完全一樣,只是換成用 x_cvr 輸入,構建自己的 gate 和專家融合。
下一層(若存在):
然后這些輸出(fused_output_ctr, fused_output_cvr, fused_output_shared)會作為下一層的輸入,繼續重復這一機制。
每一層都會重新生成:
- 專家網絡(不同任務分開)
- gate(使用該層輸入為條件)
從而實現「逐層提純」。
Gate 的本質:
Gate 是一個小的 DNN 網絡,輸入是當前任務的 embedding,輸出是對所有專家的 softmax 權重
決定了“這個任務現在要聽誰的話”
PLE的pytorch實現
class Expert(nn.Module):
def __init__(self, input_dim, expert_dim):
super(Expert, self).__init__()
self.layer = nn.Sequential(
nn.Linear(input_dim, expert_dim),
nn.ReLU(),
nn.BatchNorm1d(expert_dim),
nn.Dropout(0.2)
)
def forward(self, x):
return self.layer(x)
class Gate(nn.Module):
def __init__(self, input_dim, n_experts):
super(Gate, self).__init__()
self.gate = nn.Sequential(
nn.Linear(input_dim, n_experts),
nn.Softmax(dim=-1)
)
def forward(self, x):
weights = self.gate(x) # [B, n_experts]
return weights.unsqueeze(-1) # [B, n_experts, 1]
class PLELayer(nn.Module):
def __init__(self, input_dim, expert_dim, n_tasks, n_task_experts, n_shared_experts):
super(PLELayer, self).__init__()
self.n_tasks = n_tasks
self.task_experts = nn.ModuleList([
nn.ModuleList([Expert(input_dim, expert_dim) for _ in range(n_task_experts)])
for _ in range(n_tasks)
])
self.shared_experts = nn.ModuleList([
Expert(input_dim, expert_dim) for _ in range(n_shared_experts)
])
self.task_gates = nn.ModuleList([
Gate(input_dim, n_task_experts + n_shared_experts)
for _ in range(n_tasks)
])
self.shared_gate = Gate(input_dim, n_tasks * n_task_experts + n_shared_experts)
def forward(self, task_inputs, shared_input):
# Compute expert outputs
task_outputs = []
for i in range(self.n_tasks):
task_outputs.append([expert(task_inputs[i]) for expert in self.task_experts[i]])
shared_outputs = [expert(shared_input) for expert in self.shared_experts]
# Task-specific gate outputs
next_task_inputs = []
for i in range(self.n_tasks):
all_expert_outputs = task_outputs[i] + shared_outputs
stacked = torch.stack(all_expert_outputs, dim=1) # [B, n_experts, D]
weights = self.task_gates[i](task_inputs[i]) # [B, n_experts, 1]
fused = torch.sum(stacked * weights, dim=1) # [B, D]
next_task_inputs.append(fused)
# Shared gate output (for next layer's shared input)
flat_all_experts = sum(task_outputs, []) + shared_outputs
stacked_shared = torch.stack(flat_all_experts, dim=1)
shared_weights = self.shared_gate(shared_input)
next_shared_input = torch.sum(stacked_shared * shared_weights, dim=1) # [B, D]
return next_task_inputs, next_shared_input
class PLE(nn.Module):
# 正確處理多層維度
def __init__(self, input_dim, expert_dim, n_tasks=3, n_layers=2,
n_task_experts=2, n_shared_experts=1):
super(PLE, self).__init__()
self.n_tasks = n_tasks
self.ple_layers = nn.ModuleList()
# 為每一層設置正確的輸入維度
for layer_idx in range(n_layers):
if layer_idx == 0:
# 第一層:使用原始輸入維度
current_input_dim = input_dim
else:
# 后續層:使用expert輸出維度作為輸入
current_input_dim = expert_dim
self.ple_layers.append(
PLELayer(
input_dim=current_input_dim, # 動態設置輸入維度
expert_dim=expert_dim,
n_tasks=n_tasks,
n_task_experts=n_task_experts,
n_shared_experts=n_shared_experts
)
)
def forward(self, x):
# Initial input: shared across all tasks and shared experts
task_inputs = [x for _ in range(self.n_tasks)]
shared_input = x
for layer in self.ple_layers:
task_inputs, shared_input = layer(task_inputs, shared_input)
return task_inputs # final task-specific vectors [task1_repr, task2_repr, task3_repr]
浙公網安備 33010602011771號