人工智能【第31篇】生成对抗网络GAN入门:AI的创造力之源
作者的话在前面的文章中我们学习了各种监督学习和无监督学习算法以及深度学习中的CNN、RNN等架构。今天我们将进入一个充满想象力的领域——生成对抗网络GAN。GAN让AI拥有了创造力可以生成逼真的图像、音乐、文本甚至视频。从DeepFake到AI绘画从风格迁移到超分辨率GAN的应用无处不在。让我们一起探索这个让AI学会造假的神奇技术一、什么是生成对抗网络GAN1.1 GAN的诞生2014年Ian Goodfellow等人在论文《Generative Adversarial Nets》中提出了GAN这是深度学习领域最具革命性的创新之一。核心思想通过两个神经网络的对抗训练让生成器学会创造逼真的数据。类比理解生成器Generator 假币制造者试图制造逼真的假币判别器Discriminator 警察试图识别真假货币两者不断对抗最终假币制造者技术越来越高超警察也越来越难分辨1.2 GAN的基本架构随机噪声 z ~ N(0,1) ↓ ┌──────────────────┐ │ 生成器 G │ ← 学习从噪声生成假样本 │ (逆卷积网络) │ └────────┬─────────┘ ↓ G(z) 假样本 │ ┌─────┴─────┐ ↓ ↓ 真实样本x 假样本G(z) │ │ └─────┬─────┘ ↓ ┌──────────────────┐ │ 判别器 D │ ← 区分真实样本和生成样本 │ (卷积分类器) │ └────────┬─────────┘ ↓ D(x) → 1 (真实) D(G(z)) → 0 (虚假)1.3 GAN的数学原理目标函数Minimax Gamemin_G max_D V(D, G) E[log D(x)] E[log(1 - D(G(z)))]直观理解组件目标优化方向判别器 D最大化V正确区分真假样本生成器 G最小化V让D无法区分真假1.4 GAN vs 传统生成模型特性GANVAE自回归模型扩散模型训练稳定性较难较易中等较易生成质量高中等高很高多样性好中等好很好推理速度快快慢慢二、GAN的核心组件详解2.1 生成器Generator功能将随机噪声映射为目标数据分布class Generator(nn.Module): def __init__(self, latent_dim100, img_shape(1, 28, 28)): super(Generator, self).__init__() self.img_shape img_shape self.model nn.Sequential( nn.Linear(latent_dim, 128), nn.LeakyReLU(0.2, inplaceTrue), nn.Linear(128, 256), nn.BatchNorm1d(256, 0.8), nn.LeakyReLU(0.2, inplaceTrue), nn.Linear(256, 512), nn.BatchNorm1d(512, 0.8), nn.LeakyReLU(0.2, inplaceTrue), nn.Linear(512, 1024), nn.BatchNorm1d(1024, 0.8), nn.LeakyReLU(0.2, inplaceTrue), nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))), nn.Tanh() # 输出范围[-1, 1] ) def forward(self, z): img self.model(z) img img.view(img.size(0), *self.img_shape) return img2.2 判别器Discriminatorclass Discriminator(nn.Module): def __init__(self, img_shape(1, 28, 28)): super(Discriminator, self).__init__() self.model nn.Sequential( nn.Linear(int(torch.prod(torch.tensor(img_shape))), 512), nn.LeakyReLU(0.2, inplaceTrue), nn.Linear(512, 256), nn.LeakyReLU(0.2, inplaceTrue), nn.Linear(256, 1), nn.Sigmoid() # 输出概率 ) def forward(self, img): img_flat img.view(img.size(0), -1) validity self.model(img_flat) return validity2.3 DCGAN深度卷积GAN对于图像生成使用卷积层效果更好class DCGAN_Generator(nn.Module): def __init__(self, latent_dim100, channels1): super(DCGAN_Generator, self).__init__() self.init_size 7 self.l1 nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2)) self.conv_blocks nn.Sequential( nn.BatchNorm2d(128), nn.Upsample(scale_factor2), nn.Conv2d(128, 128, 3, stride1, padding1), nn.BatchNorm2d(128, 0.8), nn.LeakyReLU(0.2, inplaceTrue), nn.Upsample(scale_factor2), nn.Conv2d(128, 64, 3, stride1, padding1), nn.BatchNorm2d(64, 0.8), nn.LeakyReLU(0.2, inplaceTrue), nn.Conv2d(64, channels, 3, stride1, padding1), nn.Tanh() ) def forward(self, z): out self.l1(z) out out.view(out.shape[0], 128, self.init_size, self.init_size) img self.conv_blocks(out) return img三、GAN训练实战3.1 训练循环代码# 训练循环 for epoch in range(n_epochs): for i, (imgs, _) in enumerate(dataloader): batch_size imgs.size(0) # 真实标签和假标签 real torch.ones(batch_size, 1).to(device) fake torch.zeros(batch_size, 1).to(device) # 真实图像 real_imgs imgs.to(device) # # 训练生成器 # optimizer_G.zero_grad() # 采样随机噪声 z torch.randn(batch_size, latent_dim).to(device) # 生成图像 gen_imgs generator(z) # 计算生成器损失 g_loss adversarial_loss(discriminator(gen_imgs), real) g_loss.backward() optimizer_G.step() # # 训练判别器 # optimizer_D.zero_grad() # 真实图像的损失 real_loss adversarial_loss(discriminator(real_imgs), real) # 生成图像的损失 fake_loss adversarial_loss(discriminator(gen_imgs.detach()), fake) # 总判别器损失 d_loss (real_loss fake_loss) / 2 d_loss.backward() optimizer_D.step() # 打印进度 if i % 100 0: print(f[Epoch {epoch}/{n_epochs}] f[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}])3.2 训练技巧技巧具体做法效果标签平滑真实标签设为0.9而非1.0防止判别器过度自信学习率调整生成器学习率稍高帮助生成器追赶梯度惩罚使用WGAN-GP提高训练稳定性历史平均使用生成器历史版本增加多样性四、GAN的变体与演进4.1 条件GANCGAN创新在输入中加入条件信息如类别标签实现可控生成class CGAN_Generator(nn.Module): def __init__(self, latent_dim100, num_classes10): super(CGAN_Generator, self).__init__() self.label_emb nn.Embedding(num_classes, num_classes) self.model nn.Sequential( nn.Linear(latent_dim num_classes, 128), nn.LeakyReLU(0.2, inplaceTrue), nn.Linear(128, 256), nn.BatchNorm1d(256, 0.8), nn.LeakyReLU(0.2, inplaceTrue), nn.Linear(256, 512), nn.BatchNorm1d(512, 0.8), nn.LeakyReLU(0.2, inplaceTrue), nn.Linear(512, 784), # 28x28 nn.Tanh() ) def forward(self, noise, labels): # 将标签嵌入与噪声拼接 label_input self.label_emb(labels) gen_input torch.cat((label_input, noise), -1) img self.model(gen_input) img img.view(img.size(0), 1, 28, 28) return img # 使用示例生成数字7 z torch.randn(1, 100).to(device) label torch.tensor([7]).to(device) generated_img generator(z, label)4.2 Wasserstein GANWGAN问题原始GAN使用JS散度训练不稳定容易出现梯度消失解决方案使用Wasserstein距离Earth Movers Distance原始GANWGANSigmoid输出线性输出BCE Loss直接优化W距离判别器叫Discriminator叫Critic权重裁剪梯度惩罚WGAN-GP4.3 其他重要变体变体年份核心创新应用场景DCGAN2015使用卷积层图像生成基础CGAN2014条件控制可控生成WGAN2017Wasserstein距离稳定训练CycleGAN2017循环一致性风格迁移StyleGAN2018渐进式增长高分辨率人脸五、GAN的应用场景5.1 图像生成应用描述代表工作人脸生成生成逼真的人脸图像StyleGAN、StyleGAN2艺术创作AI绘画、风格迁移DALL-E、Midjourney数据增强扩充训练数据集各种条件GAN超分辨率图像放大不失真SRGAN、ESRGAN5.2 风格迁移CycleGAN原理学习两个域之间的映射无需成对数据照片 → 油画风格 马 → 斑马 夏天 → 冬天 苹果 → 橙子5.3 超分辨率重建SRGAN应用将低分辨率图像恢复为高分辨率优势传统方法模糊、细节丢失GAN方法感知质量更好细节更丰富六、GAN的挑战与解决方案6.1 模式坍塌Mode Collapse现象生成器只生成少数几种样本缺乏多样性原因生成器找到了能欺骗判别器的捷径方法原理效果WGAN改善损失函数中等Minibatch Discrimination批量内比较较好Spectral Normalization谱归一化好6.2 训练不稳定现象损失震荡、无法收敛、生成质量差解决方案学习率调整判别器学习率0.0001生成器学习率0.0002网络架构使用DCGAN架构避免全连接层标签平滑真实标签0.9假标签0.16.3 评估指标指标原理优点缺点Inception Score (IS)分类置信度多样性计算简单对模式敏感Fréchet Inception Distance (FID)特征分布距离与人类感知相关需要预训练模型七、实战项目生成手写数字7.1 完整训练代码import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets, transforms from torchvision.utils import save_image # 超参数 latent_dim 100 img_size 28 batch_size 64 lr 0.0002 n_epochs 100 device torch.device(cuda if torch.cuda.is_available() else cpu) # 数据加载 transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) dataloader DataLoader( datasets.MNIST(./data, trainTrue, downloadTrue, transformtransform), batch_sizebatch_size, shuffleTrue ) # 初始化模型 generator Generator().to(device) discriminator Discriminator().to(device) # 损失函数和优化器 adversarial_loss nn.BCELoss() optimizer_G optim.Adam(generator.parameters(), lrlr, betas(0.5, 0.999)) optimizer_D optim.Adam(discriminator.parameters(), lrlr, betas(0.5, 0.999)) # 训练循环同上 # ... print(训练完成)7.2 训练结果分析正常训练的迹象D loss 在 0.5 附近波动G loss 逐渐下降生成的图像越来越清晰问题症状解决方案D太强D loss≈0, G loss很高降低D的学习率减少D的训练次数G太强G loss≈0, 图像模式单一增加D的学习率检查模式坍塌训练不稳定loss剧烈震荡使用WGAN-GP调整学习率八、总结与展望8.1 GAN的核心要点对抗训练生成器和判别器相互博弈共同进步损失函数Minimax博弈达到纳什均衡训练技巧标签平滑、学习率调整、架构设计评估指标IS、FID等衡量生成质量8.2 GAN vs 扩散模型对比项GAN扩散模型生成质量高更高训练稳定性较难较易推理速度快单步慢多步去噪当前主流逐渐减少成为主流现状虽然扩散模型如Stable Diffusion在图像生成领域逐渐取代GAN但GAN在特定任务如实时生成、风格迁移上仍有优势。8.3 学习建议从简单开始先用全连接GAN理解原理再用DCGAN生成图像调参耐心GAN训练需要耐心多尝试不同的超参数可视化经常查看生成结果及时发现问题下一篇预告【第32篇】GAN实战进阶图像风格迁移与超分辨率重建我们将深入实践CycleGAN和SRGAN体验GAN在图像变换中的强大能力本文为系列第31篇详细讲解了GAN的原理与实战。有任何问题欢迎在评论区交流标签GAN、生成对抗网络、深度学习、图像生成、神经网络、AI创造力、PyTorch