This repository has been archived by the owner on Oct 16, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_3dssd_zero2.py
74 lines (56 loc) · 2.31 KB
/
train_3dssd_zero2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
import colossalai
import torch
import torch.distributed as dist
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from mmcv import Config
from mmdet3d.core.bbox import (CameraInstance3DBoxes, DepthInstance3DBoxes,
LiDARInstance3DBoxes)
from mmdet3d.models import build_model
from tqdm import tqdm
ZERO_OPTIM_CFG = dict(
reduce_bucket_size=12 * 1024**2,
overlap_communication=True,
cpu_offload=False)
def get_data():
points_0 = torch.rand([2000, 4], device='cuda')
points_1 = torch.rand([2000, 4], device='cuda')
points = [points_0, points_1]
img_meta_0 = dict(box_type_3d=DepthInstance3DBoxes)
img_meta_1 = dict(box_type_3d=DepthInstance3DBoxes)
img_metas = [img_meta_0, img_meta_1]
gt_bbox_0 = DepthInstance3DBoxes(torch.rand([10, 7], device='cuda'))
gt_bbox_1 = DepthInstance3DBoxes(torch.rand([10, 7], device='cuda'))
gt_bboxes = [gt_bbox_0, gt_bbox_1]
gt_labels_0 = torch.zeros([10], device='cuda').long()
gt_labels_1 = torch.zeros([10], device='cuda').long()
gt_labels = [gt_labels_0, gt_labels_1]
return points, img_metas, gt_bboxes, gt_labels
def main(cfg, steps: int):
colossalai.launch_from_torch({})
model = build_model(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg')).cuda()
# model.init_weights()
optimizer = HybridAdam(model.parameters())
model = zero_model_wrapper(model, zero_stage=2)
optimizer = zero_optim_wrapper(
model, optimizer, optim_config=ZERO_OPTIM_CFG)
with tqdm(range(steps), desc='Train', disable=dist.get_rank() != 0) as pbar:
for _ in pbar:
data = get_data()
losses = model.forward_train(*data)
loss, _ = model._parse_losses(losses)
optimizer.backward(loss)
optimizer.step()
optimizer.zero_grad()
pbar.set_postfix(loss=loss.item())
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
parser.add_argument('--steps', type=int, default=10)
args = parser.parse_args()
cfg = Config.fromfile(args.config, args.steps)
main(cfg, args.steps)