网络构建
该版本,网络全程采用全连接网络,激活函数采用leakyReLU
from torch import nn
class D_Net(nn.Module):
def __init__(self):
super().__init__()
self.dnet = nn.Sequential(
nn.Linear(784, 512),
nn.LeakyReLU(),
nn.Linear(512, 256),
nn.LeakyReLU(),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
out = self.dnet(x)
return out
class G_Net(nn.Module):
def __init__(self):
super().__init__()
self.gnet = nn.Sequential(
nn.Linear(128,256),
nn.LeakyReLU(),
nn.Linear(256,512),
nn.LeakyReLU(),
nn.Linear(512,784)
)
def forward(self, x):
out = self.gnet(x)
return out
模型训练
from torch.utils.data import DataLoader
from torchvision import transforms,datasets
from torchvision.utils import save_image
import os
import torch
from torch import nn
from model import D_Net,G_Net
if __name__ == '__main__':
batch_size = 100
num_epoch = 100
if not os.path.exists("img"):
os.makedirs("img")
if not os.path.exists("./params"):
os.mkdir("./params")
mnist_data = datasets.MNIST("/data", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(mnist_data, batch_size, shuffle=True)
d_net = D_Net().cuda()
g_net = G_Net().cuda()
if os.path.exists("./params/d_net.pth"):
d_net.load_state_dict(torch.load("./params/d_net.pth"))
if os.path.exists("./params/g_net.pth"):
g_net.load_state_dict(torch.load("./params/g_net.pth"))
loss_fun = nn.BCELoss()
d_opt = torch.optim.Adam(d_net.parameters(),lr=0.0002,betas=(0.5, 0.999))
g_opt = torch.optim.Adam(g_net.parameters(),lr=0.0002,betas=(0.5, 0.999))
k = 0
for epoch in range(num_epoch):
for i,(img,label) in enumerate(train_loader):
real_img = img.reshape(-1,784).cuda()
#生成真标签1和假标签0
real_label = torch.ones(img.size(0),1).cuda()
fake_label = torch.zeros(img.size(0),1).cuda()
#训练判别器判断真图片
real_out = d_net(real_img)
d_loss_real = loss_fun(real_out,real_label)
#训练判别器判断假图片
z = torch.randn(img.size(0),128).cuda()
fake_img = g_net(z)
fake_out = d_net(fake_img)
d_loss_fake = loss_fun(fake_out,fake_label)
d_loss = d_loss_real+d_loss_fake
d_opt.zero_grad()
d_loss.backward()
d_opt.step()
#训练生成器#
z = torch.randn(img.size(0),128).cuda()
fake_img = g_net(z)
g_fake_out = d_net(fake_img)
g_loss = loss_fun(g_fake_out,real_label)
g_opt.zero_grad()
g_loss.backward()
g_opt.step()
if i%10 == 0:
print("Epoch:{0},d_loss{1},g_loss{2}".format(epoch,d_loss,g_loss))
real_img = real_img.reshape(-1,1,28,28)
fake_img = fake_img.reshape(-1,1,28,28)
save_image(real_img,"img/{}-real_img.jpg".format(k),nrow=10,normalize=True,scale_each=True)
save_image(fake_img, "img/{}-fake_img.jpg".format(k), nrow=10, normalize=True, scale_each=True)
torch.save(d_net.state_dict(), "./params/d_net.pth")
torch.save(g_net.state_dict(), "./params/g_net.pth")
k+=1
模型运行
from torchvision.utils import save_image
import os
import torch
from model import G_Net
if __name__ == '__main__':
batch_size = 100
num_epoch = 10
if not os.path.exists("test_img"):
os.makedirs("test_img")
if not os.path.exists("./params"):
os.mkdir("./params")
g_net = G_Net().cuda()
if os.path.exists("./params/g_net.pth"):
g_net.load_state_dict(torch.load("./params/g_net.pth"))
for i in range(num_epoch):
z = torch.randn(batch_size, 128).cuda()
fake_img = g_net(z)
fake_img = fake_img.reshape(-1, 1, 28, 28)
save_image(fake_img, "test_img/{}-fake_img.jpg".format(i), nrow=10, normalize=True, scale_each=True)
print(i)