本文将介绍一个使用PyTorch实现的生成对抗网络(GAN)项目,并详细解释GAN的原理、代码实现以及应用场景。项目的代码可以在GitHub链接中找到。
1. 引言
生成对抗网络是一种强大的深度学习模型,它由两个互相竞争的网络组成:生成器和判别器。生成器负责生成逼真的数据样本,而判别器则试图区分生成的样本与真实样本。通过两个网络的不断对抗与学习,GAN能够产生高质量的合成数据。在本项目中,我们将使用PyTorch框架实现一个基本的GAN模型,并应用于图像生成任务。
2. GAN的原理
GAN的核心思想是使用生成器和判别器两个网络进行对抗性训练。生成器接受随机噪声作为输入,并生成伪造的数据样本。判别器则接受真实样本和生成器生成的样本,并试图将它们区分开来。通过反复迭代的对抗过程,生成器逐渐改进其生成的样本,以尽量欺骗判别器。
GAN的训练过程可以总结为以下几个步骤:
- 生成器生成一批伪造的样本。
- 判别器分别对真实样本和生成的样本进行分类。
- 根据判别器的分类结果,计算生成器和判别器的损失函数。
- 更新生成器和判别器的权重参数。
- 重复步骤1-4,直到达到预定的训练轮数或损失收敛。
3. 代码实现
以下是使用PyTorch实现的GAN的关键代码片段:
导入所需的库和模块
import torch
from torch import nn
from torch.autograd.variable import Variable
import torchvision
import torchvision.transforms as transforms
# Preprocess
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]
)
# Training data
train_set = torchvision.datasets.MNIST(
root='.', train=True, download=True, transform=transform
)
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=32, shuffle=True
)
判别器部分。 判别器网络是对图像真实与否进行分类。
输入:28×28像素 -> 一个长度为784的向量 输出:一个单独的值<表示图像是否是实际的MNIST数字>表示图像是否是实际的MNIST数字>
# Our Discriminator classes
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(784, 1024),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
out = self.model(x.view(x.size(0), 784))
out = out.view(out.size(0), -1)
return out.cuda()
discriminator = Discriminator()
生成器部分。 生成器网络负责创建实际的图像。
输入:一个长度为100的向量<纯随机噪声> 输出:一个长度为784的向量 -> 28×28像素纯随机噪声>
# Our Generator class
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 784),
nn.Tanh()
)
def forward(self, x):
x = x.view(x.size(0), 100)
out = self.model(x).cuda()
return out
generator = Generator()
把模型移动到GPU上
# If we have a GPU with CUDA, use it
if torch.cuda.is_available():
print("Using CUDA")
discriminator.cuda()
generator.cuda()
# Setup loss function and optimizers
lr = 0.01
num_epochs = 40
num_batches = len(train_loader)
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
开始训练循环。 训练GAN的关键是我们需要在一个循环中更新生成器和判别器。
# Convenience function for training our Discriminator
def train_discriminator(discriminator, real_images, real_labels, fake_images, fake_labels):
discriminator.zero_grad()
# Get the predictions, loss, and score of the real images
predictions = discriminator(real_images)
real_loss = criterion(predictions, real_labels)
real_score = predictions
# Get the predictions, loss, and score of the fake images
predictions = discriminator(fake_images)
fake_loss = criterion(predictions, fake_labels)
fake_score = predictions
# Calculate the total loss, update the weights, and update the optimizer
d_loss = real_loss + fake_loss
d_loss.backward()
d_optimizer.step()
return d_loss, real_score, fake_score
# Convenience function for training our Generator
def train_generator(generator, discriminator_outputs, real_labels):
generator.zero_grad()
# Calculate the total loss, update the weights, and update the optimizer
g_loss = criterion(discriminator_outputs, real_labels)
g_loss.backward()
g_optimizer.step()
return g_loss
for epoch in range(num_epochs):
for n, (images, _) in enumerate(train_loader):
# (1) Prepare the real data for the Discriminator
real_images = Variable(images).cuda()
real_labels = Variable(torch.ones(images.size(0))).reshape([32, 1]).cuda()
# (2) Prepare the random noise data for the Generator
noise = Variable(torch.randn(images.size(0), 100)).cuda()
# (3) Prepare the fake data for the Discriminator
fake_images = generator(noise)
fake_labels = Variable(torch.zeros(images.size(0))).reshape([32, 1]).cuda()
# (4) Train the discriminator on real and fake data
d_loss, real_score, fake_score = train_discriminator(
discriminator,
real_images, real_labels,
fake_images, fake_labels
)
# (5a) Generate some new fake images from the Generator.
# (5b) Get the label predictions of the Discriminator on that fake data.
noise = Variable(torch.randn(images.size(0), 100)).cuda()
fake_images = generator(noise)
outputs = discriminator(fake_images)
# (6) Train the generator
g_loss = train_generator(generator, outputs, real_labels)
Save the model
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')