<output id="qn6qe"></output>

    1. <output id="qn6qe"><tt id="qn6qe"></tt></output>
    2. <strike id="qn6qe"></strike>

      亚洲 日本 欧洲 欧美 视频,日韩中文字幕有码av,一本一道av中文字幕无码,国产线播放免费人成视频播放,人妻少妇偷人无码视频,日夜啪啪一区二区三区,国产尤物精品自在拍视频首页,久热这里只有精品12

      生成式對抗網(wǎng)絡(luò)GAN

      【 李宏毅機(jī)器學(xué)習(xí)】生成式對抗網(wǎng)絡(luò)GAN

        在傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)任務(wù)中,我們通常把一個網(wǎng)絡(luò)當(dāng)作一個函數(shù)f(x),給定輸入x,網(wǎng)絡(luò)就會輸出一個對應(yīng)的結(jié)果 y。比如圖像分類任務(wù)中,輸入是一張圖片,輸出是一個分類標(biāo)簽。這是一種 判別式模型(Discriminative Model),它學(xué)的是輸入和輸出之間的映射關(guān)系。但在生成式模型(Generative Model) 中,輸入會增加一個隨機(jī)分布中sample出來的z,網(wǎng)絡(luò)輸入x和z,輸出y是可以從中采樣的復(fù)雜分布。

      網(wǎng)絡(luò)通常會采用兩種方式來處理x和z(1)拼接(Concatenate):直接把 兩個向量拼接在一起,變成一個更長的向量,輸入到神經(jīng)網(wǎng)絡(luò)中。(2)相加(Element-wise Add):如果x和 z維度相同,可以直接相加作為輸入。我們的目標(biāo)不再是“判斷”某個輸入屬于哪一類,而是希望模型能夠“生成”數(shù)據(jù)——比如生成看起來真實(shí)的圖片、音頻,甚至文本。換句話說,我們希望網(wǎng)絡(luò)本身就是一個“生成器”,可以從某種潛在的隨機(jī)性中創(chuàng)造出無限多樣的輸出。

      1 GAN(Generative Adversarial Network)

      1.1 GAN 的基本概念和工作原理

      生成對抗網(wǎng)絡(luò)(GAN)是由 Ian Goodfellow 等人在 2014 年提出的一種生成式模型。與傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)不同,GAN 由兩個相互競爭的網(wǎng)絡(luò)組成:生成器(Generator)判別器(Discriminator)

      以生成二次元人臉為例:

      生成器(Generator):把x拿掉,生成器只輸入隨機(jī)噪聲 z ,這種叫做unconditional generation。假設(shè)z是從normal distribution中sample出來的向量,這個向量一般是low-dim的向量,維度是自定義的,Generator輸入z后產(chǎn)生一個64x64x3的向量,整理后可以得到一張二次元人臉的圖像。

      判別器(Discriminator):對應(yīng)的判別器輸入一張圖片(可能是來自真實(shí)數(shù)據(jù)集,也可能是來自生成器),輸出一個數(shù)字,數(shù)字越大表示輸入的圖像越像真實(shí)的二次元的人臉。

      1.2 GAN 的訓(xùn)練機(jī)制

      • 初始化生成器和判別器參數(shù)
      • 在每個訓(xùn)練迭代中
        • 固定住生成器,更新判別器。具體的隨機(jī)采樣一些向量z,輸入到生成器,得到一些生成的圖像,然后從真實(shí)數(shù)據(jù)中采樣一些二次元人臉,訓(xùn)練判別器分別兩者之間的差異。比如用二分類器,或者邏輯回歸。

              

        •  固定住判別器,更新生成器。具體的把兩個網(wǎng)絡(luò)接起來變成一個大網(wǎng)絡(luò),其中判別器的參數(shù)是固定的,訓(xùn)練生成器,使得分?jǐn)?shù)越大越好。

                    

      GAN的訓(xùn)練目標(biāo)是通過對抗訓(xùn)練,生成器和判別器在博弈中不斷提高自己的能力:

      • 生成器的目標(biāo):生成足夠真實(shí)的數(shù)據(jù),使得判別器無法區(qū)分它們是來自真實(shí)數(shù)據(jù)還是生成的假數(shù)據(jù)。

      假設(shè)生成的數(shù)據(jù)分布為PG,真實(shí)數(shù)據(jù)分布為Pdata,生成器的目標(biāo)就是

      Divergence是PG和Pdata之間的某種距離,距離越小兩個分布就越相近。盡管不知道PG和Pdata是什么樣的分布,但是可以從中sample。

      • 判別器的目標(biāo):盡可能準(zhǔn)確地區(qū)分真實(shí)數(shù)據(jù)和生成數(shù)據(jù),減少自己的分類錯誤。

      分別從PG和Pdata中sample,Discriminator要學(xué)會給PG和中sample到的樣本打1,給Pdata中sample到的樣本打0,要做的是一個二分類任務(wù)。判別器的總 BCE 損失(對所有樣本)為

         損失是要最小化,等價于最大化目標(biāo)函數(shù)V:

       

         這個maxV和JS divergence有關(guān),因此生成器的目標(biāo)就是

      之前的訓(xùn)練步驟就是在解這個minmax問題,為什么訓(xùn)練步驟可以解這個函數(shù),見論文推導(dǎo)。設(shè)計(jì)不同的objective function解minmax問題,就對應(yīng)不同的divergence(參考)。

      1.3 WGAN:基于 Wasserstein 距離的改進(jìn)方法

      在大多數(shù)情況下,PG和Pdata之間是沒有重疊的,兩個都是高維空間的低維manifold,相當(dāng)于二維空間的兩條直線。即使兩個分布是有重疊的,但是sample的點(diǎn)不夠多,重疊的范圍也非常小。當(dāng)兩個分布沒有重疊時,JS divergence是存在問題的。 JS divergence是非度量的,不能為生成器提供一個明確的目標(biāo)方向來指導(dǎo)其優(yōu)化。比如下面的情況,PG和Pdata越來越接近,但是JS始終為log2,知道兩者重合才會變成0。

      為了改善 GAN 的訓(xùn)練穩(wěn)定性,研究者提出了Wasserstein GAN(WGAN)。它的核心思想是:

      替換原本的 JS 散度為 Wasserstein 距離(又叫 Earth Mover’s Distance)。

      你可以把它理解為:把一堆土(生成的數(shù)據(jù)分布)搬到另一堆土堆(真實(shí)數(shù)據(jù)分布)需要花費(fèi)的最小“搬運(yùn)成本”。這是一種更光滑、梯度更穩(wěn)定的距離度量方式。

      計(jì)算Wasserstein 距離就是解下面的公式

      注意:

      • WGAN 中的判別器不再是“真假分類器”,而是一個 Critic,它的輸出是任意實(shí)數(shù),用于度量樣本的“真實(shí)性”得分。

      • 為了滿足理論條件,Critic 的梯度需要是1-Lipschitz的。這個限制是讓Discriminator變得平滑,如果沒有這個限制,D會給生成的x負(fù)無窮,給真實(shí)的x正無窮,訓(xùn)練會無法收斂,max始終是無窮大。

      ??Q: 如何讓Discriminator變得比較平滑?

      最早用權(quán)重裁剪,限制權(quán)重在-c到c之間,后續(xù)改進(jìn)為 WGAN-GP(加入梯度懲罰項(xiàng)),訓(xùn)練更穩(wěn)定,另一個常用、效果非常不錯的方式是:Spectral Normalization,歸一化權(quán)重的最大奇異值,不需要像 WGAN-GP 那樣計(jì)算復(fù)雜的梯度懲罰項(xiàng),也避免了 weight clipping 帶來的訓(xùn)練困難。

      1.4 GAN 面臨的挑戰(zhàn)

      盡管已經(jīng)有了如WGAN這樣的改進(jìn)方法,GAN在訓(xùn)練過程中仍然面臨諸多挑戰(zhàn)。這主要源于其對抗性訓(xùn)練機(jī)制的本質(zhì)特性:生成器(Generator)與判別器(Discriminator)在訓(xùn)練過程中是相互博弈、彼此促進(jìn)的關(guān)系。模型的優(yōu)化是一個動態(tài)博弈過程,如果其中一方(例如判別器)訓(xùn)練不足或性能不穩(wěn)定,就會導(dǎo)致另一方(如生成器)無法獲得有效的反饋信號,從而影響整體訓(xùn)練效果。這種相互依賴使得GAN的訓(xùn)練過程高度不穩(wěn)定,調(diào)試和收斂都較為困難。目前主流的生成模型,除了GAN,還有變分自編碼器(VAE)流模型(Flow-based Models)。GAN在圖像生成質(zhì)量上往往優(yōu)于VAE和Flow模型,但它的訓(xùn)練更不穩(wěn)定、可解釋性更弱。

      ??Q: 為什么GAN難用于文本生成?

      答:因?yàn)槲谋臼?/strong>離散的,而GAN的訓(xùn)練依賴于梯度的反向傳播。在文本生成中,生成器通常輸出一個概率分布,再通過 argmax采樣 選擇一個詞。這種“選詞”的過程是非可導(dǎo)的,梯度無法穿過這個離散選擇,從而導(dǎo)致生成器無法優(yōu)化。在多個值相等或者接近的時候,梯度無法確定往哪個方向優(yōu)化。比如在 argmax 選詞時,即便概率稍微改變,只要最大值沒變,輸出結(jié)果也不會變,因此梯度是 0 或未定義的,無法有效訓(xùn)練生成器。

      ??Q: 那CNN里也有max pooling,為什么沒問題?

      答:CNN 中的 max pooling 雖然也是 max 操作,但它出現(xiàn)在網(wǎng)絡(luò)的中間層,反向傳播時我們可以將梯度傳給最大值的位置,其他位置設(shè)為 0(這叫次梯度 subgradient)。這種近似梯度在實(shí)踐中效果不錯,因此不會影響 CNN 的訓(xùn)練。

      模式坍縮 (Mode Collapse)

      Mode Collapse 是指生成器只學(xué)會生成真實(shí)數(shù)據(jù)中的一小部分模式,導(dǎo)致輸出缺乏多樣性。換句話說,雖然真實(shí)數(shù)據(jù)有很多種可能,但生成器反復(fù)生成的是同一種或幾種“看起來不錯”的樣本。
      例如:你訓(xùn)練 GAN 生成手寫數(shù)字,但生成器最終只會生成“數(shù)字 3”,而忽略了其他數(shù)字。這樣雖然圖片質(zhì)量可能還可以,但多樣性完全丟失。這是因?yàn)樯善髦粚W⒂凇膀_過判別器”的目標(biāo),而不是完整地復(fù)現(xiàn)數(shù)據(jù)分布。一旦找到某個容易成功的樣本類型,就會反復(fù)生成,從而陷入局部最優(yōu)。

      模式遺漏(Mode Dropping)

      Mode Dropping 指的是生成器完全忽略了真實(shí)數(shù)據(jù)中的某些模式,即使這些模式在訓(xùn)練數(shù)據(jù)中是存在的,生成器卻沒有學(xué)會去生成它們。假設(shè)你訓(xùn)練一個 GAN 模型用于生成真實(shí)人臉,訓(xùn)練數(shù)據(jù)中包含了不同年齡、性別、膚色、發(fā)型的人臉圖像。但訓(xùn)練后的生成器:只會生成某一膚色的面孔。盡管這些樣本在真實(shí)數(shù)據(jù)中是存在的,生成器卻忽略了這些“模式”。這就屬于典型的 Mode Dropping —— 生成結(jié)果看起來多樣,但其實(shí)缺失了某些重要的群體特征

      1.5 GAN評估指標(biāo)

      GAN 在訓(xùn)練中容易出現(xiàn)模式坍縮(Mode Collapse)模式遺漏(Mode Dropping)等問題,即生成器生成的樣本看起來質(zhì)量不錯,但實(shí)則重復(fù)或覆蓋不全。這時我們就需要一些定量評估指標(biāo)來判斷兩個關(guān)鍵問題:(1) 生成圖像質(zhì)量好不好?(2) 生成圖像夠不夠多樣?不像分類器有準(zhǔn)確率指標(biāo),GAN 的 Generator 沒有明確的評價指標(biāo),但有一些常用方法可以參考:

      Inception Score(IS)

      • 用一個預(yù)訓(xùn)練的分類器(如 Inception-v3)去分類生成的圖像
      • 一張圖片丟到CNN去分類,結(jié)果分布越集中,quality越高
      • 一堆圖片的平均分布,越平均diversity越大
      • good quality,large diversity→large IS,說明生成樣本質(zhì)量越好、類別多樣。

      在二次元人臉生成中,分類器的輸出可能都是人臉,diversity小,不適合這個場景。

      Frechet Inception Distane(FID)

      • 將生成圖像和真實(shí)圖像分別輸入一個預(yù)訓(xùn)練好的 CNN(通常是 Inception v3);

      • 在網(wǎng)絡(luò)的某一層(通常是 softmax 前的一層)提取特征向量(即使輸入圖像都是人臉,這些向量也會有所不同,因?yàn)樗鼈儾蹲降氖歉邔拥恼Z義信息(比如臉的姿態(tài)、表情、風(fēng)格等)

      • 假設(shè)這兩組特征向量分別服從一個多維高斯分布,F(xiàn)ID 會分別計(jì)算這兩組特征向量的 均值(μ)協(xié)方差矩陣(Σ),然后用 Frechet 距離來衡量它們之間的差異

      如果 FID 值很小,說明生成的圖像和真實(shí)圖像非常接近,但這并不意味著生成器生成的樣本多樣化。生成器可能只是記住了真實(shí)數(shù)據(jù)的某些特征,導(dǎo)致它只能生成“很相似”的樣本,而失去了多樣性,比如學(xué)會了對真實(shí)圖像進(jìn)行翻轉(zhuǎn)。真實(shí)特征分布可能遠(yuǎn)非高斯,所以在做次元人臉生成中主要是用FID和人眼就去看。

      2 Conditional GAN(條件生成對抗網(wǎng)絡(luò))

      傳統(tǒng)GAN中,生成器是無條件的——它只接收隨機(jī)噪聲z作為輸入。而 Conditional GAN 則引入條件信息x,例如類別標(biāo)簽、文本描述、圖像等,引導(dǎo)生成器生成“符合條件”的樣本。

      應(yīng)用: 文本生成圖像

      需要收集一些圖片和對應(yīng)的標(biāo)注, 輸入x是一段文字red eyes,可以用rnn或者transformer encoder把它變成一段向量。期望輸入red eyes,generator就輸出一個紅眼睛的圖片,每次的輸出都是不一樣的紅眼睛,取決于sample到不一樣的z。

      ??Q: 如何訓(xùn)練Conditional GAN?

      如果像之前的GAN一樣,判別器只判斷圖像是否是真實(shí)的,生成器就不用在意輸入x,只要產(chǎn)生清晰的圖像就可以。在 Conditional GAN 中,我們的目標(biāo)不只是生成“看起來真實(shí)的圖像”,更重要的是:圖像還要和輸入的條件匹配。需要準(zhǔn)備文字和圖像成對的資料(positive),以及文字和機(jī)器產(chǎn)生出來的圖片(negative),還需要把文字和圖像亂配作為(negative)。

      應(yīng)用: 圖像生成圖像 (image translation or pix2pix)

      比如輸入黑白的圖片讓生成器著色,或者對圖像去霧。輸入一張圖片生成一張圖片,可以用supervised的方法,由于同樣的輸入可能對應(yīng)到不一樣的輸出,機(jī)器學(xué)到把所有的可能平均起來,所以產(chǎn)生的圖片會很模糊。如果用GAN的方法,再加入判別器,判別器輸入生成器生成的圖像和condition,然后輸出分?jǐn)?shù)。GAN方法產(chǎn)生的圖像比較清楚,但是可能會產(chǎn)生輸入沒有的東西,比如下面的房子左上角有奇怪的東西。當(dāng)GAN和supervised同時使用,效果會比較好,生成器在訓(xùn)練的時候一方面要騙過判別器,但又要使得產(chǎn)生的圖片和目標(biāo)越接近越好。

      3 CycleGAN:無監(jiān)督圖像到圖像的轉(zhuǎn)換

        在之前的unconditional generation中,輸入是一個簡單分布,輸出是一個復(fù)雜分布,現(xiàn)在稍微轉(zhuǎn)換下,輸入是x domain圖片的分布,輸出是y domain圖片的分布。假設(shè)我們將真實(shí)人臉轉(zhuǎn)換為二次元人臉,像之前訓(xùn)練GAN一樣,從x domain中sample一張圖片,輸入到Generator,再用學(xué)過ydomain的discriminator去給圖像打分,這是有問題的。Generator會無視輸入的圖片,只產(chǎn)生一張像y domain的二次元圖片就可以,這個圖片和輸入的真實(shí)人臉沒有。如何強(qiáng)化輸入和輸出的關(guān)系呢?之前的conditional gan也講過類似的問題,但是現(xiàn)在沒有成對的數(shù)據(jù)去訓(xùn)練discriminator學(xué)習(xí)輸入和輸出的關(guān)系。

        CycleGAN(由 Jun-Yan Zhu 等人在 2017 年提出,旨在解決沒有成對數(shù)據(jù)的圖像轉(zhuǎn)換問題。傳統(tǒng)的圖像到圖像的轉(zhuǎn)換任務(wù)(比如圖像翻譯、風(fēng)格轉(zhuǎn)換)通常需要輸入和目標(biāo)圖像一一對應(yīng)的成對數(shù)據(jù)。然而,在實(shí)際應(yīng)用中,往往沒有這樣的配對數(shù)據(jù)集,這就使得傳統(tǒng)的 GAN 方法難以應(yīng)用。CycleGAN 通過設(shè)計(jì)一對生成器和判別器,解決了這一問題。它的核心思想是 循環(huán)一致性(Cycle Consistency),即:如果將一張圖像從一個域映射到另一個域,然后再映射回來,應(yīng)該得到原始的圖像。這樣的映射關(guān)系可以確保生成的圖像既能保持域間的轉(zhuǎn)換效果,又不失原始圖像的特征。以真實(shí)人臉轉(zhuǎn)換為二次元人臉為例,在訓(xùn)練的時候增加一個額外的目標(biāo),輸入在經(jīng)過x到y(tǒng) domain的轉(zhuǎn)換后再經(jīng)過y到xdomain的轉(zhuǎn)換,兩次轉(zhuǎn)換后和原來的向量越接近越好,同時也可以做另一個方向上的訓(xùn)練。

      4 作業(yè)HW6

      鏈接給出了PDF和code:李宏毅2021&2022機(jī)器學(xué)習(xí)

      代碼:

      import random
      
      import torch
      import numpy as np
      import os
      import glob
      
      import torch.nn as nn
      import torch.nn.functional as F
      import torchvision
      import torchvision.transforms as transforms
      from torch import optim
      from torch.autograd import Variable
      from torch.utils.data import Dataset, DataLoader
      import matplotlib.pyplot as plt
      from tqdm import tqdm
      
      
      def same_seeds(seed):
          # Python built-in random module
          random.seed(seed)
          # Numpy
          np.random.seed(seed)
          # Torch
          torch.manual_seed(seed)
          if torch.cuda.is_available():
              torch.cuda.manual_seed(seed)
              torch.cuda.manual_seed_all(seed)
          torch.backends.cudnn.benchmark = False
          torch.backends.cudnn.deterministic = True
      
      
      same_seeds(2021)
      
      
      class CrypkoDataset(Dataset):
          def __init__(self, fnames, transform):
              self.transform = transform
              self.fnames = fnames
              self.num_samples = len(self.fnames)
      
          def __getitem__(self, idx):
              fname = self.fnames[idx]
              # 1. Load the image
              img = torchvision.io.read_image(fname)
              # 2. Resize and normalize the images using torchvision.
              img = self.transform(img)
              return img
      
          def __len__(self):
              return self.num_samples
      
      
      def get_dataset(root):
          fnames = glob.glob(os.path.join(root, '*'))
          # 1. Resize the image to (64, 64)
          # 2. Linearly map [0, 1] to [-1, 1]
          compose = [
              transforms.ToPILImage(),
              transforms.Resize((64, 64)),
              transforms.ToTensor(),
              transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
          ]
          transform = transforms.Compose(compose)
          dataset = CrypkoDataset(fnames, transform)
          return dataset
      
      
      def weights_init(m):
          classname = m.__class__.__name__
          if classname.find('Conv') != -1:
              m.weight.data.normal_(0.0, 0.02)
          elif classname.find('BatchNorm') != -1:
              m.weight.data.normal_(1.0, 0.02)
              m.bias.data.fill_(0)
      
      
      class Generator(nn.Module):
          """
          Input shape: (N, in_dim)
          Output shape: (N, 3, 64, 64)
          """
      
          def __init__(self, in_dim, dim=64):
              super(Generator, self).__init__()
      
              def dconv_bn_relu(in_dim, out_dim):
                  return nn.Sequential(
                      nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
                                         padding=2, output_padding=1, bias=False),
                      nn.BatchNorm2d(out_dim),
                      nn.ReLU()
                  )
      
              self.l1 = nn.Sequential(
                  nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False),
                  nn.BatchNorm1d(dim * 8 * 4 * 4),
                  nn.ReLU()
              )
              self.l2_5 = nn.Sequential(
                  dconv_bn_relu(dim * 8, dim * 4),
                  dconv_bn_relu(dim * 4, dim * 2),
                  dconv_bn_relu(dim * 2, dim),
                  nn.ConvTranspose2d(dim, 3, 5, 2, padding=2, output_padding=1),
                  nn.Tanh()
              )
              self.apply(weights_init)
      
          def forward(self, x):
              y = self.l1(x)
              y = y.view(y.size(0), -1, 4, 4)
              y = self.l2_5(y)
              return y
      
      
      class Discriminator(nn.Module):
          """
          Input shape: (N, 3, 64, 64)
          Output shape: (N, )
          """
      
          def __init__(self, in_dim, dim=64, use_sigmoid=True):
              super(Discriminator, self).__init__()
      
              def conv_bn_lrelu(in_dim, out_dim):
                  return nn.Sequential(
                      nn.Conv2d(in_dim, out_dim, 5, 2, 2),
                      nn.BatchNorm2d(out_dim),
                      nn.LeakyReLU(0.2),
                  )
      
              """ Medium: Remove the last sigmoid layer for WGAN. """
              layers = [
                  nn.Conv2d(in_dim, dim, 5, 2, 2),
                  nn.LeakyReLU(0.2),
                  conv_bn_lrelu(dim, dim * 2),
                  conv_bn_lrelu(dim * 2, dim * 4),
                  conv_bn_lrelu(dim * 4, dim * 8),
                  nn.Conv2d(dim * 8, 1, 4),
              ]
      
              if use_sigmoid:
                  layers.append(nn.Sigmoid())
      
              self.ls = nn.Sequential(*layers)
      
              self.apply(weights_init)
      
          def forward(self, x):
              y = self.ls(x)
              y = y.view(-1)
              return y
      
      
      def train(baseline="Simple", show_img=True):
          # Training hyperparameters
          batch_size = 64
          z_sample = Variable(torch.randn(100, z_dim)).cuda()
          lr = 1e-4
      
          if baseline == "Simple":
              n_epoch = 50  # 50
              n_critic = 1  # 訓(xùn)練 1 次判別器,再訓(xùn)練 1 次生成器
          elif baseline == "Medium":
              """ Medium: WGAN, 50 epoch, n_critic=5, clip_value=0.01 """
              n_epoch = 50
              n_critic = 5  # 先訓(xùn)練 5 次判別器,再訓(xùn)練 1 次生成器
              clip_value = 0.01
      
          # Model
          G = Generator(in_dim=z_dim).cuda()
          if baseline == "Simple":
              D = Discriminator(3).cuda()
          elif baseline == "Medium":
              D = Discriminator(3, use_sigmoid=False).cuda()
          G.train()
          D.train()
      
          # Loss
          criterion = nn.BCELoss()
      
          # Optimizer
          if baseline == "Simple":
              opt_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
              opt_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
          elif baseline == "Medium":
              """ Medium: Use RMSprop for WGAN. """
              opt_D = torch.optim.RMSprop(D.parameters(), lr=lr)
              opt_G = torch.optim.RMSprop(G.parameters(), lr=lr)
      
          # DataLoader
          dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
      
          steps = 0
          for e, epoch in enumerate(range(n_epoch)):
              progress_bar = tqdm(dataloader)
              for i, data in enumerate(progress_bar):
                  imgs = data
                  imgs = imgs.cuda()
      
                  bs = imgs.size(0)
      
                  # ============================================
                  #  Train D
                  # ============================================
                  z = Variable(torch.randn(bs, z_dim)).cuda()
                  r_imgs = Variable(imgs).cuda()
                  f_imgs = G(z)
      
                  if baseline == "Simple":
                      # Label
                      r_label = torch.ones((bs)).cuda()
                      f_label = torch.zeros((bs)).cuda()
      
                      # Model forwarding
                      r_logit = D(r_imgs.detach())
                      f_logit = D(f_imgs.detach())
      
                      # Compute the loss for the discriminator.
                      r_loss = criterion(r_logit, r_label)
                      f_loss = criterion(f_logit, f_label)
                      loss_D = (r_loss + f_loss) / 2
                  elif baseline == "Medium":
                      # WGAN Loss
                      loss_D = -torch.mean(D(r_imgs)) + torch.mean(D(f_imgs))
      
                  # Model backwarding
                  D.zero_grad()
                  loss_D.backward()
      
                  # Update the discriminator.
                  opt_D.step()
      
                  if baseline == "Medium":
                      """ Medium: Clip weights of discriminator. """
                      for p in D.parameters():
                         p.data.clamp_(-clip_value, clip_value)
      
                  # ============================================
                  #  Train G
                  # ============================================
                  if steps % n_critic == 0:
                      # Generate some fake images.
                      z = Variable(torch.randn(bs, z_dim)).cuda()
                      f_imgs = G(z)
      
                      # Model forwarding
                      f_logit = D(f_imgs)
      
                      if baseline == "Simple":
                          # Compute the loss for the generator.
                          loss_G = criterion(f_logit, r_label)
                      elif baseline == "Medium":
                          # WGAN Loss
                          loss_G = -torch.mean(D(f_imgs))
      
                      # Model backwarding
                      G.zero_grad()
                      loss_G.backward()
      
                      # Update the generator.
                      opt_G.step()
      
                  steps += 1
      
                  # Set the info of the progress bar
                  #   Note that the value of the GAN loss is not directly related to
                  #   the quality of the generated images.
                  progress_bar.set_postfix({
                      'Loss_D': round(loss_D.item(), 4),
                      'Loss_G': round(loss_G.item(), 4),
                      'Epoch': e + 1,
                      'Step': steps,
                  })
      
              G.eval()
              f_imgs_sample = (G(z_sample).data + 1) / 2.0
              filename = os.path.join(log_dir, f'Epoch_{epoch + 1:03d}.jpg')
              torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)
              print(f' | Save some samples to {filename}.')
      
              # Show generated images in the jupyter notebook.
              if show_img:
                  grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10)
                  plt.figure(figsize=(10, 10))
                  plt.imshow(grid_img.permute(1, 2, 0))
                  plt.show()
              G.train()
      
              if (e + 1) % 5 == 0 or e == 0:
                  # Save the checkpoints.
                  torch.save(G.state_dict(), os.path.join(ckpt_dir, 'G.pth'))
                  torch.save(D.state_dict(), os.path.join(ckpt_dir, 'D.pth'))
      
      
      def inference():
          G = Generator(z_dim)
          G.load_state_dict(torch.load(os.path.join(ckpt_dir, 'G.pth')))
          G.eval()
          G.cuda()
          # Generate 1000 images and make a grid to save them.
          n_output = 1000
          z_sample = Variable(torch.randn(n_output, z_dim)).cuda()
          imgs_sample = (G(z_sample).data + 1) / 2.0
          log_dir = os.path.join('logs')
          filename = os.path.join(log_dir, 'result.jpg')
          torchvision.utils.save_image(imgs_sample, filename, nrow=10)
      
          # Show 30 of the images.
          grid_img = torchvision.utils.make_grid(imgs_sample[:30].cpu(), nrow=10)
          plt.figure(figsize=(10, 10))
          plt.imshow(grid_img.permute(1, 2, 0))
          plt.show()
      
      
      if __name__ == '__main__':
          dataset = get_dataset('faces')
      
          # 注意,這些數(shù)值的范圍是 [-1, 1],所以顯示比較暗
          # images = [dataset[i] for i in range(16)]
          # grid_img = torchvision.utils.make_grid(images, nrow=4)
          # plt.figure(figsize=(10, 10))
          # plt.imshow(grid_img.permute(1, 2, 0))
          # plt.show()
      
          # 我們需要將它們轉(zhuǎn)換到有效的范圍 [0, 1],才能正確顯示。
          # images = [(dataset[i] + 1) / 2 for i in range(16)]
          # grid_img = torchvision.utils.make_grid(images, nrow=4)
          # plt.figure(figsize=(10, 10))
          # plt.imshow(grid_img.permute(1, 2, 0))
          # plt.show()
      
          z_dim = 100
          log_dir = os.path.join('logs')
          ckpt_dir = os.path.join('checkpoints')
          os.makedirs(log_dir, exist_ok=True)
          os.makedirs(ckpt_dir, exist_ok=True)
      
          train(baseline="Medium", show_img=False)
      
          inference()
      View Code

      兩個指標(biāo):

      • FID(Frechet Inception Distance):衡量生成圖片和真實(shí)圖片的差異,越低越好。

      • AFD(Attribute FID Distance):衡量生成圖像屬性的多樣性或質(zhì)量,越高越好。

      項(xiàng)目評分標(biāo)準(zhǔn)分?jǐn)?shù)
      ? 代碼部分 Code 提交完整、可運(yùn)行的代碼 4 分
      ? 簡單基準(zhǔn) Simple FID ≤ 30000 且 AFD ≥ 0.00 2 分
      ? 中等基準(zhǔn) Medium FID ≤ 11800 且 AFD ≥ 0.43 2 分
      ? 強(qiáng)基準(zhǔn) Strong FID ≤ 9300 且 AFD ≥ 0.53 1 分
      ? 最強(qiáng)基準(zhǔn) Boss FID ≤ 8200 且 AFD ≥ 0.68 1 分
      ?? 額外加分 Bonus 擊敗 Boss 基準(zhǔn) + 提交 < 100 字的英文 PDF 報(bào)告 0.5 分
       代碼給出的例子是DCGAN,用WGAN可以達(dá)到medium分?jǐn)?shù),要做的修改主要是
      • 從判別器中移除最后的 sigmoid 層。
      • 計(jì)算損失時不取對數(shù)(log)。
      • 將判別器的權(quán)重裁剪到一個常數(shù)范圍內(nèi)。
      • 使用 RMSProp 或 SGD 作為優(yōu)化器。

      用Spectral Normalization GAN (SNGAN)可以達(dá)到strong分?jǐn)?shù),主要是在判別器的每一層的權(quán)重進(jìn)行譜歸一化(Spectral Normalization)。

       

      訓(xùn)練50epoch結(jié)果

       

      訓(xùn)練50epoch結(jié)果

       

      posted @ 2025-04-21 09:34  灣仔碼農(nóng)  閱讀(426)  評論(0)    收藏  舉報(bào)
      主站蜘蛛池模板: 麻豆成人久久精品二区三| 俄罗斯美女真人性做爰| 成人免费亚洲av在线| 国产一区二区日韩在线| xxxx丰满少妇高潮| 久久精品高清一区二区三区| 欧美人与动牲交精品| 中文字幕久区久久中文字幕| 国产一区二区日韩在线| 亚洲欧洲精品一区二区| 中文字幕亚洲精品乱码| 国产亚洲精品成人aa片新蒲金| 国产va在线观看免费| 亚欧美闷骚院| 自拍视频亚洲精品在线| 两个人的视频www免费| 中文字幕亚洲制服在线看 | 伊人中文在线最新版天堂| 白丝乳交内射一二三区| 全黄h全肉边做边吃奶视频 | 国产精品多p对白交换绿帽| 国产成人8x视频网站入口| 国产精品国产三级国产午| 天堂av最新版中文在线| 无码人妻精品一区二区三区蜜桃| 成人午夜大片免费看爽爽爽| 国产人妻人伦精品婷婷| 狠狠色狠狠色综合久久蜜芽| 国产在线拍揄自揄拍无码视频| 精品日本乱一区二区三区| 动漫av网站免费观看| 在线观看国产午夜福利片| 日韩秘 无码一区二区三区| 精品视频一区二区三区不卡 | 秋霞人妻无码中文字幕| 亚洲第一区二区国产精品| 日韩大片高清播放器| 久久精品丝袜高跟鞋| 日韩人妻精品中文字幕专区| 日区中文字幕一区二区| 一区二区三区四区亚洲自拍 |