-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_bradeepv3_ce.py
150 lines (134 loc) · 4.9 KB
/
main_bradeepv3_ce.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#! /usr/bin/python3
from common_header import *
import branchy_seg_losses as BSL
from my_pixelwise_xentropy import BrXEntropyLoss
import glob
from exp_setup import *
from get_seg_datasets import LoadDataset
import errno
import concurrent.futures as concurrent
from itertools import repeat
import argparse
#from aux_functions import move_file
from train_test import eval_net
from allocate_cuda_device import allocate_cuda
from deepv3_funcs import eval_deepv3
#########################################
#from torch import device
from torch.nn import ModuleList
from torchvision import transforms as tr
#########################################
#model = 'mobilenet_v2_wdil'
parser = argparse.ArgumentParser(description='Evaluate branched deepv3.')
parser.add_argument('-t', '--type', type=str, default='resnet101')
parser.add_argument('-n', '--n_branches', type=int, default=0)
parser.add_argument('-N', '--Name', type=str, default='deep_v3_resnet101')
parser.add_argument('-p', '--print_file', type=str, default=None)
parser.add_argument('-e', '--num_epochs', type=int, default=0)
parser.add_argument('-l', '--lr', type=float, default=.01)
parser.add_argument('-m', '--min_lr', type=float, default=.0)
parser.add_argument('-L', '--base_lr', type=float, default=0)
parser.add_argument('-c', '--count_branches', action='store_true')
parser.add_argument('-s', '--skip', type=int, default=0)
parser.add_argument('-f', '--fine_tune', type=str, default='')
parser.set_defaults(count_branches=False)
args = parser.parse_args()
types = args.type
n_branches = args.n_branches
name = args.Name
num_epochs = args.num_epochs
lr = args.lr
min_lr = args.min_lr
base_lr = args.base_lr
count_branches = args.count_branches
if n_branches and not base_lr:
base_lr = lr
skip = args.skip
fine_tune = args.fine_tune
dataset = 'voc_seg'
use_file = args.print_file or f'{dataset}_deepv3_msgs.txt'
og_dir = os.getcwd()
r_dir = os.path.join(og_dir,f"{dataset}_results")
if fine_tune:
fine_tune = os.path.join(og_dir, fine_tune)
try:
os.makedirs(r_dir)
except OSError as err:
if err.errno != errno.EEXIST:
raise
#dataset
data_path = os.path.join(og_dir,f"datasets/{dataset.split('_')[0]}")
input_dim = 256#320
target_dim = None#input_dim//8
hand_data = LoadDataset(input_dim, target_dim, None, None)
train_set, val_set, test_set = hand_data.get_dataset(data_path, dataset)#, split_ratio, idx_path=data_path, idx=idxs)
#Can't pickle <function <lambda> at 0x7f7903994430>: attribute lookup <lambda> on __main__ failed
#def _def_prefetch(x):
# return min(10, max(2, 100 // x))
#def _def_nworkers(x):
# return max(1, min(6, 200 // x))
#max. suggested for this system is 6 workers
def _def_prefetch(x):
return 2#min(10, 40 // x)
def _def_nworkers(x):
return 4#min(6, 40 // x)
#max. suggested for this system is 6 workers
def _lr_law(x):
return 0 if x < 20 else 1
def _ch_es(x):#change the metric that we will follow in earlystopping when bs == 40
return x == 40
dts_info = {
'device': allocate_cuda(),
'name': name,
'main_dir': og_dir,
'n_procs': 1,
'n_rep': 1,
'res_dir': r_dir,
#dataset info
'input_dim': input_dim,
#'n_classes': n_classes,
'train_set': train_set,
'val_set': val_set,
'test_set': test_set,
#data fransforms
'use_file': use_file,
#Can't pickle <function <lambda> at 0x7f7903994430>: attribute lookup <lambda> on __main__ failed
'def_prefetch': _def_prefetch,
'def_nworkers': _def_nworkers,
'metrics': ['mIoU'], #['SSIM', 'MSE'] if ae_train else ['accuracy', 'precision', 'recall', 'Top3acc', 'Top5acc', 'F1'],
'ch_es': None,#_ch_es,
'minimize': False,
'n_branches': n_branches,
'count_branches': count_branches,
'lr': lr,
'min_lr': min_lr,
'base_lr': base_lr,
'num_epochs': num_epochs,
'batch_sizes': 32,
#'loss': BSL.LovaszSoftmax(classes='present', ignore=21, n_branches=n_branches, prev_out=True),
'loss': BrXEntropyLoss(ignore_index=21, b_reduction='sum', n_exits=n_branches+1),
'use_scheduler': True,
'nout_channels': 21,
'skip': skip,
'fine_tune': fine_tune,
'freeze_backbone': True if fine_tune else False,
'freeze_from': None, #4 if fine_tune else None,
'weighted_lr': False,
'branch_params': None,#{
# 'atrous_rates': [12, 24, 36], #talvez definir individualmente (p/ futuro)
#'nout_channels': 128,
#},
#'lr_law': _lr_law,
}
#kwargss = list(map(get_info,models,repeat(dts_info)))
##multiprocessing for all networks
#with concurrent.ProcessPoolExecutor(max_workers=2) as executor:
# res_dir = executor.map(eval_net,kwargss)
ret = eval_deepv3(dts_info)
msg = f'Finished training. model is saved @ {ret}'
if use_file:
with open(use_file,'a') as f:
f.write(msg + '\n')
f.write('-'*20 + '\n')
else:
print(msg)