Python全系列 教程
3567个小节阅读:5931.1k
目录
鸿蒙应用开发
C语言快速入门
JAVA全系列 教程
面向对象的程序设计语言
Python全系列 教程
Python3.x版本,未来主流的版本
人工智能 教程
顺势而为,AI创新未来
大厂算法 教程
算法,程序员自我提升必经之路
C++ 教程
一门通用计算机编程语言
微服务 教程
目前业界流行的框架组合
web前端全系列 教程
通向WEB技术世界的钥匙
大数据全系列 教程
站在云端操控万千数据
AIGC全能工具班
A A
White Night
TorchGAN是基于PyTorch开发的GAN设计框架,它可以快速开发和定制GAN
安装方式:
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple torchgan==0.1.0
xxxxxxxxxx
import os
import matplotlib.pyplot as plt
import numpy as np
# Pytorch and Torchvision Imports
import torch
import torch.nn as nn
import torchvision
from torch.optim import Adam
import torch.nn as nn
import torch.utils.data as data
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torchvision.utils as vutils
# Torchgan Imports
import torchgan
from torchgan.models import *
from torchgan.losses import *
from torchgan.trainer import Trainer
# 防止核崩溃
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
# 数据预处理
data_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
# 加载数据
trainset = dsets.ImageFolder('faces', data_transform)
dataloader = torch.utils.data.DataLoader(trainset, batch_size=64,shuffle=True)
# Plot some of the training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0][:64], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()
# 定义GAN网络
dcgan_network = {
# 定义生成器
"generator": {
"name": DCGANGenerator, # 深度卷积生成器
# DCGANGenerator的参数设置
"args": {
"encoding_dims": 100,
"out_channels": 3,
"step_channels": 32,
"nonlinearity": nn.LeakyReLU(0.2),
"last_nonlinearity": nn.Tanh()
},
"optimizer": {
"name": Adam,
"args": {
"lr": 0.0001,
"betas": (0.5, 0.999)
}
}
},
# 鉴别器
"discriminator": {
"name": DCGANDiscriminator, # 深度卷积鉴别器
"args": {
"in_channels": 3,
"step_channels": 32,
"nonlinearity": nn.LeakyReLU(0.2),
"last_nonlinearity": nn.LeakyReLU(0.2)
},
"optimizer": {
"name": Adam,
"args": {
"lr": 0.0003,
"betas": (0.5, 0.999)
}
}
}
}
# TorchGAN支持的损失函数
wgangp_losses = [WassersteinGeneratorLoss(), WassersteinDiscriminatorLoss(), WassersteinGradientPenalty()]
# 检查设备是否支持CUDA,并设置训练的轮数
if torch.cuda.is_available():
device = torch.device("cuda:0")
# Use deterministic cudnn algorithms
torch.backends.cudnn.deterministic = True
epochs = 400
else:
device = torch.device("cpu")
epochs = 100
print("Device: {}".format(device))
print("Epochs: {}".format(epochs))
# 创建训练器对象
trainer = Trainer(dcgan_network, wgangp_losses, sample_size=64, epochs=epochs, device=device)
# 开始训练,训练后,会在代码根目录下生成一个images文件夹
trainer(dataloader)