-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
38 lines (34 loc) · 919 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torchvision
from dataset_class import CGAN_dataset
from torch.utils.data import DataLoader
from torch import Tensor
def LSGAN_D(real, fake):
return (torch.mean((real - 1)**2) + torch.mean(fake**2))
def LSGAN_G(fake):
return torch.mean((fake - 1)**2)
def PSNR(img1,img2,range_value):
mse = torch.mean((img1 - img2)**2)
'''
if torch.max(img1) > torch.max(img2):
return 20 * torch.log10(torch.max(img1)/ torch.sqrt(mse))
else:
return 20 * torch.log10(torch.max(img2)/ torch.sqrt(mse))
'''
return 20 * torch.log10(range_value/ torch.sqrt(mse))
def get_loaders(
get_dir,
batch_size,
img_transform,
data_shuffle,
):
data = CGAN_dataset(
img_dir = get_dir,
transform = img_transform,
)
data_loader = DataLoader(
data,
batch_size = batch_size,
shuffle = data_shuffle,
)
return data_loader