-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcalculate_scores.py
100 lines (80 loc) · 3.11 KB
/
calculate_scores.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import sys
import torch
from ignite.engine import Engine
from ignite.metrics import FID, InceptionScore
from torchvision import transforms, datasets
from dataset_utils import gain_sample
from config import DATASETS_CONFIG
from model import Generator
from train_stylegan_model import DEVICE, INPUT_DIM, LATENT_DIM, MAPPING_LAYER_NUM, STEP, MINI_BATCH_SIZE
"""
This file calculates evaluates a given dataset and calculate the FID score
"""
# load model
generator = Generator(MAPPING_LAYER_NUM, LATENT_DIM, INPUT_DIM).to(DEVICE)
if os.path.exists('checkpoint/trained.pth'):
# Load data from last checkpoint
print('Loading pre-trained model...')
checkpoint = torch.load('checkpoint/trained.pth')
generator.load_state_dict(checkpoint['generator'])
# define transformation to resize image in order to be used by ignite
resize_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((299, 299)),
transforms.ToTensor()
])
# define function to interpolate images in order to be used by ignite
def interpolate(image_batch):
"""
Organize images in a batch into a format usable by ignite
"""
transformed_image = []
for image in image_batch:
transformed_image.append(resize_transform(image))
return torch.stack(transformed_image)
# set global model variable
global model
def evaluation_step(data_batch):
"""
Define evaluation step to generate images and interpolate generated
and real images
"""
image_batch, _ = data_batch
alpha = 0 # not implemented, only for style mixing purpose
with torch.no_grad():
noise_sample = []
for m in range(STEP + 1):
size = 4 * 2 ** m # Due to the upsampling, size of noise will grow
noise_sample.append(torch.randn((MINI_BATCH_SIZE, 1, size, size),
device=DEVICE))
latent_sample = [torch.randn((MINI_BATCH_SIZE, LATENT_DIM),
device=DEVICE)]
gen_img = generator(latent_sample, STEP, alpha, noise_sample)
generated_images = interpolate(gen_img)
real_images = interpolate(image_batch)
return generated_images, real_images
if __name__ == '__main__':
if len(sys.argv) != 2:
print("Usage: python calculate_scores.py dataset_name {flowers}")
sys.exit(1)
dataset_name = sys.argv[1]
dataset_config = DATASETS_CONFIG[dataset_name]
print("Dataset name:", dataset_name)
fid = FID(device=DEVICE)
inception = InceptionScore(device=DEVICE, output_transform=lambda x: x[0])
evaluator = Engine(evaluation_step)
fid.attach(evaluator, "fid")
inception.attach(evaluator, "inception")
image_folder_path = dataset_config['image_folder_path']
dataset = datasets.ImageFolder(image_folder_path)
resolution = 4 * 2 ** STEP
origin_loader = gain_sample(dataset, MINI_BATCH_SIZE, resolution)
model = generator
model.to(DEVICE)
model.eval()
evaluator.run(origin_loader, max_epochs=1)
metrics = evaluator.state.metrics
fid_score = metrics['fid']
is_score = metrics['inception']
print(fid_score, is_score)