在本节中,我们将创建一个使用 PyTorch 的图像生成项目。我们将使用生成对抗网络(GANs)来生成图像。这是一个相对简单的示例,旨在帮助您快速上手 PyTorch 并理解其基本概念。
1. 环境准备
在开始之前,确保您已安装必要的库。可以使用以下命令来安装 PyTorch 和其他依赖库:
1
| pip install torch torchvision matplotlib
|
2. 数据集
我们将使用 CIFAR-10
数据集,该数据集包含 60,000 张 32x32 的彩色图像分为 10 个类别。我们将使用 torchvision
来下载和加载数据集。
1 2 3 4 5 6 7 8 9 10 11 12
| import torchvision import torchvision.transforms as transforms
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
|
3. 定义生成对抗网络
3.1 判别器(Discriminator)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| import torch import torch.nn as nn
class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.model = nn.Sequential( nn.Linear(32*32*3, 128), nn.LeakyReLU(0.2), nn.Linear(128, 1), nn.Sigmoid() )
def forward(self, x): x = x.view(x.size(0), -1) return self.model(x)
|
3.2 生成器(Generator)
1 2 3 4 5 6 7 8 9 10 11 12
| class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.model = nn.Sequential( nn.Linear(100, 128), nn.ReLU(), nn.Linear(128, 32*32*3), nn.Tanh() )
def forward(self, z): return self.model(z).view(-1, 3, 32, 32)
|
4. 训练模型
4.1 定义训练过程
我们将实现 GAN 的训练过程,包括更新生成器和判别器。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
| import torch.optim as optim
discriminator = Discriminator() generator = Generator()
criterion = nn.BCELoss()
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
num_epochs = 50 for epoch in range(num_epochs): for i, (images, _) in enumerate(trainloader): discriminator.zero_grad() real_labels = torch.ones(images.size(0), 1) fake_labels = torch.zeros(images.size(0), 1) outputs = discriminator(images) d_loss_real = criterion(outputs, real_labels) z = torch.randn(images.size(0), 100) fake_images = generator(z) outputs = discriminator(fake_images.detach()) d_loss_fake = criterion(outputs, fake_labels)
d_loss = d_loss_real + d_loss_fake d_loss.backward() optimizer_d.step()
generator.zero_grad() outputs = discriminator(fake_images) g_loss = criterion(outputs, real_labels) g_loss.backward() optimizer_g.step()
print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')
|
5. 生成与可视化图像
我们将使用生成器生成图像并将其可视化。
1 2 3 4 5 6 7 8 9 10 11 12
| import matplotlib.pyplot as plt
z = torch.randn(64, 100) fake_images = generator(z).detach()
grid = torchvision.utils.make_grid(fake_images, nrow=8, normalize=True) plt.imshow(grid.permute(1, 2, 0).numpy()) plt.title('Generated Images') plt.axis('off') plt.show()
|
6. 总结
在本节中,我们学习了如何使用 PyTorch 实现简单的生成对抗网络(GAN)。我们定义了判别器和生成器模型,加载了 CIFAR-10 数据集,并实现了训练过程。最后,我们使用生成器生成并可视化了一组图像。
继续学习可以探索更复杂的网络架构和技术,比如条件 GANs、生成对抗网络的改进等。