先訓練G:
先不計算D的梯度: 判別器輸入類型為(源域,0)或者(目標域,1),輸出圖片為真實圖片(源域)的概率值
for param in model_D.parameters(): # model_D = nn.ModuleList([FCDiscriminator...]) 判別器是一個全卷積網絡,其實就是一個二分類,輸出一個條件概率,即輸入樣本屬于源域或者目標域的概率
param.requires_grad = False 判別損失 Ld 是一個二分類交叉熵損失,判斷輸入屬于源域還是目標域
怎么才算訓練好判別器:判別器能對真圖打高分,對假圖打低分
輸入圖片:
images.size: torch.Size([1, 3, 512, 1024])
labels.size: torch.Size([1, 512, 1024])
源域圖片S 的輸出分割特征圖:
feat_source: ([1, 2048, 65, 129])
pred_source: ([1, 19, 65, 129])
輸出特征圖接一個上采樣后 pred_source 大小變成: ([1, 19, 512, 1024])
計算交叉熵損失:
loss_seg = seg_loss(pred_source, labels)
計算梯度值,并反傳梯度值: (只是計算,不更新)
loss_seg.backward()
目標域圖片T的大小、特征圖大小 和上面的源域S一樣,不同的是,經過分割網絡時,得到一個加權的特征圖(注:加權后的特征圖大小不變)
和S一樣,得到特征圖后,接一個上采樣:
pred_target = interp_target(pred_target)
先損失清零
loss_adv = 0
然后計算判別損失值,即對倒數第二層的T域特征圖打分
D_out = model_D[0](feat_target) (判別器D[0]輸入通道為2048,輸出通道為1)
再用上面的判別損失值來計算對抗損失,即用bce_loss(即均方差MSELoss())來計算D_out和source_label的分布差
loss_adv += bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).to(device)) # source_label=0
先對最后一層的T域特征圖打分:特征圖先變成概率圖(用softmax()),然后對概率圖打分
D_out = model_D[1](F.softmax(pred_target, dim=1)) (判別器D[1]輸入19,輸出1)
然后計算對抗損失:
loss_adv += bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).to(device))
loss_adv = loss_adv * 0.01
計算梯度值,并將梯度反傳:
loss_adv.backward()
更新模型參數:
optimizer.step()
再訓練D: