Skip to content

Commit 5146c3b

Browse files
authored
[FIX] Enable YOLOX training on different devices (#7912)
* Enable yolox training on different devices * Enable yolox resize test on cpu
1 parent 2a643e4 commit 5146c3b

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

mmdet/models/detectors/yolox.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def forward_train(self,
9797

9898
# random resizing
9999
if (self._progress_in_iter + 1) % self._random_size_interval == 0:
100-
self._input_size = self._random_resize()
100+
self._input_size = self._random_resize(device=img.device)
101101
self._progress_in_iter += 1
102102

103103
return losses
@@ -116,8 +116,8 @@ def _preprocess(self, img, gt_bboxes):
116116
gt_bbox[..., 1::2] = gt_bbox[..., 1::2] * scale_y
117117
return img, gt_bboxes
118118

119-
def _random_resize(self):
120-
tensor = torch.LongTensor(2).cuda()
119+
def _random_resize(self, device):
120+
tensor = torch.LongTensor(2).to(device)
121121

122122
if self.rank == 0:
123123
size = random.randint(*self._random_size_range)

tests/test_models/test_forward.py

-2
Original file line numberDiff line numberDiff line change
@@ -673,8 +673,6 @@ def test_inference_detector():
673673
assert len(result) == 2 and len(result[0]) == num_class
674674

675675

676-
@pytest.mark.skipif(
677-
not torch.cuda.is_available(), reason='requires CUDA support')
678676
def test_yolox_random_size():
679677
from mmdet.models import build_detector
680678
model = _get_detector_cfg('yolox/yolox_tiny_8x8_300e_coco.py')

0 commit comments

Comments
 (0)