forked from MhLiao/DB
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexperiment.py
103 lines (77 loc) · 2.55 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from concern.config import Configurable, State
from concern.log import Logger
from structure.builder import Builder
from structure.representers import *
from structure.measurers import *
from structure.visualizers import *
from data.data_loader import *
from data import *
from training.model_saver import ModelSaver
from training.checkpoint import Checkpoint
from training.optimizer_scheduler import OptimizerScheduler
class Structure(Configurable):
builder = State()
representer = State()
measurer = State()
visualizer = State()
def __init__(self, **kwargs):
self.load_all(**kwargs)
@property
def model_name(self):
return self.builder.model_name
class TrainSettings(Configurable):
data_loader = State()
model_saver = State()
checkpoint = State()
scheduler = State()
epochs = State(default=10)
def __init__(self, **kwargs):
kwargs['cmd'].update(is_train=True)
self.load_all(**kwargs)
if 'epochs' in kwargs['cmd']:
self.epochs = kwargs['cmd']['epochs']
class ValidationSettings(Configurable):
data_loaders = State()
visualize = State()
interval = State(default=100)
exempt = State(default=-1)
def __init__(self, **kwargs):
kwargs['cmd'].update(is_train=False)
self.load_all(**kwargs)
cmd = kwargs['cmd']
self.visualize = cmd['visualize']
class EvaluationSettings(Configurable):
data_loaders = State()
visualize = State(default=True)
resume = State()
def __init__(self, **kwargs):
self.load_all(**kwargs)
class EvaluationSettings2(Configurable):
structure = State()
data_loaders = State()
def __init__(self, **kwargs):
self.load_all(**kwargs)
class ShowSettings(Configurable):
data_loader = State()
representer = State()
visualizer = State()
def __init__(self, **kwargs):
self.load_all(**kwargs)
class Experiment(Configurable):
structure = State(autoload=False)
train = State()
validation = State(autoload=False)
evaluation = State(autoload=False)
logger = State(autoload=True)
def __init__(self, **kwargs):
self.load('structure', **kwargs)
cmd = kwargs.get('cmd', {})
if 'name' not in cmd:
cmd['name'] = self.structure.model_name
self.load_all(**kwargs)
self.distributed = cmd.get('distributed', False)
self.local_rank = cmd.get('local_rank', 0)
if cmd.get('validate', False):
self.load('validation', **kwargs)
else:
self.validation = None