-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtemp_disc.py
71 lines (62 loc) · 2.74 KB
/
temp_disc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import functools
from utils import Attention, DBlock, snconv3d, snlinear, Conv3_1d, MDmin
class Discriminator(nn.Module):
def __init__(self, params):
super(Discriminator, self).__init__()
self.p = params
# Architecture
self.arch = {'in_channels' : [item * self.p.filterD for item in [1, 2, 4, 8, 16]],
'out_channels' : [item * self.p.filterD for item in [2, 4, 8, 16, 16]],
'downsample' : [True] * 5 + [False],
'resolution' : [64, 32, 16, 8, 4, 4],
'attention' : {2**i: 2**i in [int(item) for item in '16'.split('_')]
for i in range(2,8)}}
# Prepare model
if self.p.triplet:
self.input_conv = snconv3d(1, self.arch['in_channels'][0])
else:
self.input_conv = snconv3d(3, self.arch['in_channels'][0])
self.blocks = []
for index in range(len(self.arch['out_channels'])):
self.blocks += [[DBlock(in_channels=self.arch['in_channels'][index],
out_channels=self.arch['out_channels'][index],
preactivation=True,
downsample=(nn.AvgPool3d(2) if self.arch['downsample'][index] else None))]]
if self.arch['attention'][self.arch['resolution'][index]]:
self.blocks[-1] += [Attention(self.arch['out_channels'][index])]
if False:#self.p.md:
self.blocks[-2] = [DBlock(in_channels=self.arch['in_channels'][-2]+1,
out_channels=self.arch['out_channels'][-2],
preactivation=True,
downsample=(nn.AvgPool3d(2) if self.arch['downsample'][-2] else None))]
self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
if not self.p.triplet:
self.linear = snlinear(self.arch['out_channels'][-1], 1)
self.activation = nn.ReLU(inplace=True)
self.init_weights()
def init_weights(self):
self.param_count = 0
for module in self.modules():
if (isinstance(module, nn.Conv3d)
or isinstance(module, nn.Linear)):
init.orthogonal_(module.weight)
self.param_count += sum([p.data.nelement() for p in module.parameters()])
print('Param count for D''s initialized parameters: %d' % self.param_count)
def forward(self, x):
# Run input conv
h = self.input_conv(x)
# Loop over blocks
for index, blocklist in enumerate(self.blocks):
if False:#index == len(self.blocks)-2:
h,s = MDmin(h, lidc=self.p.lidc)
for block in blocklist:
h = block(h)
# Apply global sum pooling as in SN-GAN
h = torch.sum(self.activation(h), [2, 3, 4])
if not self.p.triplet:
h = self.linear(h)
return h