-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_cnn.py
145 lines (116 loc) · 5.23 KB
/
main_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from matplotlib import pyplot as plt
# Ref: https://github.com/lyeoni/pytorch-mnist-VAE/blob/master/pytorch-mnist-VAE.ipynb
# Ref: https://towardsdatascience.com/building-a-convolutional-vae-in-pytorch-a0f54c947f71
class VAE(nn.Module):
def __init__(self, imgChannels=1, featureDim=32*20*20, zDim=256):
super(VAE, self).__init__()
# Initializing the 2 convolutional layers and 2 full-connected layers for the encoder
self.encConv1 = nn.Conv2d(imgChannels, 16, 5)
self.encConv2 = nn.Conv2d(16, 32, 5)
self.encFC1 = nn.Linear(featureDim, zDim)
self.encFC2 = nn.Linear(featureDim, zDim)
# Initializing the fully-connected layer and 2 convolutional layers for decoder
self.decFC1 = nn.Linear(zDim, featureDim)
self.decConv1 = nn.ConvTranspose2d(32, 16, 5)
self.decConv2 = nn.ConvTranspose2d(16, imgChannels, 5)
def encoder(self, x):
# Input is fed into 2 convolutional layers sequentially
# The output feature map are fed into 2 fully-connected layers to predict mean (mu) and variance (logVar)
# Mu and logVar are used for generating middle representation z and KL divergence loss
x = F.relu(self.encConv1(x))
x = F.relu(self.encConv2(x))
x = x.view(-1, 32 * 20 * 20)
mu = self.encFC1(x)
logVar = self.encFC2(x)
return mu, logVar
def sampling(self, mu, log_var):
# Reparameterization takes in the input mu and logVar and sample the mu + std * eps
std = torch.exp(log_var / 2)
eps = torch.randn_like(std)
return mu + std * eps
def decoder(self, z):
# z is fed back into a fully-connected layers and then into two transpose convolutional layers
# The generated output is the same size of the original input
x = F.relu(self.decFC1(z))
x = x.view(-1, 32, 20, 20)
x = F.relu(self.decConv1(x))
x = torch.sigmoid(self.decConv2(x))
return x
def forward(self, x):
# The entire pipeline of the VAE: encoder -> reparameterization -> decoder
# output, mu, and logVar are returned for loss computation
mu, log_var = self.encoder(x)
z = self.sampling(mu, log_var)
out = self.decoder(z)
return out, mu, log_var
def loss_function(recon_x, x, mu, log_var):
BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return BCE + KLD
def train(epoch, vae, train_loader, optimizer):
vae.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.cuda()
optimizer.zero_grad()
recon_batch, mu, log_var = vae(data)
loss = loss_function(recon_batch, data, mu, log_var)
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item() / len(data)))
print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))
def test(vae, test_loader):
vae.eval()
test_loss = 0
with torch.no_grad():
for data, _ in test_loader:
data = data.cuda()
recon, mu, log_var = vae(data)
# sum up batch loss
test_loss += loss_function(recon, data, mu, log_var).item()
test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))
def main():
print('this is hw3')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
bs = 100
# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)
# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)
vae = VAE()
if torch.cuda.is_available():
vae.cuda()
optimizer = optim.Adam(vae.parameters())
# Training:
for epoch in range(1, 2):
train(epoch, vae, train_loader, optimizer)
test(vae, test_loader)
images, labels = iter(test_loader).next()
sample_pic = images[0]
print(sample_pic.shape)
plt.imshow(sample_pic[0].reshape(28, 28), cmap="gray")
plt.show()
with torch.no_grad():
sample_pic = sample_pic.cuda()
sample_pic = sample_pic.unsqueeze(0)
print(sample_pic.shape)
result = vae.forward(sample_pic)
print('Got result')
result = result[0].cpu()
plt.imshow(result.reshape(28, 28), cmap="gray")
plt.show()
if __name__ == '__main__':
main()