-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtrain.py
122 lines (113 loc) · 4.95 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from options.options import Options
import os
import torch
from build_dataset_model import build_loaders, build_model
from utils import get_model_attr, calculate_model_losses, tensor_aug
from collections import defaultdict
import math
def main(args):
vocab, train_loader, val_loader = build_loaders(args)
model, model_kwargs = build_model(args, vocab)
print(model)
model.float().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
restore_path = None
if args.restore_from_checkpoint:
restore_path = '%s_with_model.pt' % args.checkpoint_name
restore_path = os.path.join(args.output_dir, restore_path)
if restore_path is not None and os.path.isfile(restore_path):
print('Restoring from checkpoint:')
print(restore_path)
checkpoint = torch.load(restore_path)
get_model_attr(model, 'load_state_dict')(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optim_state'])
t = checkpoint['counters']['t']
if 0 <= args.eval_mode_after <= t:
model.eval()
else:
model.train()
epoch = checkpoint['counters']['epoch']
else:
t, epoch = 0, 0
checkpoint = {
'args': args.__dict__,
'vocab': vocab,
'model_kwargs': model_kwargs,
'losses_ts': [],
'losses': defaultdict(list),
'd_losses': defaultdict(list),
'checkpoint_ts': [],
'train_batch_data': [],
'train_samples': [],
'train_iou': [],
'val_batch_data': [],
'val_samples': [],
'val_losses': defaultdict(list),
'val_iou': [],
'counters': {
't': None,
'epoch': None,
},
'model_state': None,
'optim_state': None,
}
while True:
if t >= args.num_iterations:
break
epoch += 1
print('Starting epoch %d' % epoch)
for batch in train_loader:
if t == args.eval_mode_after:
print('switching to eval mode')
model.eval()
t += 1
if t%50 ==0:
print("Currently on batch {}".format(t))
ids, objs, boxes, triples, angles, attributes, obj_to_img, triple_to_img = tensor_aug(batch)
model_out = model(objs, triples, boxes, angles, attributes, obj_to_img)
mu, logvar, boxes_pred, angles_pred = model_out
if args.KL_linear_decay:
KL_weight = 10 ** (t // 1e5 - 6)
else:
KL_weight = args.KL_loss_weight
total_loss, losses = calculate_model_losses(args, model, boxes, boxes_pred, angles, angles_pred, mu=mu, logvar=logvar, KL_weight=KL_weight)
losses['total_loss'] = total_loss.item()
if not math.isfinite(losses['total_loss']):
print('WARNING: Got loss = NaN, not backpropping')
continue
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
if t % args.print_every == 0:
print("On batch {} out of {}".format(t, args.num_iterations))
for name, val in losses.items():
print(' [%s]: %.4f' % (name, val))
checkpoint['losses'][name].append(val)
checkpoint['losses_ts'].append(t)
if t % args.checkpoint_every == 0:
checkpoint['model_state'] = get_model_attr(model, 'state_dict')()
checkpoint['optim_state'] = optimizer.state_dict()
checkpoint['counters']['t'] = t
checkpoint['counters']['epoch'] = epoch
checkpoint_path = os.path.join(args.output_dir, 'latest_%s_with_model.pt' % args.checkpoint_name)
print('Saving checkpoint to ', checkpoint_path)
torch.save(checkpoint, checkpoint_path)
if t % args.snapshot_every == 0:
snapshot_name = args.checkpoint_name + 'snapshot_%06dK' % (t // 1000)
snapshot_path = os.path.join(args.output_dir, snapshot_name)
print('Saving snapshot to ', snapshot_path)
torch.save(checkpoint, snapshot_path)
checkpoint_path = os.path.join(args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
key_blacklist = ['model_state', 'optim_state']
small_checkpoint = {}
for k, v in checkpoint.items():
if k not in key_blacklist:
small_checkpoint[k] = v
torch.save(small_checkpoint, checkpoint_path)
if __name__ == '__main__':
args = Options().parse()
if (args.output_dir is not None) and (not os.path.isdir(args.output_dir)):
os.mkdir(args.output_dir)
if (args.test_dir is not None) and (not os.path.isdir(args.test_dir)):
os.mkdir(args.test_dir)
main(args)