T4_0414N_資料、模型、訓練流程全面優化
你已經成功建立與訓練一個 MIMO-UNet 模型來處理模糊圖片,但你觀察到推論出來的圖片依然模糊,代表 模型的學習效果有限或尚未收斂。這很常見,我幫你整理出一套「強化訓練效果」的策略,從資料、模型、訓練流程全面優化。
🔧 一、資料相關調整
✅ 1. 增加資料多樣性與數量
- 資料量太少會導致模型無法泛化,建議:
- 收集更多類型的模糊圖片(不同模糊程度、方向、背景)
- 使用模糊核自動產生模糊圖片 (
cv2.GaussianBlur、motion blur)
✅ 2. 增加資料增強(Data Augmentation)
目前 Dataset 中僅做了 Resize 與 ToTensor,建議加上:
self.transform = T.Compose([
T.Resize((256, 256)),
T.RandomHorizontalFlip(),
T.RandomRotation(10),
T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
T.ToTensor()
])
🧠 二、模型架構強化(可選升級)
✅ 1. 加入 BatchNorm2d:提升穩定性
修改 ResidualBlock:
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(channels, channels, 3, padding=1),
nn.BatchNorm2d(channels),
nn.ReLU(),
nn.Conv2d(channels, channels, 3, padding=1),
nn.BatchNorm2d(channels)
)
def forward(self, x): return x + self.block(x)
✅ 2. 替換 Loss 函數為更感知的損失
目前用的是 MSELoss,但這對視覺品質不夠敏感。建議改用:
import torchvision.models as models
class PerceptualLoss(nn.Module):
def __init__(self):
super().__init__()
vgg = models.vgg16(pretrained=True).features[:16].eval()
for param in vgg.parameters(): param.requires_grad = False
self.vgg = vgg
self.loss_fn = nn.MSELoss()
def forward(self, pred, target):
vgg_pred = self.vgg(pred)
vgg_target = self.vgg(target)
return self.loss_fn(vgg_pred, vgg_target)
或者組合兩者:
loss_fn = lambda pred, target: mse_loss(pred, target) + 0.1 * perceptual_loss(pred, target)
🚀 三、訓練流程強化建議
✅ 1. 增加 Epoch 數
你的 30 epochs 可能不足,建議先訓練 100+ epochs(儲存最佳模型)。
✅ 2. 使用學習率衰減策略
加上:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)
在每個 epoch 結尾加入:
scheduler.step(total_loss)
✅ 3. 加入 Validation 集與評估指標(PSNR / SSIM)
這有助於觀察模型是否過擬合:
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim
def evaluate(model, loader, device):
model.eval()
psnr_total, ssim_total = 0, 0
with torch.no_grad():
for blur, sharp in loader:
blur, sharp = blur.to(device), sharp.to(device)
output = model(blur).clamp(0, 1)
for pred, gt in zip(output, sharp):
pred_img = pred.permute(1, 2, 0).cpu().numpy()
gt_img = gt.permute(1, 2, 0).cpu().numpy()
psnr_total += psnr(gt_img, pred_img, data_range=1.0)
ssim_total += ssim(gt_img, pred_img, channel_axis=-1, data_range=1.0)
print(f"[Eval] PSNR: {psnr_total/len(loader.dataset):.2f}, SSIM: {ssim_total/len(loader.dataset):.4f}")
✅ 小結(調整重點整理)
| 項目 | 建議 |
|---|---|
| 📁 資料 | 增加樣本數、多樣性、資料增強 |
| 🧠 模型 | 加入 BatchNorm、可選 Perceptual Loss |
| 🔧 損失函數 | MSE + Perceptual Loss 組合更穩定 |
| 📊 評估 | 加入 PSNR/SSIM、Validation dataset |
| 🔁 訓練 | Epoch 100+、使用學習率調整策略 |
🎁 Bonus:我可以幫你
- 加入完整的 可視化曲線 + 評估指標
- 幫你整合成 GUI、CLI 或 Web API
- 產生一份訓練報告(含訓練圖表與效果圖)
需要我幫你產生優化後的 train.py 或重新整理資料集架構嗎?你只要告訴我你目前的資料集規劃,我就能幫你改完整程式。