1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
|
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch import optim
import torchvision
from torchvision import transforms
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
writer = SummaryWriter('./logs')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.5, ), std=(0.5, ))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=False, transform=transform)
# 设置drop_last丢弃最后不满一个batch_size的数据
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
class Generator(nn.Module):
def __init__(self, g_input_dim, g_output_dim):
super(Generator, self).__init__()
self.fc1 = nn.Linear(g_input_dim, 256)
self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)
# forward method
def forward(self, x):
x = F.leaky_relu(self.fc1(x), 0.2)
x = F.leaky_relu(self.fc2(x), 0.2)
x = F.leaky_relu(self.fc3(x), 0.2)
return torch.tanh(self.fc4(x))
class Discriminator(nn.Module):
def __init__(self, d_input_dim):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(d_input_dim, 1024)
self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
self.fc4 = nn.Linear(self.fc3.out_features, 1)
# forward method
def forward(self, x):
x = F.leaky_relu(self.fc1(x), 0.2)
x = F.dropout(x, 0.3)
x = F.leaky_relu(self.fc2(x), 0.2)
x = F.dropout(x, 0.3)
x = F.leaky_relu(self.fc3(x), 0.2)
x = F.dropout(x, 0.3)
return torch.sigmoid(self.fc4(x))
z_dim = 100
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)
# build network
G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
D = Discriminator(mnist_dim).to(device)
# 添加网络图到tensorboard
writer.add_graph(G, input_to_model=torch.randn(batch_size, z_dim))
writer.add_graph(D, input_to_model=torch.randn(batch_size, mnist_dim))
# optimizer
lr = 0.0002
g_optimizer = optim.Adam(G.parameters(), lr = lr)
d_optimizer = optim.Adam(D.parameters(), lr = lr)
# loss
criterion = nn.BCELoss()
def d_train(x):
D.zero_grad()
x_real, y_real = x.view(-1, mnist_dim).to(device), torch.ones(batch_size, 1).to(device)
print(x_real.shape, y_real.shape)
d_output = D(x_real)
print(d_output.shape, y_real.shape)
d_real_loss = criterion(d_output, y_real)
d_real_score = d_output
z = torch.randn(batch_size, z_dim).to(device)
x_fake, y_fake = G(z), torch.zeros(batch_size, 1).to(device)
d_output = D(x_fake)
d_fake_loss = criterion(d_output, y_fake)
d_fake_score = d_output
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
d_optimizer.step()
return d_loss.item()
def g_train(x):
G.zero_grad()
z = torch.randn(batch_size, z_dim).to(device)
y = torch.ones(batch_size, 1).to(device)
g_output = G(z)
d_output = D(g_output)
g_loss = criterion(d_output, y)
g_loss.backward()
g_optimizer.step()
return g_loss.item()
epochs = 10
step = 0
for epoch in range(epochs):
d_losses, g_losses = [], []
for batch_idx, (x, _) in enumerate(train_loader):
step += 1
d_losses.append(d_train(x))
g_losses.append(g_train(x))
print('[%d/%d]: [%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
epoch, epochs,batch_idx, len(train_loader), torch.mean(torch.FloatTensor(d_losses)), torch.mean(torch.FloatTensor(g_losses))))
writer.add_scalar('g_loss', torch.mean(torch.FloatTensor(g_losses)), step)
writer.add_scalar('d_loss', torch.mean(torch.FloatTensor(d_losses)), step)
if batch_idx % 10 == 0:
with torch.no_grad():
test_z = torch.randn(batch_size, z_dim).to(device)
generated = G(test_z)
img = img = torchvision.utils.make_grid(generated.view(generated.size(0), 1, 28, 28))
writer.add_image(f'mnist_{epoch}_{batch_idx}', img, global_step=step)
if epoch % 10 == 0:
D.eval()
G.eval()
torch.save({
'epoch': epoch,
'd_model_state_dict': D.state_dict(),
'g_model_state_dict': G.state_dict(),
'd_optimizer_state_dict': d_optimizer.state_dict(),
'd_loss': d_losses,
'g_optimizer_state_dict': g_optimizer.state_dict(),
'g_loss': g_losses,
}, f'./checkpoint/epoch{epoch}_weight.pth')
D.train()
G.train()
writer.close()
torch.save(D, './model/discriminator.pt')
torch.save(G, './model/generator.pt')
|