<output id="qn6qe"></output>

    1. <output id="qn6qe"><tt id="qn6qe"></tt></output>
    2. <strike id="qn6qe"></strike>

      亚洲 日本 欧洲 欧美 视频,日韩中文字幕有码av,一本一道av中文字幕无码,国产线播放免费人成视频播放,人妻少妇偷人无码视频,日夜啪啪一区二区三区,国产尤物精品自在拍视频首页,久热这里只有精品12

      重生之從零開始的神經網絡算法學習之路——第八篇 大型數據集與復雜模型的GPU訓練實踐

      引言

      在前一篇中,我們實現了基礎的SRCNN超分辨率模型并掌握了后臺訓練技巧。本篇將進一步拓展實驗規模:引入更大規模的數據集、實現更復雜的網絡結構,并優化GPU訓練策略,以應對更具挑戰性的圖像重建任務。通過這些實踐,我們將深入理解大規模深度學習實驗的關鍵技術和工程細節。

      大型數據集的獲取與處理

      適合超分辨率任務的大型數據集

      為了提升模型泛化能力,我們可以使用以下大型數據集:

      1.** DIV2K擴展集 :包含1000張高分辨率訓練圖像和100張驗證圖像(2K分辨率)
      2.
      Flickr2K :2650張來自Flickr的高分辨率自然圖像(4K及以上)
      3.
      CelebA-HQ :30,000張高質量人臉圖像(1024x1024分辨率)
      4.
      ImageNet **:百萬級通用圖像數據集(可用于預訓練)

      代碼實現

      git clone https://gitee.com/cmx1998/py-torch-learning.git
      cd py-torch-learning/codes/esrgan-project
      

      自動下載與解壓實現

      import os
      import wget
      import zipfile
      import tarfile
      from tqdm import tqdm
      
      # 數據集下載配置
      DATASETS = {
          "DIV2K": {
              "train": "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip",
              "valid": "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip"
          },
          "Flickr2K": {
              "url": "https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar"
          }
      }
      
      def download_dataset(url, save_dir, filename=None):
          """帶進度條的數據集下載函數"""
          # 根據輸入參數save_dir,新建存儲目錄
          os.makedirs(save_dir, exist_ok=True)
          
          # 傳入文件名參數檢查
          if not filename:
              filename = url.split("/")[-1]
          save_path = os.path.join(save_dir, filename)
          
          # 文件存在性檢查
          if os.path.exists(save_path):
              print(f"文件 {filename} 已存在,跳過下載")
              return save_path
      
          print(f"開始下載 {filename}...")
          # 使用tqdm創建進度條,減少輸出頻率
          with tqdm(total=100, desc=f"下載 {filename}", unit="%") as pbar:
              def progress_bar(current, total, width=80):
                  progress = current / total * 100
                  pbar.n = int(progress)
                  pbar.update(0)      # 只更新進度條顯示,不產生新輸出
                  
              wget.download(url, save_path, bar=progress_bar)
          print(f"\n{filename} 下載完成")
          return save_path
      
      def extract_archive(file_path, extract_dir):
          """解壓數據集文件"""
          # 根據輸入參數extract_dir,新建解壓目錄
          os.makedirs(extract_dir, exist_ok=True)
          filename = os.path.basename(file_path)
          # 生成解壓后根目錄的標識(根據壓縮包名判斷)
          extract_flag = os.path.join(extract_dir, f".{filename}.extracted")  # 標記文件
          
          # 檢查是否已解壓(通過標記文件判斷)
          if os.path.exists(extract_flag):
              print(f"文件 {filename} 已解壓,跳過解壓!")
              return
          
          # 執行解壓
          try:
              # 文件名后綴檢查
              if file_path.endswith(".zip"):
                  """處理zip壓縮文件"""
                  with zipfile.ZipFile(file_path, 'r') as zip_ref:
                      # 顯示解壓進度
                      for file in tqdm(zip_ref.namelist(), desc="解壓中"):
                          zip_ref.extract(file, extract_dir)
              elif file_path.endswith(".tar") or file_path.endswith(".tar.gz"):
                  """處理tar壓縮文件"""
                  with tarfile.open(file_path, 'r') as tar_ref:
                      # 顯示解壓進度
                      members = tar_ref.getmembers()
                      for member in tqdm(members, desc="解壓中"):
                          tar_ref.extract(member, extract_dir)
                          
              # 解壓成功后創建標記文件
              with open(extract_flag, 'w') as f:
                  f.write("Extracted successfully")
              print(f"文件 {os.path.basename(file_path)} 解壓完成")
          except Exception as e:
              print(f"解壓失敗:{e}")
              # 失敗時刪除標記文件(避免誤判)
              if os.path.exists(extract_flag):
                  os.remove(extract_flag)
      
      def prepare_large_datasets(base_dir):
          """準備所有大型數據集"""
          # 下載DIV2K
          div2k_dir = os.path.join(base_dir, "DIV2K")
          for split, url in DATASETS["DIV2K"].items():
              file_path = download_dataset(url, div2k_dir)
              extract_archive(file_path, os.path.join(div2k_dir, split))
          
          # 下載Flickr2K
          flickr_dir = os.path.join(base_dir, "Flickr2K")
          flickr_url = DATASETS["Flickr2K"]["url"]
          file_path = download_dataset(flickr_url, flickr_dir)
          extract_archive(file_path, flickr_dir)
          
          print("所有數據集準備完成")
      

      高效數據加載策略

      對于大型數據集,需要優化數據加載流程以充分利用GPU:

      from torch.utils.data import ConcatDataset
      
      class CombinedDataset(Dataset):
          """組合多個數據集的包裝類"""
          def __init__(self, dataset_paths, scale_factor=4, patch_size=128, train=True, 
                       augment=True, cache_in_memory=False):
              self.datasets = []
              for path in dataset_paths:
                  # 路徑檢查
                  if os.path.exists(path):
                      dataset = SuperResolutionDataset(
                          path, 
                          scale_factor=scale_factor,
                          patch_size=patch_size,
                          train=train,
                          augment=augment,
                          cache_in_memory=cache_in_memory
                      )
                      self.datasets.append(dataset)
                  else:
                      print(f"警告: 數據集路徑不存在: {path}")
                      
              
              if not self.datasets:
                  raise ValueError("沒有有效的數據集路徑")
              self.combined = ConcatDataset(self.datasets)
              
          def __len__(self):
              return len(self.combined)
          
          def __getitem__(self, idx):
              return self.combined[idx]
      
      # 優化的數據加載器
      def create_optimized_dataloaders(batch_size, num_workers=8, pin_memory=True):
          # 組合多個大型數據集
          dataset_paths = [
              os.path.join(args.dataset_path, "DIV2K"),
              os.path.join(args.dataset_path, "Flickr2K")
          ]
          
          train_dataset = CombinedDataset(
              dataset_paths,
              scale_factor=args.scale_factor,
              patch_size=args.patch_size,
              train=True
          )
          
          val_dataset = SuperResolutionDataset(
              os.path.join(args.dataset_path, "DIV2K"),
              train=False,
              scale_factor=args.scale_factor,
              patch_size=args.patch_size
          )
          
          # 使用預加載和多進程加速
          train_loader = DataLoader(
              train_dataset,
              batch_size=batch_size,
              shuffle=True,
              num_workers=num_workers,
              pin_memory=pin_memory,
              prefetch_factor=2,  # 預加載下一批數據
              persistent_workers=True  # 保持工作進程存活
          )
          
          val_loader = DataLoader(
              val_dataset,
              batch_size=batch_size,
              shuffle=False,
              num_workers=num_workers,
              pin_memory=pin_memory
          )
          
          return train_loader, val_loader
      

      復雜模型實現:ESRGAN

      相比SRCNN,ESRGAN(Enhanced Super-Resolution Generative Adversarial Networks)能生成更富細節的高分辨率圖像。我們實現其核心結構:

      class ResidualDenseBlock(nn.Module):
          """殘差密集塊,ESRGAN的核心組件"""
          def __init__(self, nf=64, gc=32, bias=True):
              super(ResidualDenseBlock, self).__init__()
              self.conv1 = nn.Conv2d(nf + 0 * gc, gc, 3, 1, 1, bias=bias)
              self.conv2 = nn.Conv2d(nf + 1 * gc, gc, 3, 1, 1, bias=bias)
              self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
              self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
              self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
              self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
              
              # 初始化權重
              for m in self.modules():
                  if isinstance(m, nn.Conv2d):
                      nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
      
          def forward(self, x):
              x1 = self.lrelu(self.conv1(x))
              x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
              x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
              x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
              x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
              # 殘差連接
              return x5 * 0.2 + x
      
      class RRDB(nn.Module):
          """殘差在殘差密集塊"""
          def __init__(self, nf, gc=32):
              super(RRDB, self).__init__()
              self.rdb1 = ResidualDenseBlock(nf, gc)
              self.rdb2 = ResidualDenseBlock(nf, gc)
              self.rdb3 = ResidualDenseBlock(nf, gc)
      
          def forward(self, x):
              out = self.rdb1(x)
              out = self.rdb2(out)
              out = self.rdb3(out)
              # 殘差連接
              return out * 0.2 + x
      
      class RRDBNet(nn.Module):
          """ESRGAN 生成器的基礎模塊(RRDB 網絡)"""
          def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4):
              super(RRDBNet, self).__init__()
              self.scale = scale
              # 示例結構:卷積 + RRDB塊 + 上采樣 + 輸出卷積
              self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
              self.body = self._make_rrdb_blocks(num_feat, num_block, num_grow_ch)
              self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
              self.upsampler = self._make_upsampler(num_feat, scale)
              self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1, bias=True)
      
          def _make_rrdb_blocks(self, num_feat, num_block, num_grow_ch):
              blocks = []
              for _ in range(num_block):
                  blocks.append(RRDB(num_feat, num_grow_ch))
              return nn.Sequential(*blocks)
      
          def _make_upsampler(self, num_feat, scale):
              # 實現上采樣模塊
              upsampler = []
              for _ in range(int(torch.log2(torch.tensor(scale)))):
                  upsampler.append(nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True))
                  upsampler.append(nn.PixelShuffle(2))
              return nn.Sequential(*upsampler)
      
          def forward(self, x):
              # 實現前向傳播邏輯
              feat = self.conv_first(x)
              body_feat = self.conv_body(self.body(feat))
              feat = feat + body_feat
              out = self.conv_last(self.upsampler(feat))
              return out
      
      # 定義ESRGAN生成器(繼承RRDB網絡,保持接口一致性)
      class ESRGAN(RRDBNet):
          """ESRGAN生成器類"""
          def __init__(self, scale_factor=4, num_block=23, num_grow_ch=32, **kwargs):
              super(ESRGAN, self).__init__(
                  scale=scale_factor,
                  num_block=num_block,
                  num_grow_ch=num_grow_ch,** kwargs
              )
              self.scale_factor = scale_factor
              self.conv_first = nn.Conv2d(3, 64, 3, 1, 1, bias=True)
              # 保存參數為實例變量,供后續調用
              self.num_rrdb_blocks = num_block  # RRDB塊數量
              self.num_grow_ch = num_grow_ch    # 增長通道數
              # 正確調用_make_rrdb_blocks,使用實例變量
              self.RRDB_trunk = self._make_rrdb_blocks(64, self.num_rrdb_blocks, self.num_grow_ch)
              self.trunk_conv = nn.Conv2d(64, 64, 3, 1, 1, bias=True)
              self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True)
              self.conv_last = nn.Conv2d(64, 3, 3, 1, 1, bias=True)
              self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
              
              # 根據縮放因子添加上采樣層
              self.upsampler = self._make_upsampler(64, scale_factor)     # 復用父類的上采樣方法
      
          def forward(self, x):
              fea = self.conv_first(x)
              trunk = self.trunk_conv(self.RRDB_trunk(fea))
              fea = fea + trunk
              # 上采樣邏輯
              fea = self.upsampler(fea)           # 先上采樣到高分辨率尺寸
              
              fea = self.lrelu(self.HRconv(fea))
              out = self.conv_last(fea)
              return out
      

      生成對抗訓練策略

      ESRGAN使用GAN損失函數,需要定義生成器和判別器:

      # 判別器定義
      class Discriminator(nn.Module):
          def __init__(self, num_in_ch=3, num_feat=64, skip_connection=True):
              super(Discriminator, self).__init__()
              self.skip_connection = skip_connection
              
              self.features = nn.Sequential(
                  # 第一層:輸入為3通道(RGB圖像),輸出64通道
                  nn.Conv2d(num_in_ch, num_feat, 3, 1, 1),
                  nn.LeakyReLU(0.2, True),
                  
                  # 第二層:輸入64通道(承接上一層),輸出64通道,步長2(下采樣)
                  nn.Conv2d(num_feat, num_feat, 3, 2, 1),
                  nn.BatchNorm2d(num_feat),
                  nn.LeakyReLU(0.2, True),
                  
                  # 第三層:輸入64通道,輸出128通道
                  nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1),     # 64 -> 128
                  nn.BatchNorm2d(num_feat * 2),
                  nn.LeakyReLU(0.2, True),
                  
                  # 第四層:輸入128通道(承接上一層),輸出1280通道,步長2
                  nn.Conv2d(num_feat * 2, num_feat * 2, 3, 2, 1), # 128 -> 128
                  nn.BatchNorm2d(num_feat * 2),
                  nn.LeakyReLU(0.2, True),
                  
                  # 第五層:輸入128通道,輸出256通道
                  nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1), # 128 -> 256
                  nn.BatchNorm2d(num_feat * 4),
                  nn.LeakyReLU(0.2, True),
                  
                  # 第六層:輸入256通道,輸出256通道,步長2
                  nn.Conv2d(num_feat * 4, num_feat * 4, 3, 2, 1), # 256 -> 256
                  nn.BatchNorm2d(num_feat * 4),
                  nn.LeakyReLU(0.2, True),
                  
                  # 第七層:輸入256通道,輸出512通道
                  nn.Conv2d(num_feat * 4, num_feat * 8 ,3, 1, 1), # 256 -> 512
                  nn.BatchNorm2d(num_feat * 8),
                  nn.LeakyReLU(0.2, True),
                  
                  # 第八層:輸入512通道,輸出512通道,步長2
                  nn.Conv2d(num_feat * 8, num_feat * 8, 3, 2, 1), # 512 -> 512
                  nn.BatchNorm2d(num_feat * 8),
                  nn.LeakyReLU(0.2, True),
              )
              
              # 分類頭(判斷真假)
              self.classifier = nn.Sequential(
                  nn.AdaptiveAvgPool2d(1),
                  nn.Conv2d(num_feat * 8, num_feat * 16, 1, 1, 0),
                  nn.LeakyReLU(0.2, True),
                  nn.Conv2d(num_feat * 16, 1, 1, 1, 0),
              )
      
          def forward(self, x):
              x = self.features(x)
              x = self.classifier(x)
              return x
      
      # 混合損失函數
      class ContentLoss(nn.Module):
          def __init__(self):
              super(ContentLoss, self).__init__()
              # 使用預訓練的VGG作為特征提取器
              vgg = torchvision.models.vgg19(pretrained=True).features[:35].eval()
              for param in vgg.parameters():
                  param.requires_grad = False
              self.vgg = vgg.to(device)
              self.criterion = nn.L1Loss()
              self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                                   std=[0.229, 0.224, 0.225])
      
          def forward(self, sr, hr):
              # 歸一化輸入以匹配VGG訓練條件
              sr_norm = self.normalize(sr)
              hr_norm = self.normalize(hr)
              # 提取特征
              sr_feat = self.vgg(sr_norm)
              hr_feat = self.vgg(hr_norm)
              return self.criterion(sr_feat, hr_feat)
      

      GPU訓練優化技巧

      混合精度訓練

      from torch.amp import GradScaler, autocast
      
      def train():
          # 解析輸入參數
          args = parse_args()
          
          # 初始化日志
          train_logger = setup_logger(
              logger_name="train",
              log_file="train.log",
              log_dir="logs/train",
              level=logging.DEBUG  # 調試級別,輸出更詳細信息
          )
          
          # 加載數據集
          if args.download_datasets:
              train_logger.info("開始自動下載數據集...")
              prepare_large_datasets(args.dataset_path)  # 下載到指定路徑
              train_logger.info("數據集下載完成")
      
          train_loader, val_loader = create_optimized_dataloaders(
              batch_size=args.batch_size,
              num_workers=args.num_workers,
              pin_memory=args.pin_memory
          )
          train_logger.info(f"數據集加載完成 - 訓練集: {len(train_loader.dataset)} 樣本, 驗證集: {len(val_loader.dataset)} 樣本")
          
          # 初始化TensorBoard
          tb_writer = init_tensorboard(os.path.join(args.log_dir, "tensorboard"))
          
          # 設置設備
          device = torch.device(args.device)
          train_logger.info(f"使用設備: {device}")
          
          # 初始化模型
          generator = ESRGAN(scale_factor=args.scale_factor).to(device)
          discriminator = Discriminator().to(device)
          
          # 初始化損失函數
          content_criterion = ContentLoss(device)
          gan_criterion = GANLoss(gan_type='vanilla').to(device)
          
          # 初始化優化器
          g_optimizer = optim.Adam(generator.parameters(), lr=args.lr, betas=(0.9, 0.999))
          d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.9, 0.999))
          
          # 學習率調度器
          g_scheduler = CosineAnnealingWarmRestarts(g_optimizer, T_0=100, T_mult=2)
          d_scheduler = CosineAnnealingWarmRestarts(d_optimizer, T_0=100, T_mult=2)
          
          # 混合精度訓練
          scaler = GradScaler('cuda', enabled=args.use_amp)       # 顯式指定cuda設備,雖然默認也是cuda
          
          # 恢復訓練(如果有 checkpoint)
          start_epoch = 0
          if args.resume:
              if os.path.isfile(args.resume):
                  checkpoint = torch.load(args.resume, map_location=device)
                  generator.load_state_dict(checkpoint['generator_state_dict'])
                  discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
                  g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
                  d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])
                  start_epoch = checkpoint['epoch'] + 1
                  train_logger.info(f"從檢查點恢復訓練: {args.resume}, 開始于 epoch {start_epoch}")
              else:
                  train_logger.warning(f"未找到檢查點文件: {args.resume}, 從頭開始訓練")
          
          # 訓練循環
          train_logger.info("開始訓練...")
          for epoch in range(start_epoch, args.epochs):
              generator.train()
              discriminator.train()
              
              total_g_loss = 0.0
              total_d_loss = 0.0
              
              # 進度條
              pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}")
              
              for batch_idx, (lr_imgs, hr_imgs) in enumerate(pbar):
                  lr_imgs = lr_imgs.to(device)
                  hr_imgs = hr_imgs.to(device)
                  
                  # ---------------------
                  #  訓練生成器
                  # ---------------------
                  g_optimizer.zero_grad()     # 初始化梯度
                  grad_accum_steps = 4
                  
                  with autocast('cuda', enabled=args.use_amp):    # 確保在autocase上下文內計算損失
                      # 生成超分辨率圖像(此時sr_imgs尺寸已正確放大)
                      sr_imgs = generator(lr_imgs)
                      
                      # 計算生成器損失
                      print(f"SR尺寸:{sr_imgs.shape}, HR尺寸:{hr_imgs.shape}")
                      assert sr_imgs.shape == hr_imgs.shape, "SR與HR尺寸不匹配!"
                      content_loss = content_criterion(sr_imgs, hr_imgs)  # 現在尺寸匹配,可正常計算
                      fake_pred = discriminator(sr_imgs)
                      gan_loss = gan_criterion(fake_pred, True)
                      
                      # 總生成器損失 (內容損失權重更高)
                      g_loss = content_loss * 0.01 + gan_loss * 0.005
                      
                  # 梯度累計邏輯(在損失計算后執行)
                  scaled_loss = g_loss / grad_accum_steps  # 平均損失
                  scaler.scale(scaled_loss).backward(retain_graph=True)
      
                  if (batch_idx + 1) % grad_accum_steps == 0:
                      scaler.step(g_optimizer)
                      scaler.update()
                      g_optimizer.zero_grad()              # 累積結束后梯度清零
                  
                  # 反向傳播和優化
                  scaler.scale(g_loss).backward(retain_graph=True)
                  scaler.step(g_optimizer)
                  
                  # ---------------------
                  #  訓練判別器
                  # ---------------------
                  d_optimizer.zero_grad()
                  
                  # 注意:必須顯式指定設備類型
                  with autocast('cuda', enabled=args.use_amp):
                      # 真實圖像損失
                      real_pred = discriminator(hr_imgs)
                      real_loss = gan_criterion(real_pred, True)
                      
                      # 生成圖像損失
                      fake_pred = discriminator(sr_imgs.detach())  #  detach 避免更新生成器
                      fake_loss = gan_criterion(fake_pred, False)
                      
                      # 總判別器損失
                      d_loss = (real_loss + fake_loss) * 0.5
                  
                  # 反向傳播和優化
                  scaler.scale(d_loss).backward()
                  scaler.step(d_optimizer)
                  scaler.update()
                  
                  # 累計損失
                  total_g_loss += g_loss.item()
                  total_d_loss += d_loss.item()
                  
                  # 日志
                  if batch_idx % args.log_interval == 0:
                      avg_g_loss = total_g_loss / (batch_idx + 1)
                      avg_d_loss = total_d_loss / (batch_idx + 1)
                      pbar.set_postfix({"G Loss": f"{avg_g_loss:.4f}", "D Loss": f"{avg_d_loss:.4f}"})
                      
                      # 記錄TensorBoard
                      global_step = epoch * len(train_loader) + batch_idx
                      tb_writer.add_scalar('Loss/Generator', g_loss.item(), global_step)
                      tb_writer.add_scalar('Loss/Discriminator', d_loss.item(), global_step)
              
              # 每個epoch結束后更新學習率
              g_scheduler.step()
              d_scheduler.step()
              
              # 計算平均損失
              avg_g_loss_epoch = total_g_loss / len(train_loader)
              avg_d_loss_epoch = total_d_loss / len(train_loader)
              train_logger.info(f"Epoch {epoch+1} - G Loss: {avg_g_loss_epoch:.4f}, D Loss: {avg_d_loss_epoch:.4f}")
              
              # 保存檢查點
              if (epoch + 1) % args.save_freq == 0:
                  save_checkpoint(
                      epoch + 1,
                      generator,
                      discriminator,
                      g_optimizer,
                      d_optimizer,
                      args.checkpoint_dir,
                      train_logger
                  )
              
              # 驗證
              if (epoch + 1) % args.val_interval == 0:
                  generator.eval()
                  val_loss = 0.0
                  
                  with torch.no_grad():
                      for lr_imgs, hr_imgs in val_loader:
                          lr_imgs = lr_imgs.to(device)
                          hr_imgs = hr_imgs.to(device)
                          
                          sr_imgs = generator(lr_imgs)
                          loss = content_criterion(sr_imgs, hr_imgs)
                          val_loss += loss.item()
                  
                  avg_val_loss = val_loss / len(val_loader)
                  train_logger.info(f"驗證損失: {avg_val_loss:.4f}")
                  tb_writer.add_scalar('Loss/Validation', avg_val_loss, epoch)
                  
                  # 記錄示例圖像
                  tb_writer.add_images('LR Images', lr_imgs[:4], epoch)
                  tb_writer.add_images('HR Images', hr_imgs[:4], epoch)
                  tb_writer.add_images('SR Images', sr_imgs[:4], epoch, dataformats='NCHW')
          
          # 訓練結束
          train_logger.info("訓練完成!")
          tb_writer.close()
      

      梯度累積與學習率調度

      def main():
          args = parse_args()
          device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
          
          # 設置隨機種子,確保可復現性
          torch.manual_seed(args.seed)
          if torch.cuda.is_available():
              torch.cuda.manual_seed_all(args.seed)
              
          # 訓練循環
          train()
      

      擴展運行腳本

      針對大型實驗的增強版運行腳本:

      #!/bin/bash
      # run_esrgan_large.sh
      echo "啟動ESRGAN大型訓練任務..."
      
      # 設置工作目錄(按需選擇)
      cd /home/vscode/workspace/py-torch-learning/codes/esrgan-project
      
      # 創建目錄
      mkdir -p data checkpoints logs
      
      # 安裝依賴
      pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
      
      # 記錄開始時間
      start_time=$(date +%s)
      echo "實驗開始時間: $(date)"
      
      # 檢查GPU狀態
      nvidia-smi
      
      # 啟動訓練
      nohup python3 -u main.py \
          --epochs 100 \
          --batch_size 8 \
          --dataset_path ./data \
          --checkpoint_dir ./checkpoints \
          --download_datasets \
          > training_log_esrgan_$(date +%Y%m%d_%H%M%S).txt 2>&1 &
      
      # 記錄進程ID和日志文件
      echo "訓練任務已在后臺啟動,PID: $!"
      log_file="training_log_esrgan_$(date +%Y%m%d_%H%M%S).txt"
      echo "日志文件: $log_file"
      
      # 監控GPU使用情況(每5分鐘記錄一次)
      while true; do
          echo "GPU監控: $(date)" >> $log_file
          nvidia-smi >> $log_file 2>&1
          sleep 300  # 5分鐘
      done &
      

      實驗監控與分析

      使用TensorBoard可視化訓練過程:

      from torch.utils.tensorboard import SummaryWriter
      
      def init_tensorboard(log_dir):
          """初始化TensorBoard"""
          writer = SummaryWriter(log_dir=log_dir)
          return writer
      
      def log_to_tensorboard(writer, epoch, train_metrics, val_metrics, images):
          """將訓練指標和圖像寫入TensorBoard"""
          # 日志指標
          writer.add_scalar('Loss/Generator', train_metrics['gen_loss'], epoch)
          writer.add_scalar('Loss/Discriminator', train_metrics['dis_loss'], epoch)
          writer.add_scalar('PSNR/Train', train_metrics['psnr'], epoch)
          writer.add_scalar('PSNR/Validation', val_metrics['psnr'], epoch)
          writer.add_scalar('LearningRate/Generator', 
                           train_metrics['gen_lr'], epoch)
          
          # 日志圖像(每10個epoch)
          if epoch % 10 == 0:
              lr_img, sr_img, hr_img = images
              writer.add_image('Input/LowResolution', lr_img, epoch)
              writer.add_image('Output/SuperResolution', sr_img, epoch)
              writer.add_image('Target/HighResolution', hr_img, epoch)
      

      輸出結果匯總

      1. 核心輸出:模型檢查點(Checkpoint)
        內容:訓練過程中保存的生成器(Generator)和判別器(Discriminator)的權重參數、優化器狀態、訓練輪次等。
        路徑:由 config.py 中的 --checkpoint_dir 參數指定,默認路徑為:
        ./checkpoints/
        文件名格式為 checkpoint_epoch_{epoch}.pth(例如 checkpoint_epoch_10.pth)。
        觸發時機:每訓練 --save_freq 輪(默認 10 輪)保存一次,可通過命令行參數調整。
      2. 日志記錄
        內容:訓練過程中的損失值、驗證指標、關鍵操作日志(如下載 / 解壓進度、模型加載信息等)。
        路徑:
        文本日志:由 config.py 中的 --log_dir 參數指定,默認路徑為 ./logs/train/train.log。
        終端輸出日志:運行腳本 run_esrgan_large.sh 時,會重定向到 training_log_esrgan_$(date).txt(與腳本同目錄)。
      3. TensorBoard 可視化結果
        內容:訓練 / 驗證損失曲線、生成的超分辨率圖像(LR 輸入、HR 真實值、SR 預測值對比)。
        路徑:默認存儲在 ./logs/tensorboard/(由 main.py 中 init_tensorboard 函數指定,基于 --log_dir 參數)。
        查看方式:運行 tensorboard --logdir=./logs/tensorboard 后在瀏覽器訪問本地端口。
        注意tensorboard和protobuf的版本要匹配
      4. 超分辨率結果圖像(可選)
        內容:驗證階段或推理時生成的超分辨率圖像(SR Images)。
        路徑:由 config.py 中的 --result_dir 參數指定,默認路徑為 ./results/(代碼中已初始化該目錄,可在推理邏輯中補充保存圖像的代碼)。

      總結與后續方向

      通過本篇實驗,我們實現了:

      1.** 大型數據集管理 :自動下載、解壓和組合多個大型數據集,優化數據加載流程
      2.
      復雜模型構建 :實現了基于殘差密集塊的ESRGAN模型,相比SRCNN能生成更豐富的細節
      3.
      高級訓練策略 :引入混合精度訓練、梯度累積和余弦退火學習率調度,提升GPU利用率
      4.
      完善監控體系 **:結合日志文件、GPU監控和TensorBoard可視化,全面跟蹤實驗過程

      后續可探索的方向:

      • 嘗試更大規模的模型(如RCAN、SwinIR)
      • 引入感知損失和GAN的改進變體(如Relativistic GAN)
      • 實現模型并行和數據并行,利用多GPU進行訓練
      • 探索模型壓縮和加速技術,實現實時超分辨率
      • 嘗試視頻超分辨率任務,考慮時間維度的一致性

      下一篇我們將探索更前沿的視覺Transformer模型在超分辨率任務中的應用,進一步提升重建質量。

      posted on 2025-09-25 21:21  cmxcxd  閱讀(15)  評論(0)    收藏  舉報

      主站蜘蛛池模板: 久久精品第九区免费观看| 韩国免费a级毛片久久| 激情综合色综合啪啪五月| 四虎国产精品成人免费久久| 99在线视频免费观看| 香港日本三级亚洲三级| 国产真实乱对白精彩久久| 亚洲精品国产一二三区| 成人国产精品日本在线观看| 综合色一色综合久久网| 精品 无码 国产观看| 日韩中文字幕V亚洲中文字幕| 日本熟妇浓毛hdsex| 曰韩精品无码一区二区三区视频| 午夜精品亚洲一区二区三区| 99久久久无码国产精品免费| 亚洲一区二区精品极品| 亚洲欧美色综合影院| 加勒比中文字幕无码一区| 五月婷婷久久草| 视频一区二区三区四区不卡| 亚洲精品色一区二区三区| 国产一区二区日韩在线| 九九成人免费视频| 国产精品久久久久久福利| 高清国产美女一级a毛片在线| 亚洲精品动漫免费二区| 光棍天堂在线手机播放免费| 国产亚洲一二三区精品| 浓毛老太交欧美老妇热爱乱| 国产在线自拍一区二区三区| 人妻中文字幕不卡精品| 九九久久自然熟的香蕉图片| 亚洲综合一区二区三区| 亚洲情色av一区二区| 9久久伊人精品综合| 日韩av中文字幕有码| 无码丰满人妻熟妇区| 欧美大屁股喷潮水xxxx| 蜜臀久久精品亚洲一区| 精品国产成人一区二区|