Python全系列 教程
3567个小节阅读:5930.4k
目录
鸿蒙应用开发
C语言快速入门
JAVA全系列 教程
面向对象的程序设计语言
Python全系列 教程
Python3.x版本,未来主流的版本
人工智能 教程
顺势而为,AI创新未来
大厂算法 教程
算法,程序员自我提升必经之路
C++ 教程
一门通用计算机编程语言
微服务 教程
目前业界流行的框架组合
web前端全系列 教程
通向WEB技术世界的钥匙
大数据全系列 教程
站在云端操控万千数据
AIGC全能工具班
A A
White Night
xxxxxxxxxx
d = D(3,32) # 创建鉴别器对象
g = G(3,128,1024,100) # 创建生成器对象
criterion = nn.BCELoss() # 二分类交叉熵损失函数
d_optimizer = torch.optim.Adam(d.parameters(),lr=0.0003) # 鉴别器的优化器对象
g_optimizer = torch.optim.Adam(g.parameters(),lr=0.0003) # 生成器的优化器对象
# 训练函数
def train(d,g,criterion,d_optimizer,g_optimizer,
epochs=1,show_every=100,print_every=10):
iter_count = 0
for epoch in range(epochs): # 外层循环“轮”
for inputs,_ in trainloader: # 每批抽5张图
real_inputs = inputs # 真图样本的特征
fake_inputs = g(torch.randn(5,100)) # 生成器生成假图
real_labels = torch.ones(real_inputs.size(0)) # 真图标签记为1
fake_labels = torch.zeros(fake_inputs.size(0)) # 假图标记为0
real_outputs = d(real_inputs) # 使用鉴别器对真图进行鉴别,每张图鉴别结果为0-1之间的数
d_loss_real = criterion(real_outputs,real_labels) # 鉴别器对真图的损失值
real_scores = real_outputs
fake_outputs = d(fake_inputs) # 使用鉴别器对假图进行鉴别,每张图鉴别结果为0-1之间的数
d_loss_fake = criterion(fake_outputs,fake_labels) # 鉴别器对假图的损失值
fake_scores = fake_outputs
d_loss = d_loss_real + d_loss_fake # 鉴别器的总的损失值
d_optimizer.zero_grad() # 清空上一次的梯度
d_loss.backward() # 反向传播,计算梯度
d_optimizer.step() # 更新鉴别器参数
fake_inputs = g(torch.randn(5,100)) # 生成器生成假图
fake_outputs = d(fake_inputs)
g_loss = criterion(fake_outputs,real_labels)
g_optimizer.zero_grad() # 清空上一次的梯度
g_loss.backward() # 反向传播,计算梯度
g_optimizer.step() # 更新生成器参数
if (iter_count % show_every == 0):
print('Epoch:{},Iter: {}, D: {:.4}, G:{:.4}'.format(epoch,iter_count, d_loss.item(), g_loss.item()))
picname = "Epoch_"+str(epoch)+"Iter_"+str(iter_count)
imshow(torchvision.utils.make_grid(fake_inputs.data),picname)
if (iter_count%print_every == 0):
print('Epoch:{},Iter: {}, D: {:.4}, G:{:.4}'.format(epoch,iter_count, d_loss.item(), g_loss.item()))
iter_count += 1
print("Finish Training,OK!")
train(d,g,criterion,d_optimizer,g_optimizer,epochs=1)