重生之從零開始的神經網絡算法學習之路——第八篇 大型數據集與復雜模型的GPU訓練實踐
引言
在前一篇中,我們實現了基礎的SRCNN超分辨率模型并掌握了后臺訓練技巧。本篇將進一步拓展實驗規模:引入更大規模的數據集、實現更復雜的網絡結構,并優化GPU訓練策略,以應對更具挑戰性的圖像重建任務。通過這些實踐,我們將深入理解大規模深度學習實驗的關鍵技術和工程細節。
大型數據集的獲取與處理
適合超分辨率任務的大型數據集
為了提升模型泛化能力,我們可以使用以下大型數據集:
1.** DIV2K擴展集 :包含1000張高分辨率訓練圖像和100張驗證圖像(2K分辨率)
2. Flickr2K :2650張來自Flickr的高分辨率自然圖像(4K及以上)
3. CelebA-HQ :30,000張高質量人臉圖像(1024x1024分辨率)
4. ImageNet **:百萬級通用圖像數據集(可用于預訓練)
代碼實現
git clone https://gitee.com/cmx1998/py-torch-learning.git
cd py-torch-learning/codes/esrgan-project
自動下載與解壓實現
import os
import wget
import zipfile
import tarfile
from tqdm import tqdm
# 數據集下載配置
DATASETS = {
"DIV2K": {
"train": "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip",
"valid": "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip"
},
"Flickr2K": {
"url": "https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar"
}
}
def download_dataset(url, save_dir, filename=None):
"""帶進度條的數據集下載函數"""
# 根據輸入參數save_dir,新建存儲目錄
os.makedirs(save_dir, exist_ok=True)
# 傳入文件名參數檢查
if not filename:
filename = url.split("/")[-1]
save_path = os.path.join(save_dir, filename)
# 文件存在性檢查
if os.path.exists(save_path):
print(f"文件 {filename} 已存在,跳過下載")
return save_path
print(f"開始下載 {filename}...")
# 使用tqdm創建進度條,減少輸出頻率
with tqdm(total=100, desc=f"下載 {filename}", unit="%") as pbar:
def progress_bar(current, total, width=80):
progress = current / total * 100
pbar.n = int(progress)
pbar.update(0) # 只更新進度條顯示,不產生新輸出
wget.download(url, save_path, bar=progress_bar)
print(f"\n{filename} 下載完成")
return save_path
def extract_archive(file_path, extract_dir):
"""解壓數據集文件"""
# 根據輸入參數extract_dir,新建解壓目錄
os.makedirs(extract_dir, exist_ok=True)
filename = os.path.basename(file_path)
# 生成解壓后根目錄的標識(根據壓縮包名判斷)
extract_flag = os.path.join(extract_dir, f".{filename}.extracted") # 標記文件
# 檢查是否已解壓(通過標記文件判斷)
if os.path.exists(extract_flag):
print(f"文件 {filename} 已解壓,跳過解壓!")
return
# 執行解壓
try:
# 文件名后綴檢查
if file_path.endswith(".zip"):
"""處理zip壓縮文件"""
with zipfile.ZipFile(file_path, 'r') as zip_ref:
# 顯示解壓進度
for file in tqdm(zip_ref.namelist(), desc="解壓中"):
zip_ref.extract(file, extract_dir)
elif file_path.endswith(".tar") or file_path.endswith(".tar.gz"):
"""處理tar壓縮文件"""
with tarfile.open(file_path, 'r') as tar_ref:
# 顯示解壓進度
members = tar_ref.getmembers()
for member in tqdm(members, desc="解壓中"):
tar_ref.extract(member, extract_dir)
# 解壓成功后創建標記文件
with open(extract_flag, 'w') as f:
f.write("Extracted successfully")
print(f"文件 {os.path.basename(file_path)} 解壓完成")
except Exception as e:
print(f"解壓失敗:{e}")
# 失敗時刪除標記文件(避免誤判)
if os.path.exists(extract_flag):
os.remove(extract_flag)
def prepare_large_datasets(base_dir):
"""準備所有大型數據集"""
# 下載DIV2K
div2k_dir = os.path.join(base_dir, "DIV2K")
for split, url in DATASETS["DIV2K"].items():
file_path = download_dataset(url, div2k_dir)
extract_archive(file_path, os.path.join(div2k_dir, split))
# 下載Flickr2K
flickr_dir = os.path.join(base_dir, "Flickr2K")
flickr_url = DATASETS["Flickr2K"]["url"]
file_path = download_dataset(flickr_url, flickr_dir)
extract_archive(file_path, flickr_dir)
print("所有數據集準備完成")
高效數據加載策略
對于大型數據集,需要優化數據加載流程以充分利用GPU:
from torch.utils.data import ConcatDataset
class CombinedDataset(Dataset):
"""組合多個數據集的包裝類"""
def __init__(self, dataset_paths, scale_factor=4, patch_size=128, train=True,
augment=True, cache_in_memory=False):
self.datasets = []
for path in dataset_paths:
# 路徑檢查
if os.path.exists(path):
dataset = SuperResolutionDataset(
path,
scale_factor=scale_factor,
patch_size=patch_size,
train=train,
augment=augment,
cache_in_memory=cache_in_memory
)
self.datasets.append(dataset)
else:
print(f"警告: 數據集路徑不存在: {path}")
if not self.datasets:
raise ValueError("沒有有效的數據集路徑")
self.combined = ConcatDataset(self.datasets)
def __len__(self):
return len(self.combined)
def __getitem__(self, idx):
return self.combined[idx]
# 優化的數據加載器
def create_optimized_dataloaders(batch_size, num_workers=8, pin_memory=True):
# 組合多個大型數據集
dataset_paths = [
os.path.join(args.dataset_path, "DIV2K"),
os.path.join(args.dataset_path, "Flickr2K")
]
train_dataset = CombinedDataset(
dataset_paths,
scale_factor=args.scale_factor,
patch_size=args.patch_size,
train=True
)
val_dataset = SuperResolutionDataset(
os.path.join(args.dataset_path, "DIV2K"),
train=False,
scale_factor=args.scale_factor,
patch_size=args.patch_size
)
# 使用預加載和多進程加速
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
prefetch_factor=2, # 預加載下一批數據
persistent_workers=True # 保持工作進程存活
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory
)
return train_loader, val_loader
復雜模型實現:ESRGAN
相比SRCNN,ESRGAN(Enhanced Super-Resolution Generative Adversarial Networks)能生成更富細節的高分辨率圖像。我們實現其核心結構:
class ResidualDenseBlock(nn.Module):
"""殘差密集塊,ESRGAN的核心組件"""
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock, self).__init__()
self.conv1 = nn.Conv2d(nf + 0 * gc, gc, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(nf + 1 * gc, gc, 3, 1, 1, bias=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# 初始化權重
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
# 殘差連接
return x5 * 0.2 + x
class RRDB(nn.Module):
"""殘差在殘差密集塊"""
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(nf, gc)
self.rdb2 = ResidualDenseBlock(nf, gc)
self.rdb3 = ResidualDenseBlock(nf, gc)
def forward(self, x):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
# 殘差連接
return out * 0.2 + x
class RRDBNet(nn.Module):
"""ESRGAN 生成器的基礎模塊(RRDB 網絡)"""
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4):
super(RRDBNet, self).__init__()
self.scale = scale
# 示例結構:卷積 + RRDB塊 + 上采樣 + 輸出卷積
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
self.body = self._make_rrdb_blocks(num_feat, num_block, num_grow_ch)
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.upsampler = self._make_upsampler(num_feat, scale)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1, bias=True)
def _make_rrdb_blocks(self, num_feat, num_block, num_grow_ch):
blocks = []
for _ in range(num_block):
blocks.append(RRDB(num_feat, num_grow_ch))
return nn.Sequential(*blocks)
def _make_upsampler(self, num_feat, scale):
# 實現上采樣模塊
upsampler = []
for _ in range(int(torch.log2(torch.tensor(scale)))):
upsampler.append(nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True))
upsampler.append(nn.PixelShuffle(2))
return nn.Sequential(*upsampler)
def forward(self, x):
# 實現前向傳播邏輯
feat = self.conv_first(x)
body_feat = self.conv_body(self.body(feat))
feat = feat + body_feat
out = self.conv_last(self.upsampler(feat))
return out
# 定義ESRGAN生成器(繼承RRDB網絡,保持接口一致性)
class ESRGAN(RRDBNet):
"""ESRGAN生成器類"""
def __init__(self, scale_factor=4, num_block=23, num_grow_ch=32, **kwargs):
super(ESRGAN, self).__init__(
scale=scale_factor,
num_block=num_block,
num_grow_ch=num_grow_ch,** kwargs
)
self.scale_factor = scale_factor
self.conv_first = nn.Conv2d(3, 64, 3, 1, 1, bias=True)
# 保存參數為實例變量,供后續調用
self.num_rrdb_blocks = num_block # RRDB塊數量
self.num_grow_ch = num_grow_ch # 增長通道數
# 正確調用_make_rrdb_blocks,使用實例變量
self.RRDB_trunk = self._make_rrdb_blocks(64, self.num_rrdb_blocks, self.num_grow_ch)
self.trunk_conv = nn.Conv2d(64, 64, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# 根據縮放因子添加上采樣層
self.upsampler = self._make_upsampler(64, scale_factor) # 復用父類的上采樣方法
def forward(self, x):
fea = self.conv_first(x)
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
# 上采樣邏輯
fea = self.upsampler(fea) # 先上采樣到高分辨率尺寸
fea = self.lrelu(self.HRconv(fea))
out = self.conv_last(fea)
return out
生成對抗訓練策略
ESRGAN使用GAN損失函數,需要定義生成器和判別器:
# 判別器定義
class Discriminator(nn.Module):
def __init__(self, num_in_ch=3, num_feat=64, skip_connection=True):
super(Discriminator, self).__init__()
self.skip_connection = skip_connection
self.features = nn.Sequential(
# 第一層:輸入為3通道(RGB圖像),輸出64通道
nn.Conv2d(num_in_ch, num_feat, 3, 1, 1),
nn.LeakyReLU(0.2, True),
# 第二層:輸入64通道(承接上一層),輸出64通道,步長2(下采樣)
nn.Conv2d(num_feat, num_feat, 3, 2, 1),
nn.BatchNorm2d(num_feat),
nn.LeakyReLU(0.2, True),
# 第三層:輸入64通道,輸出128通道
nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1), # 64 -> 128
nn.BatchNorm2d(num_feat * 2),
nn.LeakyReLU(0.2, True),
# 第四層:輸入128通道(承接上一層),輸出1280通道,步長2
nn.Conv2d(num_feat * 2, num_feat * 2, 3, 2, 1), # 128 -> 128
nn.BatchNorm2d(num_feat * 2),
nn.LeakyReLU(0.2, True),
# 第五層:輸入128通道,輸出256通道
nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1), # 128 -> 256
nn.BatchNorm2d(num_feat * 4),
nn.LeakyReLU(0.2, True),
# 第六層:輸入256通道,輸出256通道,步長2
nn.Conv2d(num_feat * 4, num_feat * 4, 3, 2, 1), # 256 -> 256
nn.BatchNorm2d(num_feat * 4),
nn.LeakyReLU(0.2, True),
# 第七層:輸入256通道,輸出512通道
nn.Conv2d(num_feat * 4, num_feat * 8 ,3, 1, 1), # 256 -> 512
nn.BatchNorm2d(num_feat * 8),
nn.LeakyReLU(0.2, True),
# 第八層:輸入512通道,輸出512通道,步長2
nn.Conv2d(num_feat * 8, num_feat * 8, 3, 2, 1), # 512 -> 512
nn.BatchNorm2d(num_feat * 8),
nn.LeakyReLU(0.2, True),
)
# 分類頭(判斷真假)
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(num_feat * 8, num_feat * 16, 1, 1, 0),
nn.LeakyReLU(0.2, True),
nn.Conv2d(num_feat * 16, 1, 1, 1, 0),
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
# 混合損失函數
class ContentLoss(nn.Module):
def __init__(self):
super(ContentLoss, self).__init__()
# 使用預訓練的VGG作為特征提取器
vgg = torchvision.models.vgg19(pretrained=True).features[:35].eval()
for param in vgg.parameters():
param.requires_grad = False
self.vgg = vgg.to(device)
self.criterion = nn.L1Loss()
self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
def forward(self, sr, hr):
# 歸一化輸入以匹配VGG訓練條件
sr_norm = self.normalize(sr)
hr_norm = self.normalize(hr)
# 提取特征
sr_feat = self.vgg(sr_norm)
hr_feat = self.vgg(hr_norm)
return self.criterion(sr_feat, hr_feat)
GPU訓練優化技巧
混合精度訓練
from torch.amp import GradScaler, autocast
def train():
# 解析輸入參數
args = parse_args()
# 初始化日志
train_logger = setup_logger(
logger_name="train",
log_file="train.log",
log_dir="logs/train",
level=logging.DEBUG # 調試級別,輸出更詳細信息
)
# 加載數據集
if args.download_datasets:
train_logger.info("開始自動下載數據集...")
prepare_large_datasets(args.dataset_path) # 下載到指定路徑
train_logger.info("數據集下載完成")
train_loader, val_loader = create_optimized_dataloaders(
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_memory
)
train_logger.info(f"數據集加載完成 - 訓練集: {len(train_loader.dataset)} 樣本, 驗證集: {len(val_loader.dataset)} 樣本")
# 初始化TensorBoard
tb_writer = init_tensorboard(os.path.join(args.log_dir, "tensorboard"))
# 設置設備
device = torch.device(args.device)
train_logger.info(f"使用設備: {device}")
# 初始化模型
generator = ESRGAN(scale_factor=args.scale_factor).to(device)
discriminator = Discriminator().to(device)
# 初始化損失函數
content_criterion = ContentLoss(device)
gan_criterion = GANLoss(gan_type='vanilla').to(device)
# 初始化優化器
g_optimizer = optim.Adam(generator.parameters(), lr=args.lr, betas=(0.9, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.9, 0.999))
# 學習率調度器
g_scheduler = CosineAnnealingWarmRestarts(g_optimizer, T_0=100, T_mult=2)
d_scheduler = CosineAnnealingWarmRestarts(d_optimizer, T_0=100, T_mult=2)
# 混合精度訓練
scaler = GradScaler('cuda', enabled=args.use_amp) # 顯式指定cuda設備,雖然默認也是cuda
# 恢復訓練(如果有 checkpoint)
start_epoch = 0
if args.resume:
if os.path.isfile(args.resume):
checkpoint = torch.load(args.resume, map_location=device)
generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
train_logger.info(f"從檢查點恢復訓練: {args.resume}, 開始于 epoch {start_epoch}")
else:
train_logger.warning(f"未找到檢查點文件: {args.resume}, 從頭開始訓練")
# 訓練循環
train_logger.info("開始訓練...")
for epoch in range(start_epoch, args.epochs):
generator.train()
discriminator.train()
total_g_loss = 0.0
total_d_loss = 0.0
# 進度條
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.epochs}")
for batch_idx, (lr_imgs, hr_imgs) in enumerate(pbar):
lr_imgs = lr_imgs.to(device)
hr_imgs = hr_imgs.to(device)
# ---------------------
# 訓練生成器
# ---------------------
g_optimizer.zero_grad() # 初始化梯度
grad_accum_steps = 4
with autocast('cuda', enabled=args.use_amp): # 確保在autocase上下文內計算損失
# 生成超分辨率圖像(此時sr_imgs尺寸已正確放大)
sr_imgs = generator(lr_imgs)
# 計算生成器損失
print(f"SR尺寸:{sr_imgs.shape}, HR尺寸:{hr_imgs.shape}")
assert sr_imgs.shape == hr_imgs.shape, "SR與HR尺寸不匹配!"
content_loss = content_criterion(sr_imgs, hr_imgs) # 現在尺寸匹配,可正常計算
fake_pred = discriminator(sr_imgs)
gan_loss = gan_criterion(fake_pred, True)
# 總生成器損失 (內容損失權重更高)
g_loss = content_loss * 0.01 + gan_loss * 0.005
# 梯度累計邏輯(在損失計算后執行)
scaled_loss = g_loss / grad_accum_steps # 平均損失
scaler.scale(scaled_loss).backward(retain_graph=True)
if (batch_idx + 1) % grad_accum_steps == 0:
scaler.step(g_optimizer)
scaler.update()
g_optimizer.zero_grad() # 累積結束后梯度清零
# 反向傳播和優化
scaler.scale(g_loss).backward(retain_graph=True)
scaler.step(g_optimizer)
# ---------------------
# 訓練判別器
# ---------------------
d_optimizer.zero_grad()
# 注意:必須顯式指定設備類型
with autocast('cuda', enabled=args.use_amp):
# 真實圖像損失
real_pred = discriminator(hr_imgs)
real_loss = gan_criterion(real_pred, True)
# 生成圖像損失
fake_pred = discriminator(sr_imgs.detach()) # detach 避免更新生成器
fake_loss = gan_criterion(fake_pred, False)
# 總判別器損失
d_loss = (real_loss + fake_loss) * 0.5
# 反向傳播和優化
scaler.scale(d_loss).backward()
scaler.step(d_optimizer)
scaler.update()
# 累計損失
total_g_loss += g_loss.item()
total_d_loss += d_loss.item()
# 日志
if batch_idx % args.log_interval == 0:
avg_g_loss = total_g_loss / (batch_idx + 1)
avg_d_loss = total_d_loss / (batch_idx + 1)
pbar.set_postfix({"G Loss": f"{avg_g_loss:.4f}", "D Loss": f"{avg_d_loss:.4f}"})
# 記錄TensorBoard
global_step = epoch * len(train_loader) + batch_idx
tb_writer.add_scalar('Loss/Generator', g_loss.item(), global_step)
tb_writer.add_scalar('Loss/Discriminator', d_loss.item(), global_step)
# 每個epoch結束后更新學習率
g_scheduler.step()
d_scheduler.step()
# 計算平均損失
avg_g_loss_epoch = total_g_loss / len(train_loader)
avg_d_loss_epoch = total_d_loss / len(train_loader)
train_logger.info(f"Epoch {epoch+1} - G Loss: {avg_g_loss_epoch:.4f}, D Loss: {avg_d_loss_epoch:.4f}")
# 保存檢查點
if (epoch + 1) % args.save_freq == 0:
save_checkpoint(
epoch + 1,
generator,
discriminator,
g_optimizer,
d_optimizer,
args.checkpoint_dir,
train_logger
)
# 驗證
if (epoch + 1) % args.val_interval == 0:
generator.eval()
val_loss = 0.0
with torch.no_grad():
for lr_imgs, hr_imgs in val_loader:
lr_imgs = lr_imgs.to(device)
hr_imgs = hr_imgs.to(device)
sr_imgs = generator(lr_imgs)
loss = content_criterion(sr_imgs, hr_imgs)
val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)
train_logger.info(f"驗證損失: {avg_val_loss:.4f}")
tb_writer.add_scalar('Loss/Validation', avg_val_loss, epoch)
# 記錄示例圖像
tb_writer.add_images('LR Images', lr_imgs[:4], epoch)
tb_writer.add_images('HR Images', hr_imgs[:4], epoch)
tb_writer.add_images('SR Images', sr_imgs[:4], epoch, dataformats='NCHW')
# 訓練結束
train_logger.info("訓練完成!")
tb_writer.close()
梯度累積與學習率調度
def main():
args = parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 設置隨機種子,確保可復現性
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
# 訓練循環
train()
擴展運行腳本
針對大型實驗的增強版運行腳本:
#!/bin/bash
# run_esrgan_large.sh
echo "啟動ESRGAN大型訓練任務..."
# 設置工作目錄(按需選擇)
cd /home/vscode/workspace/py-torch-learning/codes/esrgan-project
# 創建目錄
mkdir -p data checkpoints logs
# 安裝依賴
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
# 記錄開始時間
start_time=$(date +%s)
echo "實驗開始時間: $(date)"
# 檢查GPU狀態
nvidia-smi
# 啟動訓練
nohup python3 -u main.py \
--epochs 100 \
--batch_size 8 \
--dataset_path ./data \
--checkpoint_dir ./checkpoints \
--download_datasets \
> training_log_esrgan_$(date +%Y%m%d_%H%M%S).txt 2>&1 &
# 記錄進程ID和日志文件
echo "訓練任務已在后臺啟動,PID: $!"
log_file="training_log_esrgan_$(date +%Y%m%d_%H%M%S).txt"
echo "日志文件: $log_file"
# 監控GPU使用情況(每5分鐘記錄一次)
while true; do
echo "GPU監控: $(date)" >> $log_file
nvidia-smi >> $log_file 2>&1
sleep 300 # 5分鐘
done &
實驗監控與分析
使用TensorBoard可視化訓練過程:
from torch.utils.tensorboard import SummaryWriter
def init_tensorboard(log_dir):
"""初始化TensorBoard"""
writer = SummaryWriter(log_dir=log_dir)
return writer
def log_to_tensorboard(writer, epoch, train_metrics, val_metrics, images):
"""將訓練指標和圖像寫入TensorBoard"""
# 日志指標
writer.add_scalar('Loss/Generator', train_metrics['gen_loss'], epoch)
writer.add_scalar('Loss/Discriminator', train_metrics['dis_loss'], epoch)
writer.add_scalar('PSNR/Train', train_metrics['psnr'], epoch)
writer.add_scalar('PSNR/Validation', val_metrics['psnr'], epoch)
writer.add_scalar('LearningRate/Generator',
train_metrics['gen_lr'], epoch)
# 日志圖像(每10個epoch)
if epoch % 10 == 0:
lr_img, sr_img, hr_img = images
writer.add_image('Input/LowResolution', lr_img, epoch)
writer.add_image('Output/SuperResolution', sr_img, epoch)
writer.add_image('Target/HighResolution', hr_img, epoch)
輸出結果匯總
- 核心輸出:模型檢查點(Checkpoint)
內容:訓練過程中保存的生成器(Generator)和判別器(Discriminator)的權重參數、優化器狀態、訓練輪次等。
路徑:由 config.py 中的 --checkpoint_dir 參數指定,默認路徑為:
./checkpoints/
文件名格式為 checkpoint_epoch_{epoch}.pth(例如 checkpoint_epoch_10.pth)。
觸發時機:每訓練 --save_freq 輪(默認 10 輪)保存一次,可通過命令行參數調整。 - 日志記錄
內容:訓練過程中的損失值、驗證指標、關鍵操作日志(如下載 / 解壓進度、模型加載信息等)。
路徑:
文本日志:由 config.py 中的 --log_dir 參數指定,默認路徑為 ./logs/train/train.log。
終端輸出日志:運行腳本 run_esrgan_large.sh 時,會重定向到 training_log_esrgan_$(date).txt(與腳本同目錄)。 - TensorBoard 可視化結果
內容:訓練 / 驗證損失曲線、生成的超分辨率圖像(LR 輸入、HR 真實值、SR 預測值對比)。
路徑:默認存儲在 ./logs/tensorboard/(由 main.py 中 init_tensorboard 函數指定,基于 --log_dir 參數)。
查看方式:運行 tensorboard --logdir=./logs/tensorboard 后在瀏覽器訪問本地端口。
注意tensorboard和protobuf的版本要匹配 - 超分辨率結果圖像(可選)
內容:驗證階段或推理時生成的超分辨率圖像(SR Images)。
路徑:由 config.py 中的 --result_dir 參數指定,默認路徑為 ./results/(代碼中已初始化該目錄,可在推理邏輯中補充保存圖像的代碼)。
總結與后續方向
通過本篇實驗,我們實現了:
1.** 大型數據集管理 :自動下載、解壓和組合多個大型數據集,優化數據加載流程
2. 復雜模型構建 :實現了基于殘差密集塊的ESRGAN模型,相比SRCNN能生成更豐富的細節
3. 高級訓練策略 :引入混合精度訓練、梯度累積和余弦退火學習率調度,提升GPU利用率
4. 完善監控體系 **:結合日志文件、GPU監控和TensorBoard可視化,全面跟蹤實驗過程
后續可探索的方向:
- 嘗試更大規模的模型(如RCAN、SwinIR)
- 引入感知損失和GAN的改進變體(如Relativistic GAN)
- 實現模型并行和數據并行,利用多GPU進行訓練
- 探索模型壓縮和加速技術,實現實時超分辨率
- 嘗試視頻超分辨率任務,考慮時間維度的一致性
下一篇我們將探索更前沿的視覺Transformer模型在超分辨率任務中的應用,進一步提升重建質量。
浙公網安備 33010602011771號