Skip to content

Commit

Permalink
removed old test for forward wrapper, fixed defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
shaydeci committed May 20, 2024
1 parent aa7d0cb commit 7f3a0d4
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def get_dataset_processing_params(self):
conf=self._default_nms_iou,
nms_top_k=self._default_nms_top_k,
max_predictions=self._default_max_predictions,
multi_label_per_box=self._multi_label_per_box,
multi_label_per_box=self._default_multi_label_per_box,
class_agnostic_nms=self._default_class_agnostic_nms,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ def __init__(
self.tile_step = tile_step
self.min_tile_threshold = min_tile_threshold

self._class_names: Optional[List[str]] = None
self._image_processor: Optional[Processing] = None
self._default_nms_iou: float = 0.7
self._default_nms_conf: float = 0.5
self._default_nms_top_k: int = 1024
self._default_max_predictions = 300
self._default_multi_label_per_box = True
self._default_class_agnostic_nms = False

# Processing params
self.model = model
self.set_dataset_processing_params(**self.model.get_dataset_processing_params())
Expand Down
24 changes: 0 additions & 24 deletions tests/unit_tests/detection_sliding_window_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,12 @@
from super_gradients.training import Trainer
from super_gradients.training.models.detection_models.sliding_window_detection_forward_wrapper import SlidingWindowInferenceDetectionWrapper
from super_gradients.training.metrics import DetectionMetrics
from super_gradients.training import training_hyperparams


class SlidingWindowWrapperTest(unittest.TestCase):
def setUp(self):
self.mini_coco_data_dir = str(Path(__file__).parent.parent / "data" / "tinycoco")

def test_train_with_sliding_window_wrapper_validation(self):
train_params = training_hyperparams.get("coco2017_yolo_nas_s")

train_params["valid_metrics_list"] = [
DetectionMetrics(
normalize_targets=True,
post_prediction_callback=None,
num_cls=80,
)
]
train_params["max_epochs"] = 2
train_params["lr_warmup_epochs"] = 0
train_params["lr_cooldown_epochs"] = 0
train_params["average_best_models"] = False
train_params["mixed_precision"] = False
train_params["validation_forward_wrapper"] = SlidingWindowInferenceDetectionWrapper(tile_size=320, tile_step=160, tile_nms_iou=0.65, tile_nms_conf=0.03)

dl = coco2017_val_yolo_nas(dataset_params=dict(data_dir=self.mini_coco_data_dir))

trainer = Trainer("test_yolo_nas_s_coco_with_sliding_window")
model = models.get("yolo_nas_s", num_classes=80, pretrained_weights="coco")
trainer.train(model=model, training_params=train_params, train_loader=dl, valid_loader=dl)

def test_yolo_nas_s_coco_with_sliding_window(self):
trainer = Trainer("test_yolo_nas_s_coco_with_sliding_window")
model = models.get("yolo_nas_s", num_classes=80, pretrained_weights="coco")
Expand Down

0 comments on commit 7f3a0d4

Please sign in to comment.