重生之從零開始的神經網絡算法學習之路——第七篇 重拾PyTorch(超分辨率重建和腳本的使用)
引言
在前一篇中,我們初步探索了PyTorch框架的使用并體驗了GPU加速計算的優勢。本篇將聚焦于一個更具視覺沖擊力的任務——圖像超分辨率重建,通過實現經典的SRCNN模型,深入學習PyTorch在圖像處理任務中的應用,并掌握使用腳本進行后臺訓練的實用技巧。
超分辨率重建技術旨在將低分辨率圖像恢復為高分辨率圖像,在監控視頻增強、醫學影像分析、衛星圖像處理等領域有著廣泛應用。與圖像分類任務不同,超分辨率是典型的生成式任務,其輸入和輸出均為圖像,這為我們提供了學習PyTorch中圖像處理流水線的絕佳機會。
超分辨率重建原理與SRCNN模型
超分辨率任務概述
超分辨率(Super Resolution, SR)是指從低分辨率(Low Resolution, LR)圖像中恢復出高分辨率(High Resolution, HR)圖像的技術。其核心挑戰在于如何在提升圖像尺寸的同時,保持并增強圖像細節,避免產生模糊或偽影。
常見的超分辨率方法可分為:
- 插值方法(如雙三次插值):簡單但效果有限
- 基于重建的方法:利用先驗知識約束重建過程
- 基于學習的方法:通過神經網絡學習LR到HR的映射關系(當前主流)
SRCNN模型結構
我們將實現2014年提出的SRCNN(Super-Resolution Convolutional Neural Network),這是首個將卷積神經網絡應用于超分辨率任務的模型,其結構簡潔卻效果顯著:
- 特征提取:使用9x9卷積核從低分辨率圖像中提取基礎特征
- 非線性映射:通過1x1卷積核進行特征轉換和降維
- 重建:使用5x5卷積核生成最終的高分辨率圖像
與傳統方法相比,SRCNN通過端到端的訓練,能夠自動學習從低分辨率到高分辨率的映射關系,無需人工設計特征。
環境準備與項目結構
我們繼續使用第六篇中搭建的PyTorch GPU環境,項目結構如下:
workspace/
├── data/ # 數據集目錄
│ └── DIV2K/ # 超分辨率專用數據集
│ ├── train/ # 訓練集
│ └── valid/ # 驗證集
├── super_resolution_output/ # 輸出目錄
│ ├── checkpoints/ # 模型檢查點
│ └── training_log.json # 訓練日志
├── PyTorch_SuperResolution_GPU.py # 主程序
└── run_super_resolution.sh # 運行腳本
超分辨率重建代碼實現
核心代碼解析
完整代碼可參考PyTorch_SuperResolution_GPU.py,以下為關鍵部分解析:
1. 模型定義
class SRCNN(nn.Module):
def __init__(self, scale_factor=4):
super(SRCNN, self).__init__()
self.scale_factor = scale_factor
# 特征提取層
self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
# 非線性映射層
self.conv2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)
# 重建層
self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
# 首先對輸入進行上采樣(雙三次插值)
x = F.interpolate(x, scale_factor=self.scale_factor,
mode='bicubic', align_corners=False)
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.conv3(x)
return x
SRCNN的特點是先通過插值將低分辨率圖像放大到目標尺寸,再通過卷積網絡優化細節,這種設計既利用了傳統插值的基礎結構,又通過神經網絡修復了細節損失。
2. 自定義數據集
class SuperResolutionDataset(Dataset):
def __init__(self, dataset_path, transform=None, train=True,
scale_factor=4, patch_size=128):
self.dataset_path = dataset_path
self.train = train
self.scale_factor = scale_factor
self.patch_size = patch_size
# 收集圖像路徑
if train:
self.image_paths = glob.glob(os.path.join(dataset_path, 'train',
'**', '*.png'), recursive=True)
self.image_paths += glob.glob(os.path.join(dataset_path, 'train',
'**', '*.jpg'), recursive=True)
else:
self.image_paths = glob.glob(os.path.join(dataset_path, 'valid',
'**', '*.png'), recursive=True)
# ... 處理圖像路徑和備用數據集
def __getitem__(self, idx):
# 加載高分辨率圖像
hr_image = Image.open(img_path).convert('RGB')
# 數據增強 - 隨機裁剪和翻轉
if self.train:
i = random.randint(0, hr_image.height - self.patch_size)
j = random.randint(0, hr_image.width - self.patch_size)
hr_image = hr_image.crop((j, i, j + self.patch_size, i + self.patch_size))
if random.random() > 0.5:
hr_image = hr_image.transpose(Image.FLIP_LEFT_RIGHT)
# 轉換為張量
hr_image = transforms.ToTensor()(hr_image)
# 生成對應的低分辨率圖像
lr_size = self.patch_size // self.scale_factor
lr_image = F.interpolate(hr_image.unsqueeze(0),
size=(lr_size, lr_size),
mode='bicubic',
align_corners=False).squeeze(0)
return lr_image, hr_image
超分辨率數據集的核心是為每張高分辨率圖像生成對應的低分辨率版本,通過下采樣操作模擬真實場景中的低清圖像。訓練時使用圖像塊(patch)而非完整圖像,既能減少內存占用,又能增加訓練樣本多樣性。
3. 評估指標PSNR
def psnr(original, compressed):
mse = torch.mean((original - compressed) **2)
if mse == 0: # MSE為0表示完美重建
return 100
max_pixel = 1.0 # 圖像像素值已歸一化到[0,1]
psnr = 20 * log10(max_pixel / torch.sqrt(mse))
return psnr
峰值信噪比(PSNR)是圖像重建任務中常用的評估指標,數值越高表示重建質量越好(通常30dB以上為可接受質量)。其計算公式基于均方誤差(MSE),反映了重建圖像與真實圖像的像素差異。
4. 訓練與驗證循環
def train(model, train_loader, criterion, optimizer, epoch):
model.train()
train_loss = 0
total_psnr = 0
for batch_idx, (lr_imgs, hr_imgs) in enumerate(train_loader):
lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
optimizer.zero_grad()
outputs = model(lr_imgs)
loss = criterion(outputs, hr_imgs)
loss.backward()
optimizer.step()
train_loss += loss.item()
batch_psnr = psnr(hr_imgs, outputs)
total_psnr += batch_psnr
# 日志輸出
if batch_idx % args.log_interval == 0:
print(f'Train Epoch: {epoch} [{batch_idx * len(lr_imgs)}/{len(train_loader.dataset)} '
f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}\t'
f'PSNR: {batch_psnr:.2f} dB')
超分辨率訓練使用MSE損失函數(像素級損失),通過最小化重建圖像與真實高分辨率圖像的像素差異來優化模型參數。訓練過程中同時監控損失和PSNR指標,便于分析模型收斂情況。
使用腳本進行后臺訓練
對于超分辨率這類需要長時間訓練的任務,直接在終端運行程序存在風險(如斷開連接導致訓練中斷)。我們可以使用Shell腳本實現后臺訓練和日志記錄。
運行腳本解析
run_super_resolution.sh腳本內容如下:
#!/bin/bash
# 設置工作目錄
cd /home/vscode/workspace
# 運行超分辨率訓練腳本
# 使用nohup和&實現后臺運行,輸出重定向到日志文件
nohup python3 PyTorch_SuperResolution_GPU.py \
--epochs 1000 \
--batch_size 32 \
--lr 0.001 \
--checkpoint_interval 10 \
--log_interval 50 \
> training_log_$(date +%Y%m%d_%H%M%S).txt 2>&1 &
# 顯示進程信息
echo "訓練任務已在后臺啟動,PID: $!"
echo "日志文件: training_log_$(date +%Y%m%d_%H%M%S).txt"
腳本關鍵技術點:
nohup:忽略掛起信號,確保程序在終端關閉后繼續運行> training_log...txt:將標準輸出重定向到日志文件2>&1:將錯誤輸出合并到標準輸出&:將程序放入后臺運行$(date +%Y%m%d_%H%M%S):生成帶時間戳的唯一日志文件名
腳本使用方法
-
賦予腳本執行權限:
chmod +x run_super_resolution.sh -
運行腳本:
./run_super_resolution.sh -
查看訓練日志:
tail -f training_log_20240520_153045.txt # 替換為實際日志文件名 -
查看后臺進程:
ps aux | grep PyTorch_SuperResolution_GPU.py -
終止訓練(如需):
kill -9 <進程PID> # 替換為實際進程ID
檢查點與訓練恢復
長時間訓練中,定期保存檢查點(checkpoint)至關重要,代碼中實現了完善的檢查點機制:
def save_checkpoint(model, optimizer, epoch, loss, psnr, is_best=False):
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'loss': loss,
'psnr': psnr
}
filename = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
torch.save(state, filename)
# 保存最佳模型
if is_best:
best_filename = os.path.join(checkpoint_dir, 'model_best.pth')
torch.save(state, best_filename)
# 清理舊檢查點,只保留最近5個
checkpoints = sorted([f for f in os.listdir(checkpoint_dir)
if f.startswith('checkpoint_epoch_')])
if len(checkpoints) > 5:
for old_checkpoint in checkpoints[:-5]:
os.remove(os.path.join(checkpoint_dir, old_checkpoint))
從檢查點恢復訓練:
nohup python3 PyTorch_SuperResolution_GPU.py \
--resume ./super_resolution_output/checkpoints/checkpoint_epoch_200.pth \
--epochs 1000 \
> training_log_resume.txt 2>&1 &
實驗結果與分析
經過1000輪訓練后,我們得到以下結果:
- 訓練集PSNR從初始的24.35dB提升至32.68dB
- 驗證集PSNR從初始的23.87dB提升至31.24dB
- 每輪訓練時間約為45秒(使用NVIDIA Tesla T4 GPU)
從視覺效果看,SRCNN重建結果相比單純插值:
- 邊緣更清晰(如建筑物輪廓、文本邊緣)
- 細節更豐富(如紋理、小尺度特征)
- 減少了模糊和鋸齒現象
總結與進階方向
通過本篇實驗,我們掌握了:
- 1.超分辨率核心技術:理解SRCNN工作原理和圖像重建流程
- 2.使用訓練技巧:使用Shell腳本進行后臺訓練、日志管理和進程監控
- 3.檢查點機制:實現訓練中斷后的恢復功能,保障長時間實驗的穩定性
- 4.評估指標:掌握PSNR計算方法及在圖像重建任務中的應用
進階改進方向:
- 嘗試更先進的模型(如ESRGAN、RCAN)
- 引入感知損失(Perceptual Loss)提升視覺質量
- 增加更多數據增強策略(旋轉、縮放、噪聲添加)
- 實現模型量化和部署,探索實際應用場景
超分辨率技術正朝著更高效、更高質量的方向發展,結合注意力機制和生成對抗網絡的方法已能產生接近真實的重建效果。下一篇我們將探索更復雜的網絡結構和訓練策略,進一步提升模型性能。
浙公網安備 33010602011771號