-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathinference.py
123 lines (93 loc) · 3.62 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from sd_pipeline import GuidedSDPipeline
import torch
import numpy as np
from PIL import Image
import PIL
from typing import Callable, List, Optional, Union
from dataset import CustomCIFAR10Dataset, CustomLatentDataset
from vae import encode
import os
import wandb
import argparse
def parse():
parser = argparse.ArgumentParser(description="Inference")
parser.add_argument("--device", default="cuda")
parser.add_argument("--target", type=float, default=0.)
parser.add_argument("--guidance", type=float, default=0.)
parser.add_argument("--prompt", type=str, default= "a nice photo")
parser.add_argument("--out_dir", type=str, default= "")
parser.add_argument("--num_images", type=int, default=4)
parser.add_argument("--bs", type=int, default=4)
parser.add_argument("--seed", type=int, default=-1)
args = parser.parse_args()
return args
######### preparation ##########
args = parse()
device= args.device
save_file = True
reward_model_file='convnet.pth'
## Image Seeds
if args.seed > 0:
torch.manual_seed(args.seed)
shape = (args.num_images//args.bs, args.bs , 4, 64, 64)
init_latents = torch.randn(shape, device=device)
else:
init_latents = None
if args.out_dir == "":
args.out_dir = f'imgs/target{args.target}guidance{args.guidance}'
try:
os.makedirs(args.out_dir)
except:
pass
wandb.init(project="guided_dm", config={
'target': args.target,
'guidance': args.guidance,
'prompt': args.prompt,
'num_images': args.num_images
})
sd_model = GuidedSDPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", local_files_only=True)
sd_model.to(device)
reward_model = torch.load(reward_model_file).to(device)
reward_model.eval()
sd_model.setup_reward_model(reward_model)
sd_model.set_target(args.target)
sd_model.set_guidance(args.guidance)
image = []
for i in range(args.num_images // args.bs):
if init_latents is None:
init_i = None
else:
init_i = init_latents[i]
image_ = sd_model(args.prompt, num_images_per_prompt=args.bs, latents=init_i).images # List of PIL.Image objects
image.extend(image_)
###### evaluation and metric #####
gt_dataset = CustomCIFAR10Dataset(image)
gt_dataloader = torch.utils.data.DataLoader(gt_dataset, batch_size=20, shuffle=False, num_workers=8)
pred_dataset = CustomLatentDataset(image)
pred_dataloader = torch.utils.data.DataLoader(pred_dataset, batch_size=20, shuffle=False, num_workers=8)
ground_truth_reward_model = torch.load('reward_model.pth').to(device)
ground_truth_reward_model.eval()
with torch.no_grad():
total_reward_gt = []
for inputs in gt_dataloader:
inputs = inputs.to(device)
gt_rewards = ground_truth_reward_model(inputs)
#print(gt_rewards, torch.mean(gt_rewards))
total_reward_gt.append( gt_rewards.cpu().numpy() )
total_reward_gt = np.concatenate(total_reward_gt, axis=None)
wandb.log({"gt_reward_mean": np.mean(total_reward_gt) ,
"gt_reward_std": np.std(total_reward_gt) })
with torch.no_grad():
total_reward_pred= []
for inputs in pred_dataloader:
inputs = inputs.to(device)
inputs = encode(inputs)
pred_rewards = reward_model(inputs)
#print(pred_rewards, torch.mean(pred_rewards))
total_reward_pred.append(pred_rewards.cpu().numpy())
total_reward_pred = np.concatenate(total_reward_pred, axis=None)
wandb.log({"pred_reward_mean": np.mean(total_reward_pred) ,
"pred_reward_std": np.std(total_reward_pred) })
if save_file:
for idx, im in enumerate(image):
im.save(args.out_dir +'/'+ f'{idx}_gt_{total_reward_gt[idx]:.4f}_pred_{total_reward_pred[idx]:.4f}.png')