This repository has been archived by the owner on Oct 8, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsymbol_net1.py
99 lines (95 loc) · 4.74 KB
/
symbol_net1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
'''
Reproducing paper:
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks"
'''
import mxnet as mx
def residual_unit(data, data_prev, num_filter, stride, dim_match, name, bottle_neck=True, bn_mom=0.9, workspace=512):
"""Return ResNet Unit symbol for building ResNet
Parameters
----------
data : str
Input data
num_filter : int
Number of output channels
bnf : int
Bottle neck channels factor with regard to num_filter
stride : tupe
Stride used in convolution
dim_match : Boolen
True means channel number between input and output is the same, otherwise means differ
name : str
Base name of the operators
workspace : int
Workspace used in convolution operator
"""
if bottle_neck:
# the same as https://github.com/facebook/fb.resnet.torch#notes, a bit difference with origin paper
bn1 = mx.sym.BatchNorm(data=data, fix_gamma=False,
eps=2e-5, momentum=bn_mom, name=name + '_bn1')
act1 = mx.sym.Activation(
data=bn1, act_type='relu', name=name + '_relu1')
conv1 = mx.sym.Convolution(data=act1, num_filter=int(num_filter*0.25),
kernel=(1, 1), stride=(1, 1), pad=(0, 0),
no_bias=True, workspace=workspace, name=name + '_conv1')
bn2 = mx.sym.BatchNorm(data=conv1, fix_gamma=False,
eps=2e-5, momentum=bn_mom, name=name + '_bn2')
act2 = mx.sym.Activation(
data=bn2, act_type='relu', name=name + '_relu2')
conv2 = mx.sym.Convolution(data=act2, num_filter=int(num_filter*0.25),
kernel=(3, 3), stride=stride, pad=(1, 1),
no_bias=True, workspace=workspace, name=name + '_conv2')
bn3 = mx.sym.BatchNorm(data=conv2, fix_gamma=False,
eps=2e-5, momentum=bn_mom, name=name + '_bn3')
act3 = mx.sym.Activation(
data=bn3, act_type='relu', name=name + '_relu3')
conv3 = mx.sym.Convolution(data=act3, num_filter=int(num_filter*0.5),
kernel=(1, 1), stride=(1, 1), pad=(0, 0),
no_bias=True, workspace=workspace, name=name + '_conv3')
if dim_match:
shortcut = data_prev
else:
shortcut = mx.sym.Convolution(data=data, num_filter=int(num_filter*0.5),
kernel=(1, 1), stride=stride, no_bias=True,
workspace=workspace, name=name+'_sc')
return mx.sym.Concat(conv3, shortcut), conv3
else:
raise ValueError("must have bottleneck structure")
def net1(units, num_stage, filter_list, num_class, bottle_neck=True, bn_mom=0.9, workspace=512):
"""Return ResNet symbol of cifar10 and imagenet
Parameters
----------
units : list
Number of units in each stage
num_stage : int
Number of stage
filter_list : list
Channel size of each stage
num_class : int
Ouput size of symbol
workspace : int
Workspace used in convolution operator
"""
num_unit = len(units)
assert(num_unit == num_stage)
data = mx.sym.Variable(name='data')
data = mx.sym.BatchNorm(data=data, fix_gamma=True,
eps=2e-5, momentum=bn_mom, name='bn_data')
body = mx.sym.Convolution(data=data, num_filter=filter_list[0],
kernel=(3, 3), stride=(1, 1), pad=(1, 1),
no_bias=True, name="conv0", workspace=workspace)
body_prev = None
for i in range(num_stage):
body, body_prev = residual_unit(body, body_prev, filter_list[i+1], (1 if i == 0 else 2, 1 if i == 0 else 2), False,
name='stage%d_unit%d' % (i + 1, 1), bottle_neck=bottle_neck, workspace=workspace)
for j in range(units[i]-1):
body, body_prev = residual_unit(body, body_prev, filter_list[i+1], (1, 1), True, name='stage%d_unit%d' % (i + 1, j + 2),
bottle_neck=bottle_neck, workspace=workspace)
bn1 = mx.sym.BatchNorm(data=body, fix_gamma=False,
eps=2e-5, momentum=bn_mom, name='bn1')
relu1 = mx.sym.Activation(data=bn1, act_type='relu', name='relu1')
# Although kernel is not used here when global_pool=True, we should put one
pool1 = mx.sym.Pooling(data=relu1, global_pool=True, kernel=(7, 7),
pool_type='avg', name='pool1')
flat = mx.sym.Flatten(data=pool1)
fc1 = mx.sym.FullyConnected(data=flat, num_hidden=num_class, name='fc1')
return mx.sym.SoftmaxOutput(data=fc1, name='softmax')