diff --git a/src/super_gradients/training/datasets/data_formats/obb/cxcywhr.py b/src/super_gradients/training/datasets/data_formats/obb/cxcywhr.py new file mode 100644 index 0000000000..e3af66c0ad --- /dev/null +++ b/src/super_gradients/training/datasets/data_formats/obb/cxcywhr.py @@ -0,0 +1,58 @@ +import cv2 +import numpy as np + + +def cxcywhr_to_poly(boxes: np.ndarray) -> np.ndarray: + """ + Convert oriented bounding boxes in CX-CY-W-H-R format to a polygon format + :param boxes: [N,...,5] oriented bounding boxes in CX-CY-W-H-R format + :return: [N,...,4, 2] oriented bounding boxes in polygon format + """ + shape = boxes.shape + if shape[-1] != 5: + raise ValueError(f"Expected last dimension to be 5, got {shape[-1]}") + + flat_rboxes = boxes.reshape(-1, 5).astype(np.float32) + polys = np.zeros((flat_rboxes.shape[0], 4, 2), dtype=np.float32) + for i, box in enumerate(flat_rboxes): + cx, cy, w, h, r = box + rect = ((cx, cy), (w, h), np.rad2deg(r)) + poly = cv2.boxPoints(rect) + polys[i] = poly + + return polys.reshape(*shape[:-1], 4, 2) + + +def poly_to_cxcywhr(poly: np.ndarray) -> np.ndarray: + shape = poly.shape + if shape[-2:] != (4, 2): + raise ValueError(f"Expected last two dimensions to be (4, 2), got {shape[-2:]}") + + flat_polys = poly.reshape(-1, 4, 2) + rboxes = np.zeros((flat_polys.shape[0], 5), dtype=np.float32) + for i, poly in enumerate(flat_polys): + hull = cv2.convexHull(np.reshape(poly, [-1, 2]).astype(np.float32)) + rect = cv2.minAreaRect(hull) + cx, cy = rect[0] + w, h = rect[1] + angle = rect[2] + if h > w: + w, h = h, w + angle += 90 + angle = np.deg2rad(angle) + rboxes[i] = [cx, cy, w, h, angle] + + return rboxes.reshape(*shape[:-2], 5) + + +def poly_to_xyxy(poly: np.ndarray) -> np.ndarray: + """ + Convert oriented bounding boxes in polygon format to XYXY format + :param poly: [..., 4, 2] + :return: [..., 4] + """ + x1 = poly[..., :, 0].min(axis=-1) + y1 = poly[..., :, 1].min(axis=-1) + x2 = poly[..., :, 0].max(axis=-1) + y2 = poly[..., :, 1].max(axis=-1) + return np.stack([x1, y1, x2, y2], axis=-1) diff --git a/src/super_gradients/training/samples/__init__.py b/src/super_gradients/training/samples/__init__.py index 93f00253ae..6516eeff5d 100644 --- a/src/super_gradients/training/samples/__init__.py +++ b/src/super_gradients/training/samples/__init__.py @@ -2,5 +2,6 @@ from .pose_estimation_sample import PoseEstimationSample from .detection_sample import DetectionSample from .segmentation_sample import SegmentationSample +from .obb_sample import OBBSample -__all__ = ["PoseEstimationSample", "DetectionSample", "SegmentationSample", "DepthEstimationSample"] +__all__ = ["PoseEstimationSample", "DetectionSample", "SegmentationSample", "DepthEstimationSample", "OBBSample"] diff --git a/src/super_gradients/training/samples/obb_sample.py b/src/super_gradients/training/samples/obb_sample.py new file mode 100644 index 0000000000..935b0fd206 --- /dev/null +++ b/src/super_gradients/training/samples/obb_sample.py @@ -0,0 +1,110 @@ +import dataclasses +from typing import Union, List, Optional + +import numpy as np +import torch +from super_gradients.training.datasets.data_formats.obb.cxcywhr import cxcywhr_to_poly, poly_to_cxcywhr + + +@dataclasses.dataclass +class OBBSample: + """ + A data class describing a single object detection sample that comes from a dataset. + It contains both input image and target information to train an object detection model. + + :param image: Associated image with a sample. Can be in [H,W,C] or [C,H,W] format + :param rboxes_cxcywhr: Numpy array of [N,5] shape with oriented bounding box of each instance (CX,CY,W,H,R) + :param labels: Numpy array of [N] shape with class label for each instance + :param is_crowd: (Optional) Numpy array of [N] shape with is_crowd flag for each instance + :param additional_samples: (Optional) List of additional samples for the same image. + """ + + __slots__ = ["image", "rboxes_cxcywhr", "labels", "is_crowd", "additional_samples"] + + image: Union[np.ndarray, torch.Tensor] + rboxes_cxcywhr: np.ndarray + labels: np.ndarray + is_crowd: np.ndarray + additional_samples: Optional[List["OBBSample"]] + + def __init__( + self, + image: Union[np.ndarray, torch.Tensor], + rboxes_cxcywhr: np.ndarray, + labels: np.ndarray, + is_crowd: Optional[np.ndarray] = None, + additional_samples: Optional[List["OBBSample"]] = None, + ): + if is_crowd is None: + is_crowd = np.zeros(len(labels), dtype=bool) + + if len(rboxes_cxcywhr) != len(labels): + raise ValueError("Number of bounding boxes and labels must be equal. Got {len(bboxes_xyxy)} and {len(labels)} respectively") + + if len(rboxes_cxcywhr) != len(is_crowd): + raise ValueError("Number of bounding boxes and is_crowd flags must be equal. Got {len(bboxes_xyxy)} and {len(is_crowd)} respectively") + + if len(rboxes_cxcywhr.shape) != 2 or rboxes_cxcywhr.shape[1] != 5: + raise ValueError(f"Oriented boxes must be in [N,5] format. Shape of input bboxes is {rboxes_cxcywhr.shape}") + + if len(is_crowd.shape) != 1: + raise ValueError(f"Number of is_crowd flags must be in [N] format. Shape of input is_crowd is {is_crowd.shape}") + + if len(labels.shape) != 1: + raise ValueError("Labels must be in [N] format. Shape of input labels is {labels.shape}") + + self.image = image + self.rboxes_cxcywhr = rboxes_cxcywhr + self.labels = labels + self.is_crowd = is_crowd + self.additional_samples = additional_samples + + def sanitize_sample(self) -> "OBBSample": + """ + Apply sanity checks on the detection sample, which includes clamping of rotate boxes to image boundaries + and removing boxes with non-positive area. + This method returns a new DetectionSample instance with sanitized data. + :return: A DetectionSample after filtering. + """ + polys = cxcywhr_to_poly(self.rboxes_cxcywhr) + # Clamp polygons to image boundaries + polys[..., 0] = np.clip(polys[..., 0], 0, self.image.shape[1]) + polys[..., 1] = np.clip(polys[..., 1], 0, self.image.shape[0]) + rboxes_cxcywhr = poly_to_cxcywhr(polys) + return OBBSample( + image=self.image, + rboxes_cxcywhr=rboxes_cxcywhr, + labels=self.labels, + is_crowd=self.is_crowd, + additional_samples=self.additional_samples, + ).filter_by_bbox_area(0) + + def filter_by_mask(self, mask: np.ndarray) -> "OBBSample": + """ + Remove boxes & labels with respect to a given mask. + This method returns a new DetectionSample instance with filtered data. + + :param mask: A boolean or integer mask of samples to keep for given sample. + :return: A DetectionSample after filtering. + """ + return OBBSample( + image=self.image, + rboxes_cxcywhr=self.rboxes_cxcywhr[mask], + labels=self.labels[mask], + is_crowd=self.is_crowd[mask] if self.is_crowd is not None else None, + additional_samples=self.additional_samples, + ) + + def filter_by_bbox_area(self, min_rbox_area: Union[int, float]) -> "OBBSample": + """ + Remove pose instances that has area of the corresponding bounding box less than a certain threshold. + + :param min_rbox_area: Minimal rotated box area of the box to keep. + :return: A OBBSample after filtering. + """ + area = self.rboxes_cxcywhr[..., 2:4].prod(axis=-1) + keep_mask = area > min_rbox_area + return self.filter_by_mask(keep_mask) + + def __len__(self): + return len(self.labels) diff --git a/src/super_gradients/training/transforms/__init__.py b/src/super_gradients/training/transforms/__init__.py index 1f52079464..bc69a0d456 100644 --- a/src/super_gradients/training/transforms/__init__.py +++ b/src/super_gradients/training/transforms/__init__.py @@ -38,6 +38,16 @@ from super_gradients.common.registry.albumentation import ALBUMENTATIONS_TRANSFORMS, ALBUMENTATIONS_COMP_TRANSFORMS, imported_albumentations_failure from super_gradients.training.transforms.detection import AbstractDetectionTransform, DetectionPadIfNeeded, DetectionLongestMaxSize +from .obb import ( + AbstractOBBDetectionTransform, + OBBDetectionPadIfNeeded, + OBBDetectionLongestMaxSize, + OBBDetectionStandardize, + OBBDetectionMixup, + OBBDetectionCompose, + OBBRemoveSmallObjects, +) + __all__ = [ "TRANSFORMS", "ALBUMENTATIONS_TRANSFORMS", @@ -76,6 +86,13 @@ "DetectionPadIfNeeded", "DetectionLongestMaxSize", "AbstractDetectionTransform", + "AbstractOBBDetectionTransform", + "OBBDetectionPadIfNeeded", + "OBBDetectionLongestMaxSize", + "OBBDetectionStandardize", + "OBBDetectionMixup", + "OBBDetectionCompose", + "OBBRemoveSmallObjects", ] cv2.setNumThreads(0) diff --git a/src/super_gradients/training/transforms/obb/__init__.py b/src/super_gradients/training/transforms/obb/__init__.py new file mode 100644 index 0000000000..3eabc1ef29 --- /dev/null +++ b/src/super_gradients/training/transforms/obb/__init__.py @@ -0,0 +1,19 @@ +from super_gradients.training.samples.obb_sample import OBBSample +from .abstract_obb_transform import AbstractOBBDetectionTransform +from .obb_pad_if_needed import OBBDetectionPadIfNeeded +from .obb_longest_max_size import OBBDetectionLongestMaxSize +from .obb_standardize import OBBDetectionStandardize +from .obb_mixup import OBBDetectionMixup +from .obb_compose import OBBDetectionCompose +from .obb_remove_small_objects import OBBRemoveSmallObjects + +__all__ = [ + "OBBSample", + "AbstractOBBDetectionTransform", + "OBBDetectionPadIfNeeded", + "OBBDetectionLongestMaxSize", + "OBBDetectionStandardize", + "OBBDetectionMixup", + "OBBDetectionCompose", + "OBBRemoveSmallObjects", +] diff --git a/src/super_gradients/training/transforms/obb/abstract_obb_transform.py b/src/super_gradients/training/transforms/obb/abstract_obb_transform.py new file mode 100644 index 0000000000..537efd931c --- /dev/null +++ b/src/super_gradients/training/transforms/obb/abstract_obb_transform.py @@ -0,0 +1,59 @@ +import abc +import warnings +from abc import abstractmethod +from typing import List + +from super_gradients.training.samples.obb_sample import OBBSample + +__all__ = ["AbstractOBBDetectionTransform"] + + +class AbstractOBBDetectionTransform(abc.ABC): + """ + Base class for all transforms for object detection sample augmentation. + """ + + def __init__(self, additional_samples_count: int = 0): + """ + :param additional_samples_count: (int) number of samples that must be extra samples from dataset. Default value is 0. + """ + self._additional_samples_count = additional_samples_count + + @abstractmethod + def apply_to_sample(self, sample: OBBSample) -> OBBSample: + """ + Apply transformation to given pose estimation sample. + Important note - function call may return new object, may modify it in-place. + This is implementation dependent and if you need to keep original sample intact it + is recommended to make a copy of it BEFORE passing it to transform. + + :param sample: Input sample to transform. + :return: Modified sample (It can be the same instance as input or a new object). + """ + raise NotImplementedError + + @property + def additional_samples_count(self) -> int: + warnings.warn( + "This property is deprecated and will be removed in the future." "Please use `get_number_of_additional_samples` instead.", DeprecationWarning + ) + return self.get_number_of_additional_samples() + + def get_number_of_additional_samples(self) -> int: + """ + Returns number of additional samples required. The default implementation assumes that this number is fixed and deterministic. + Override in case this is not the case, e.g., you randomly choose to apply MixUp, etc + """ + return self._additional_samples_count + + @property + def may_require_additional_samples(self) -> bool: + """ + Indicates whether additional samples are required. The default implementation assumes that this indicator is fixed and deterministic. + Override in case this is not the case, e.g., you randomly choose to apply MixUp, etc + """ + return self._additional_samples_count > 0 + + @abstractmethod + def get_equivalent_preprocessing(self) -> List: + raise NotImplementedError diff --git a/src/super_gradients/training/transforms/obb/obb_compose.py b/src/super_gradients/training/transforms/obb/obb_compose.py new file mode 100644 index 0000000000..821509c591 --- /dev/null +++ b/src/super_gradients/training/transforms/obb/obb_compose.py @@ -0,0 +1,93 @@ +from typing import List + +from .abstract_obb_transform import AbstractOBBDetectionTransform +from super_gradients.training.samples.obb_sample import OBBSample + + +class OBBDetectionCompose(AbstractOBBDetectionTransform): + """ + Composes several transforms together + """ + + def __init__(self, transforms: List[AbstractOBBDetectionTransform], load_sample_fn=None): + """ + + :param transforms: List of keypoint-based transformations + :param load_sample_fn: A method to load additional samples if needed (for mixup & mosaic augmentations). + Default value is None, which would raise an error if additional samples are needed. + """ + for transform in transforms: + if hasattr(transform, "may_require_additional_samples") and transform.may_require_additional_samples and load_sample_fn is None: + raise RuntimeError(f"Transform {transform.__class__.__name__} that requires additional samples but `load_sample_fn` is None") + + super().__init__() + self.transforms = transforms + self.load_sample_fn = load_sample_fn + + def apply_to_sample(self, sample: OBBSample) -> OBBSample: + """ + Applies the series of transformations to the input sample. + The function may modify the input sample inplace, so input sample should not be used after the call. + + :param sample: Input sample + :return: Transformed sample. + """ + sample = sample.sanitize_sample() + sample = self._apply_transforms(sample, transforms=self.transforms, load_sample_fn=self.load_sample_fn) + return sample + + @classmethod + def _apply_transforms(cls, sample: OBBSample, transforms: List[AbstractOBBDetectionTransform], load_sample_fn) -> OBBSample: + """ + This helper method allows us to query additional samples for mixup & mosaic augmentations + that would be also passed through augmentation pipeline. Example: + + ``` + transforms: + - OBBDetectionLongestMaxSize: + max_height: ${dataset_params.image_size} + max_width: ${dataset_params.image_size} + - OBBDetectionMixup: + prob: ${dataset_params.mixup_prob} + ``` + + In the example above all samples in mixup will be forwarded through OBBDetectionLongestMaxSize, + and only then mixed up. + + :param sample: Input data sample + :param transforms: List of transformations to apply + :param load_sample_fn: A method to load additional samples if needed + :return: A data sample after applying transformations + """ + applied_transforms_so_far = [] + for t in transforms: + if not hasattr(t, "may_require_additional_samples") or not t.may_require_additional_samples: + sample = t.apply_to_sample(sample) + applied_transforms_so_far.append(t) + else: + additional_samples = [load_sample_fn() for _ in range(t.get_number_of_additional_samples())] + additional_samples = [ + cls._apply_transforms( + sample, + applied_transforms_so_far, + load_sample_fn=load_sample_fn, + ) + for sample in additional_samples + ] + sample.additional_samples = additional_samples + sample = t.apply_to_sample(sample) + + return sample + + def get_equivalent_preprocessing(self) -> List: + preprocessing = [] + for t in self.transforms: + preprocessing += t.get_equivalent_preprocessing() + return preprocessing + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += f"\t{repr(t)}" + format_string += "\n)" + return format_string diff --git a/src/super_gradients/training/transforms/obb/obb_longest_max_size.py b/src/super_gradients/training/transforms/obb/obb_longest_max_size.py new file mode 100644 index 0000000000..4e829eb5a6 --- /dev/null +++ b/src/super_gradients/training/transforms/obb/obb_longest_max_size.py @@ -0,0 +1,69 @@ +import random +from typing import List + +import cv2 +import numpy as np +from super_gradients.common.object_names import Processings +from super_gradients.common.registry import register_transform +from super_gradients.training.transforms.utils import _rescale_bboxes + +from super_gradients.training.samples.obb_sample import OBBSample +from .abstract_obb_transform import AbstractOBBDetectionTransform + + +@register_transform() +class OBBDetectionLongestMaxSize(AbstractOBBDetectionTransform): + """ + Resize data sample to guarantee that input image dimensions is not exceeding maximum width & height + """ + + def __init__(self, max_height: int, max_width: int, interpolation: int = cv2.INTER_LINEAR, prob: float = 1.0): + """ + + :param max_height: (int) Maximum image height + :param max_width: (int) Maximum image width + :param interpolation: Used interpolation method for image + :param prob: Probability of applying this transform. Default: 1.0 + """ + super().__init__() + self.max_height = int(max_height) + self.max_width = int(max_width) + self.interpolation = int(interpolation) + self.prob = float(prob) + + def apply_to_sample(self, sample: OBBSample) -> OBBSample: + if random.random() < self.prob: + height, width = sample.image.shape[:2] + scale = min(self.max_height / height, self.max_width / width) + + sample = OBBSample( + image=self.apply_to_image(sample.image, scale, cv2.INTER_LINEAR), + rboxes_cxcywhr=self.apply_to_bboxes(sample.rboxes_cxcywhr, scale), + labels=sample.labels, + is_crowd=sample.is_crowd, + additional_samples=None, + ) + + if sample.image.shape[0] != self.max_height and sample.image.shape[1] != self.max_width: + raise RuntimeError(f"Image shape is not as expected (scale={scale}, input_shape={height, width}, resized_shape={sample.image.shape[:2]})") + + if sample.image.shape[0] > self.max_height or sample.image.shape[1] > self.max_width: + raise RuntimeError(f"Image shape is not as expected (scale={scale}, input_shape={height, width}, resized_shape={sample.image.shape[:2]}") + + return sample + + @classmethod + def apply_to_image(cls, image: np.ndarray, scale: float, interpolation: int) -> np.ndarray: + height, width = image.shape[:2] + + if scale != 1.0: + new_height, new_width = tuple(int(dim * scale + 0.5) for dim in (height, width)) + image = cv2.resize(image, dsize=(new_width, new_height), interpolation=interpolation) + return image + + @classmethod + def apply_to_bboxes(cls, bboxes: np.ndarray, scale: float) -> np.ndarray: + return _rescale_bboxes(bboxes, (scale, scale)) + + def get_equivalent_preprocessing(self) -> List: + return [{Processings.OBBDetectionLongestMaxSizeRescale: {"output_shape": (self.max_height, self.max_width)}}] diff --git a/src/super_gradients/training/transforms/obb/obb_mixup.py b/src/super_gradients/training/transforms/obb/obb_mixup.py new file mode 100644 index 0000000000..c20d15cbcd --- /dev/null +++ b/src/super_gradients/training/transforms/obb/obb_mixup.py @@ -0,0 +1,91 @@ +import random + +import numpy as np +from super_gradients.common.registry import register_transform + +from .abstract_obb_transform import AbstractOBBDetectionTransform +from super_gradients.training.samples.obb_sample import OBBSample + + +@register_transform() +class OBBDetectionMixup(AbstractOBBDetectionTransform): + """ + Apply mixup augmentation and combine two samples into one. + Images are averaged with equal weights. Targets are concatenated without any changes. + This transform requires both samples have the same image size. The easiest way to achieve this is to use resize + padding before this transform: + + NOTE: For efficiency, the decision whether to apply the transformation is done (per call) at `get_number_of_additional_samples` + + ```yaml + # This will apply KeypointsLongestMaxSize and KeypointsPadIfNeeded to two samples individually + # and then apply KeypointsMixup to get a single sample. + train_dataset_params: + transforms: + - OBBDetectionLongestMaxSize: + max_height: ${dataset_params.image_size} + max_width: ${dataset_params.image_size} + + - OBBDetectionPadIfNeeded: + min_height: ${dataset_params.image_size} + min_width: ${dataset_params.image_size} + image_pad_value: [127, 127, 127] + mask_pad_value: 1 + padding_mode: center + + - OBBDetectionMixup: + prob: 0.5 + ``` + + :param prob: Probability to apply the transform. + """ + + def __init__(self, prob: float): + """ + + :param prob: Probability to apply the transform. + """ + super().__init__() + self.prob = prob + + def get_number_of_additional_samples(self) -> int: + do_mixup = random.random() < self.prob + return int(do_mixup) + + @property + def may_require_additional_samples(self) -> bool: + return True + + def apply_to_sample(self, sample: OBBSample) -> OBBSample: + """ + Apply the transform to a single sample. + + :param sample: An input sample. It should have one additional sample in `additional_samples` field. + :return: A new pose estimation sample that represents the mixup sample. + """ + if sample.additional_samples is not None and len(sample.additional_samples) > 0: + other = sample.additional_samples[0] + if sample.image.shape != other.image.shape: + raise RuntimeError( + f"OBBDetectionMixup requires both samples to have the same image shape. " + f"Got {sample.image.shape} and {other.image.shape}. " + f"Use OBBDetectionLongestMaxSize and OBBDetectionPadIfNeeded to resize and pad images before this transform." + ) + sample = self._apply_mixup(sample, other) + return sample + + def _apply_mixup(self, sample: OBBSample, other: OBBSample) -> OBBSample: + """ + Apply mixup augmentation to a single sample. + :param sample: First sample. + :param other: Second sample. + :return: Mixup sample. + """ + image = (sample.image * 0.5 + other.image * 0.5).astype(sample.image.dtype) + rboxes_cxcywhr = np.concatenate([sample.rboxes_cxcywhr, other.rboxes_cxcywhr], axis=0) + labels = np.concatenate([sample.labels, other.labels], axis=0) + is_crowd = np.concatenate([sample.is_crowd, other.is_crowd], axis=0) + + return OBBSample(image=image, rboxes_cxcywhr=rboxes_cxcywhr, labels=labels, is_crowd=is_crowd, additional_samples=None) + + def get_equivalent_preprocessing(self): + raise RuntimeError(f"{self.__class__} does not have equivalent preprocessing because it is non-deterministic.") diff --git a/src/super_gradients/training/transforms/obb/obb_pad_if_needed.py b/src/super_gradients/training/transforms/obb/obb_pad_if_needed.py new file mode 100644 index 0000000000..b4e2748f1e --- /dev/null +++ b/src/super_gradients/training/transforms/obb/obb_pad_if_needed.py @@ -0,0 +1,69 @@ +from typing import List + +from super_gradients.common.object_names import Processings +from super_gradients.common.registry.registry import register_transform +from super_gradients.training.samples.obb_sample import OBBSample +from super_gradients.training.transforms.utils import _pad_image, PaddingCoordinates, _shift_bboxes_cxcywhr + +from .abstract_obb_transform import AbstractOBBDetectionTransform + + +@register_transform() +class OBBDetectionPadIfNeeded(AbstractOBBDetectionTransform): + """ + Pad image and targets to ensure that resulting image size is not less than (min_width, min_height). + """ + + def __init__(self, min_height: int, min_width: int, pad_value: int, padding_mode: str = "bottom_right"): + """ + :param min_height: Minimal height of the image. + :param min_width: Minimal width of the image. + :param pad_value: Padding value of image + :param padding_mode: Padding mode. Supported modes: 'bottom_right', 'center'. + """ + if padding_mode not in ("bottom_right", "center"): + raise ValueError(f"Unknown padding mode: {padding_mode}. Supported modes: 'bottom_right', 'center'") + super().__init__() + self.min_height = min_height + self.min_width = min_width + self.image_pad_value = pad_value + self.padding_mode = padding_mode + + def apply_to_sample(self, sample: OBBSample) -> OBBSample: + """ + Apply transform to a single sample. + :param sample: Input detection sample. + :return: Transformed detection sample. + """ + height, width = sample.image.shape[:2] + + if self.padding_mode == "bottom_right": + pad_left = 0 + pad_top = 0 + pad_bottom = max(0, self.min_height - height) + pad_right = max(0, self.min_width - width) + elif self.padding_mode == "center": + pad_left = max(0, (self.min_width - width) // 2) + pad_top = max(0, (self.min_height - height) // 2) + pad_bottom = max(0, self.min_height - height - pad_top) + pad_right = max(0, self.min_width - width - pad_left) + else: + raise RuntimeError(f"Unknown padding mode: {self.padding_mode}") + + padding_coordinates = PaddingCoordinates(top=pad_top, bottom=pad_bottom, left=pad_left, right=pad_right) + + return OBBSample( + image=_pad_image(sample.image, padding_coordinates, self.image_pad_value), + rboxes_cxcywhr=_shift_bboxes_cxcywhr(sample.rboxes_cxcywhr, pad_left, pad_top), + labels=sample.labels, + is_crowd=sample.is_crowd, + additional_samples=None, + ) + + def get_equivalent_preprocessing(self) -> List: + if self.padding_mode == "bottom_right": + return [{Processings.OBBDetectionBottomRightPadding: {"output_shape": (self.min_height, self.min_width), "pad_value": self.image_pad_value}}] + elif self.padding_mode == "center": + return [{Processings.OBBDetectionCenterPadding: {"output_shape": (self.min_height, self.min_width), "pad_value": self.image_pad_value}}] + else: + raise RuntimeError(f"OBBDetectionPadIfNeeded with padding_mode={self.padding_mode} is not implemented.") diff --git a/src/super_gradients/training/transforms/obb/obb_remove_small_objects.py b/src/super_gradients/training/transforms/obb/obb_remove_small_objects.py new file mode 100644 index 0000000000..cd0334360e --- /dev/null +++ b/src/super_gradients/training/transforms/obb/obb_remove_small_objects.py @@ -0,0 +1,45 @@ +from typing import List + +import numpy as np +from super_gradients.common.registry import register_transform + +from .abstract_obb_transform import AbstractOBBDetectionTransform +from super_gradients.training.samples.obb_sample import OBBSample + + +@register_transform() +class OBBRemoveSmallObjects(AbstractOBBDetectionTransform): + """ + Remove pose instances from data sample that are too small or have too few visible keypoints. + """ + + def __init__(self, min_size: int, min_area: int): + """ + :param min_size: Minimum size (width or height) of oriented box to keep in the sample + :param min_area: Minimum area of oriented box to keep in the sample + """ + super().__init__() + self.min_size = min_size + self.min_area = min_area + + def apply_to_sample(self, sample: OBBSample) -> OBBSample: + """ + Apply transformation to given pose estimation sample. + + :param sample: Input sample to transform. + :return: Filtered sample. + """ + mask = np.ones(len(sample), dtype=bool) + if self.min_size: + min_size_mask = sample.rboxes_cxcywhr[:, 2:4].min(axis=1) >= self.min_size + mask &= min_size_mask + if self.min_area: + min_area_mask = sample.rboxes_cxcywhr[:, 2] * sample.rboxes_cxcywhr[:, 3] >= self.min_area + mask &= min_area_mask + return sample.filter_by_mask(mask) + + def __repr__(self): + return self.__class__.__name__ + (f"(min_size={self.min_size}, " f"min_area={self.min_area})") + + def get_equivalent_preprocessing(self) -> List: + return [] diff --git a/src/super_gradients/training/transforms/obb/obb_standardize.py b/src/super_gradients/training/transforms/obb/obb_standardize.py new file mode 100644 index 0000000000..775e5ac81f --- /dev/null +++ b/src/super_gradients/training/transforms/obb/obb_standardize.py @@ -0,0 +1,31 @@ +from typing import List, Dict + +import numpy as np +from super_gradients.common.object_names import Processings +from super_gradients.common.registry import register_transform +from super_gradients.training.samples.obb_sample import OBBSample +from .abstract_obb_transform import AbstractOBBDetectionTransform + + +@register_transform() +class OBBDetectionStandardize(AbstractOBBDetectionTransform): + """ + Standardize image pixel values with img/max_val + + :param max_val: Current maximum value of the image pixels. (usually 255) + """ + + def __init__(self, max_value: float = 255.0): + super().__init__() + self.max_value = float(max_value) + + @classmethod + def apply_to_image(self, image: np.ndarray, max_value: float) -> np.ndarray: + return (image / max_value).astype(np.float32) + + def apply_to_sample(self, sample: OBBSample) -> OBBSample: + sample.image = self.apply_to_image(sample.image, max_value=self.max_value) + return sample + + def get_equivalent_preprocessing(self) -> List[Dict]: + return [{Processings.StandardizeImage: {"max_value": self.max_value}}] diff --git a/src/super_gradients/training/transforms/pipeline_adaptors.py b/src/super_gradients/training/transforms/pipeline_adaptors.py index b6a68ce46d..3fc1502023 100644 --- a/src/super_gradients/training/transforms/pipeline_adaptors.py +++ b/src/super_gradients/training/transforms/pipeline_adaptors.py @@ -3,9 +3,11 @@ from abc import abstractmethod, ABC import numpy as np from PIL import Image +from super_gradients.training.datasets.data_formats.obb.cxcywhr import cxcywhr_to_poly, poly_to_cxcywhr from super_gradients.training.samples import DetectionSample, SegmentationSample, PoseEstimationSample, DepthEstimationSample from super_gradients.training.datasets.data_formats.bbox_formats.xywh import xywh_to_xyxy, xyxy_to_xywh +from super_gradients.training.transforms.obb import OBBSample class SampleType(Enum): @@ -14,6 +16,7 @@ class SampleType(Enum): POSE_ESTIMATION = "POSE_ESTIMATION" DEPTH_ESTIMATION = "DEPTH_ESTIMATION" IMAGE_ONLY = "IMAGE_ONLY" + OBB_DETECTION = "OBB_DETECTION" class TransformsPipelineAdaptorBase(ABC): @@ -42,6 +45,8 @@ def __init__(self, composed_transforms: Callable): def __call__(self, sample, *args, **kwargs): if isinstance(sample, DetectionSample): self.sample_type = SampleType.DETECTION + elif isinstance(sample, OBBSample): + self.sample_type = SampleType.OBB_DETECTION elif isinstance(sample, SegmentationSample): self.sample_type = SampleType.SEGMENTATION elif isinstance(sample, DepthEstimationSample): @@ -109,6 +114,14 @@ def apply_to_sample(self, sample): def prep_for_transforms(self, sample): if self.sample_type == SampleType.DETECTION: sample = {"image": sample.image, "bboxes": sample.bboxes_xyxy, "labels": sample.labels, "is_crowd": sample.is_crowd} + elif self.sample_type == SampleType.OBB_DETECTION: + sample: OBBSample = sample + sample = { + "image": sample.image, + "keypoints": cxcywhr_to_poly(sample.rboxes_cxcywhr).reshape(-1, 2), + "labels": sample.labels, + "is_crowd": sample.is_crowd, + } elif self.sample_type == SampleType.SEGMENTATION: sample = {"image": np.array(sample.image), "mask": np.array(sample.mask)} elif self.sample_type == SampleType.DEPTH_ESTIMATION: @@ -145,6 +158,15 @@ def post_transforms_processing(self, sample): is_crowd=np.array(sample["is_crowd"]), additional_samples=None, ) + elif self.sample_type == SampleType.OBB_DETECTION: + polys = np.array(sample["keypoints"]).reshape(-1, 4, 2) + sample = OBBSample( + image=sample["image"], + rboxes_cxcywhr=poly_to_cxcywhr(polys), + labels=np.array(sample["labels"]).reshape(-1), + is_crowd=np.array(sample["is_crowd"]).reshape(-1), + additional_samples=None, + ).sanitize_sample() elif self.sample_type == SampleType.SEGMENTATION: sample = SegmentationSample(image=Image.fromarray(sample["image"]), mask=Image.fromarray(sample["mask"])) elif self.sample_type == SampleType.DEPTH_ESTIMATION: diff --git a/src/super_gradients/training/utils/visualization/obb.py b/src/super_gradients/training/utils/visualization/obb.py new file mode 100644 index 0000000000..c890c5f3ac --- /dev/null +++ b/src/super_gradients/training/utils/visualization/obb.py @@ -0,0 +1,79 @@ +from typing import Optional, Union, List, Tuple + +import cv2 +import numpy as np +from super_gradients.training.datasets.data_formats.obb.cxcywhr import cxcywhr_to_poly + + +class OBBVisualization: + @classmethod + def draw_obb( + self, + image: np.ndarray, + rboxes_cxcywhr: np.ndarray, + scores: Optional[np.ndarray], + labels: np.ndarray, + class_names: List[str], + class_colors: Union[List[Tuple], np.ndarray], + show_labels: bool = True, + show_confidence: bool = True, + thickness: int = 2, + opacity: float = 0.75, + label_prefix: str = "", + ): + """ + Draw rotated bounding boxes on the image + + :param image: [H, W, 3] - Image to draw bounding boxes on + :param rboxes_cxcywhr: [N, 5] - List of rotated bounding boxes in format [cx, cy, w, h, r] + :param labels: [N] - List of class indices + :param scores: [N] - List of confidence scores. Can be None, in which case confidence is not shown + :param class_names: [C] - List of class names + :param class_colors: [C, 3] - List of class colors + :param thickness: Thickness of the bounding box + :param show_labels: Boolean flag that indicates if labels should be shown (Default: True) + :param show_confidence: Boolean flag that indicates if confidence should be shown (Default: True) + :param opacity: Opacity of the overlay (Default: 0.5) + :param label_prefix: Prefix for the label (Default: "") + + :return: [H, W, 3] - Image with bounding boxes drawn + """ + if len(class_names) != len(class_colors): + raise ValueError("Number of class labels and colors should match") + + overlay = image.copy() + num_boxes = len(rboxes_cxcywhr) + + font_face = cv2.FONT_HERSHEY_PLAIN + font_scale = 1.0 + + show_confidence = show_confidence and scores is not None + + if scores is not None: + # Reorder the boxes to start with boxes of the lowest confidence + order = np.argsort(scores) + rboxes_cxcywhr = rboxes_cxcywhr[order] + scores = scores[order] + labels = labels[order] + + polygons = cxcywhr_to_poly(rboxes_cxcywhr) + + for i in range(num_boxes): + box = polygons[i] + class_index = int(labels[i]) + color = tuple(class_colors[class_index]) + cv2.polylines(overlay, box[None, :, :].astype(int), True, color, thickness=thickness, lineType=cv2.LINE_AA) + + if show_labels: + class_label = class_names[class_index] + label_title = f"{label_prefix}{class_label}" + if show_confidence: + conf = scores[i] + label_title = f"{label_title} {conf:.2f}" + + text_size, centerline = cv2.getTextSize(label_title, font_face, font_scale, thickness) + # Place origin somewhere at the top/top-right corner, use top-right corner of the `box` + org = (int(box[1][0]), int(box[1][1] - text_size[1])) + cv2.putText(overlay, label_title, org=org, fontFace=font_face, fontScale=font_scale, color=color, lineType=cv2.LINE_AA) + + return cv2.addWeighted(overlay, opacity, image, 1 - opacity, 0) diff --git a/tests/integration_tests/albumentations_test.py b/tests/integration_tests/albumentations_test.py index b0c7a13f06..052773615d 100644 --- a/tests/integration_tests/albumentations_test.py +++ b/tests/integration_tests/albumentations_test.py @@ -10,6 +10,8 @@ from albumentations import Compose, HorizontalFlip, InvertImg from super_gradients.training.datasets import Cifar10, Cifar100, ImageNetDataset, COCODetectionDataset, CoCoSegmentationDataSet, COCOPoseEstimationDataset +from super_gradients.training.samples import OBBSample +from super_gradients.training.transforms.pipeline_adaptors import AlbumentationsAdaptor from super_gradients.training.utils.visualization.pose_estimation import PoseVisualization from super_gradients.training.datasets.data_formats.bbox_formats.xywh import xywh_to_xyxy from super_gradients.training.datasets.depth_estimation_datasets import NYUv2DepthEstimationDataset @@ -338,6 +340,27 @@ def test_coco_pose_albumentations_intergration(self): _ = next(iter(unsupported_ds)) + def test_obb_support_albumentations(self): + import albumentations as A + + adaptor = AlbumentationsAdaptor( + composed_transforms=A.Compose( + transforms=[A.ShiftScaleRotate(p=1), A.RandomBrightness(p=1), A.Transpose(p=1)], keypoint_params=A.KeypointParams(format="xy") + ) + ) + + sample = OBBSample( + image=np.ones((256, 256, 3), dtype=np.uint8), + rboxes_cxcywhr=np.array([[128, 128, 100, 50, 0]]), + labels=np.array([1]), + is_crowd=np.array([0]), + additional_samples=None, + ) + sample = adaptor.apply_to_sample(sample) + self.assertEqual(sample.image.shape, (256, 256, 3)) + self.assertEqual(sample.rboxes_cxcywhr.shape, (1, 5)) + self.assertEqual(sample.labels.shape, (1,)) + if __name__ == "__main__": unittest.main()