Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 1488 yolo nas r integration sample and transforms #1995

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
147 commits
Select commit Hold shift + click to select a range
271f948
OBB loss
BloodAxe Apr 19, 2024
602a033
DOTA dataset
BloodAxe Apr 23, 2024
fd94586
Adding loss, postprocessing & visualization
BloodAxe Apr 23, 2024
7e25364
Visualization callback
BloodAxe Apr 23, 2024
d30790a
Visualization callback
BloodAxe Apr 23, 2024
b1ce298
YoloNAS-R
BloodAxe Apr 24, 2024
fc1cc93
YoloNAS-R
BloodAxe Apr 25, 2024
77ae38e
YoloNAS-R
BloodAxe Apr 25, 2024
13db92a
optimized_rboxes_nms
BloodAxe Apr 25, 2024
cf07405
optimized_rboxes_nms
BloodAxe Apr 25, 2024
5500717
optimized_rboxes_nms
BloodAxe Apr 25, 2024
8fda825
optimized_rboxes_nms
BloodAxe Apr 26, 2024
ed9f98f
optimized_rboxes_nms
BloodAxe Apr 26, 2024
d425476
Rename variables for better clarity
BloodAxe Apr 26, 2024
14a3501
Optimize loss weights
BloodAxe Apr 26, 2024
118271a
Prepare data script
BloodAxe Apr 26, 2024
2df4172
Prepare data script
BloodAxe Apr 26, 2024
93092b3
Reduce batch size & topk
BloodAxe Apr 27, 2024
88e3fe5
Remove unused arg
BloodAxe Apr 27, 2024
79c86d9
Remove unused arg
BloodAxe Apr 27, 2024
693121f
Remove unused arg
BloodAxe Apr 27, 2024
08c1b1f
Remove unused arg
BloodAxe Apr 27, 2024
c9de816
Increase topk
BloodAxe Apr 27, 2024
a758278
Increase topk
BloodAxe Apr 27, 2024
6bbba2f
Increase topk
BloodAxe Apr 27, 2024
ae4aa2f
Increase topk
BloodAxe Apr 27, 2024
5723734
Increase topk
BloodAxe Apr 27, 2024
52896bf
Increase topk
BloodAxe Apr 27, 2024
5f55552
Increase topk
BloodAxe Apr 27, 2024
e8ca4f6
Increase topk
BloodAxe Apr 27, 2024
2076952
Increase topk
BloodAxe Apr 27, 2024
62fa15e
Increase topk
BloodAxe Apr 27, 2024
02ef836
Increase topk
BloodAxe Apr 27, 2024
5c7dc87
Increase topk
BloodAxe Apr 27, 2024
1cc9eb5
Increase topk
BloodAxe Apr 27, 2024
b28f938
Increase topk
BloodAxe Apr 27, 2024
ca99a2e
RAdam
BloodAxe Apr 27, 2024
d68eab6
RAdam
BloodAxe Apr 27, 2024
eae4dbd
RAdam
BloodAxe Apr 27, 2024
ac0a0aa
CIou=False
BloodAxe Apr 27, 2024
aed9756
Disable check_points_inside_rboxes
BloodAxe Apr 27, 2024
9ad5455
Disable check_points_inside_rboxes
BloodAxe Apr 27, 2024
b51cd5f
Disable check_points_inside_rboxes
BloodAxe Apr 27, 2024
9f682c7
Tune weights
BloodAxe Apr 27, 2024
563a180
set_anomaly_enabled
BloodAxe Apr 27, 2024
5875f47
set_anomaly_enabled
BloodAxe Apr 27, 2024
75ce8d8
Increase eps
BloodAxe Apr 27, 2024
57df5e2
Multiply by scalar
BloodAxe Apr 27, 2024
e664135
Multiply by scalar
BloodAxe Apr 27, 2024
b938d4b
average_losses_in_ddp
BloodAxe Apr 28, 2024
21a4e53
yolo_nas_r_balanced
BloodAxe Apr 28, 2024
e42f586
Limit angle
BloodAxe Apr 28, 2024
ddab40d
Optimize image slicing
BloodAxe Apr 29, 2024
5c54b77
Do not reduce losses, reduce only scores sum
BloodAxe Apr 29, 2024
b07e3be
Do not reduce losses, reduce only scores sum
BloodAxe Apr 29, 2024
b37b735
Do not reduce losses, reduce only scores sum
BloodAxe Apr 29, 2024
48ae752
Clip grad
BloodAxe Apr 29, 2024
c2e972a
Disable L1 components
BloodAxe Apr 30, 2024
6937d29
Disable L1 components
BloodAxe Apr 30, 2024
1a0aaf0
Disable L1 components
BloodAxe Apr 30, 2024
e1c6687
Disable L1 components
BloodAxe Apr 30, 2024
1b7dfef
predict() support for YoloNAS-R
BloodAxe Apr 30, 2024
7fb7980
Comment saving of visualization
BloodAxe Apr 30, 2024
6a4daa2
yolo_nas_r_tzag
BloodAxe Apr 30, 2024
7a84b5b
Remove parameter
BloodAxe Apr 30, 2024
17f63d4
Fixed newline
BloodAxe May 1, 2024
d7c27a0
Merge remote-tracking branch 'origin/feature/SG-1448-OBB' into featur…
BloodAxe May 1, 2024
3670679
Increase min confidence
BloodAxe May 1, 2024
07b50d4
Added missing ReverseImageChannels
BloodAxe May 1, 2024
de6ad74
yolo_nas_r_tzag_balanced
BloodAxe May 1, 2024
23f2510
Merge remote-tracking branch 'origin/feature/SG-1448-OBB' into featur…
BloodAxe May 1, 2024
d1360c7
Added OBB transforms
BloodAxe May 1, 2024
a01a1ff
Added OBB transforms
BloodAxe May 1, 2024
503d586
get_dataset_preprocessing_params
BloodAxe May 1, 2024
dca0577
Added removal of small boxes during training
BloodAxe May 2, 2024
d86481f
Added removal of small boxes during training
BloodAxe May 2, 2024
d79ebc8
Added removal of invalid boxes
BloodAxe May 2, 2024
844e790
Added augmentations
BloodAxe May 2, 2024
5be4fe0
Increase bs
BloodAxe May 2, 2024
91bc4aa
Ensure that poly_to_cxcywhr always return boxes with w > h
BloodAxe May 2, 2024
7fa35df
Enabled anomaly detection
BloodAxe May 2, 2024
2602027
Added eps
BloodAxe May 2, 2024
38f6bbd
Undo anomaly
BloodAxe May 2, 2024
9c45e20
dota_yolo_nas_r_balanced_no_mixup
BloodAxe May 3, 2024
f9e01c9
Added logging of non-finite IoU results
BloodAxe May 4, 2024
68452b9
Merge remote-tracking branch 'origin/feature/SG-1448-OBB' into featur…
BloodAxe May 4, 2024
723db01
Added sanitize sample call after applying albumentations transforms
BloodAxe May 4, 2024
f4b427c
Replaced atan2() with atan() in CIoU loss because gt boxes can be zeros
BloodAxe May 4, 2024
015f5d7
MOAR Augs
BloodAxe May 5, 2024
0ab3347
MOAR Augs
BloodAxe May 5, 2024
3adac8f
RAdam -> AdamW
BloodAxe May 6, 2024
b05d07c
dota_yolo_nas_r_balanced_pretrain
BloodAxe May 6, 2024
eaf7fee
Tune augs
BloodAxe May 6, 2024
222e7c0
Fixed issue of logging wrong config
BloodAxe May 7, 2024
ebea29f
Revert file
BloodAxe May 7, 2024
67c2008
Cleanup
BloodAxe May 7, 2024
1d252a2
Added docs
BloodAxe May 7, 2024
754c02b
Remove non-existing params
BloodAxe May 8, 2024
874642f
YoloNAS-R M&L variants
BloodAxe May 9, 2024
aa5350d
Merge remote-tracking branch 'origin/feature/SG-1448-OBB' into featur…
BloodAxe May 9, 2024
8adb0a6
Export support for YoloNAS R
BloodAxe May 10, 2024
b91551f
Temporary comment matrix nms
BloodAxe May 10, 2024
2c8fafe
PrefetchIterator
BloodAxe May 10, 2024
e1855a2
Remove gs import
BloodAxe May 10, 2024
ab986d5
dota_yolo_nas_r_s
BloodAxe May 11, 2024
43deb07
Increase BS
BloodAxe May 11, 2024
fed740b
Enable anomaly detection
BloodAxe May 11, 2024
c82cba6
Rewrite t3
BloodAxe May 11, 2024
5d6063d
Rewrite t3
BloodAxe May 11, 2024
c55c698
Rewrite t3
BloodAxe May 11, 2024
2b3a705
dota_yolo_nas_r_s_1_gpu
BloodAxe May 11, 2024
9add044
M & L
BloodAxe May 11, 2024
0e10b1a
Increase numerical stability
BloodAxe May 11, 2024
134e008
Remove prefetch
BloodAxe May 11, 2024
81e05cc
Increase BS for L
BloodAxe May 11, 2024
99a676c
Increase BS for L
BloodAxe May 11, 2024
28cc08b
Disable fp16
BloodAxe May 11, 2024
7edf5c0
RM anomaly
BloodAxe May 11, 2024
72afa29
RM anomaly
BloodAxe May 11, 2024
f963043
fp32
BloodAxe May 11, 2024
1623579
Increase batch size
BloodAxe May 11, 2024
4db50c6
Increase batch size
BloodAxe May 11, 2024
85e23ff
Increase batch size
BloodAxe May 11, 2024
4a30fdc
dota_yolo_nas_r_l_fp32
BloodAxe May 12, 2024
b2961b8
Switch to use of matrix nms in post-prediction callback and exact pol…
BloodAxe May 13, 2024
d73224e
Switch to use of matrix nms in post-prediction callback and exact pol…
BloodAxe May 13, 2024
d0d8975
Added docs for rboxes_matrix_nms
BloodAxe May 14, 2024
395097e
Update script
BloodAxe May 14, 2024
51599a9
Improve auto-generated submission name
BloodAxe May 14, 2024
f6ecc5f
Added positional args
BloodAxe May 14, 2024
6985113
Added positional args
BloodAxe May 14, 2024
4999f83
dota_yolo_nas_r_l_fp32_mixup
BloodAxe May 14, 2024
6c5d912
dota_yolo_nas_r_l_fp32_mixup
BloodAxe May 14, 2024
fd68c7a
dota_yolo_nas_r_l_fp32_mixup
BloodAxe May 14, 2024
8d7db99
dota_yolo_nas_r_l_fp32_mixup
BloodAxe May 14, 2024
98a9fc2
Update recipes
BloodAxe May 16, 2024
3ff845e
Update recipes
BloodAxe May 16, 2024
37665a2
Added license support for YoloNAS-R
BloodAxe May 16, 2024
d81d939
Merge branch 'refs/heads/master' into feature/SG-1448-OBB
BloodAxe May 16, 2024
9783861
Improve docstrings
BloodAxe May 16, 2024
1f69534
Remove max gradient debugging
BloodAxe May 16, 2024
03c876b
Improve docstrings
BloodAxe May 16, 2024
5729722
Merge branch 'refs/heads/feature/SG-1488-YoloNAS-R-integration' into …
BloodAxe May 17, 2024
2882b1f
Cherry-pick commits related to OBBSample & Transforms
BloodAxe May 17, 2024
abc4e8d
Remove random rotate90
BloodAxe May 20, 2024
874119c
Move OBBSample to samples
BloodAxe May 20, 2024
d4a0e07
Added test
BloodAxe May 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import cv2
import numpy as np


def cxcywhr_to_poly(boxes: np.ndarray) -> np.ndarray:
"""
Convert oriented bounding boxes in CX-CY-W-H-R format to a polygon format
:param boxes: [N,...,5] oriented bounding boxes in CX-CY-W-H-R format
:return: [N,...,4, 2] oriented bounding boxes in polygon format
"""
shape = boxes.shape
if shape[-1] != 5:
raise ValueError(f"Expected last dimension to be 5, got {shape[-1]}")

flat_rboxes = boxes.reshape(-1, 5).astype(np.float32)
polys = np.zeros((flat_rboxes.shape[0], 4, 2), dtype=np.float32)
for i, box in enumerate(flat_rboxes):
cx, cy, w, h, r = box
rect = ((cx, cy), (w, h), np.rad2deg(r))
poly = cv2.boxPoints(rect)
polys[i] = poly

return polys.reshape(*shape[:-1], 4, 2)


def poly_to_cxcywhr(poly: np.ndarray) -> np.ndarray:
shape = poly.shape
if shape[-2:] != (4, 2):
raise ValueError(f"Expected last two dimensions to be (4, 2), got {shape[-2:]}")

flat_polys = poly.reshape(-1, 4, 2)
rboxes = np.zeros((flat_polys.shape[0], 5), dtype=np.float32)
for i, poly in enumerate(flat_polys):
hull = cv2.convexHull(np.reshape(poly, [-1, 2]).astype(np.float32))
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
rect = cv2.minAreaRect(hull)
cx, cy = rect[0]
w, h = rect[1]
angle = rect[2]
if h > w:
w, h = h, w
angle += 90
angle = np.deg2rad(angle)
rboxes[i] = [cx, cy, w, h, angle]

return rboxes.reshape(*shape[:-2], 5)


def poly_to_xyxy(poly: np.ndarray) -> np.ndarray:
"""
Convert oriented bounding boxes in polygon format to XYXY format
:param poly: [..., 4, 2]
:return: [..., 4]
"""
x1 = poly[..., :, 0].min(axis=-1)
y1 = poly[..., :, 1].min(axis=-1)
x2 = poly[..., :, 0].max(axis=-1)
y2 = poly[..., :, 1].max(axis=-1)
return np.stack([x1, y1, x2, y2], axis=-1)
3 changes: 2 additions & 1 deletion src/super_gradients/training/samples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .pose_estimation_sample import PoseEstimationSample
from .detection_sample import DetectionSample
from .segmentation_sample import SegmentationSample
from .obb_sample import OBBSample

__all__ = ["PoseEstimationSample", "DetectionSample", "SegmentationSample", "DepthEstimationSample"]
__all__ = ["PoseEstimationSample", "DetectionSample", "SegmentationSample", "DepthEstimationSample", "OBBSample"]
110 changes: 110 additions & 0 deletions src/super_gradients/training/samples/obb_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import dataclasses
from typing import Union, List, Optional

import numpy as np
import torch
from super_gradients.training.datasets.data_formats.obb.cxcywhr import cxcywhr_to_poly, poly_to_cxcywhr


@dataclasses.dataclass
class OBBSample:
"""
A data class describing a single object detection sample that comes from a dataset.
It contains both input image and target information to train an object detection model.

:param image: Associated image with a sample. Can be in [H,W,C] or [C,H,W] format
:param rboxes_cxcywhr: Numpy array of [N,5] shape with oriented bounding box of each instance (CX,CY,W,H,R)
:param labels: Numpy array of [N] shape with class label for each instance
:param is_crowd: (Optional) Numpy array of [N] shape with is_crowd flag for each instance
:param additional_samples: (Optional) List of additional samples for the same image.
"""

__slots__ = ["image", "rboxes_cxcywhr", "labels", "is_crowd", "additional_samples"]

image: Union[np.ndarray, torch.Tensor]
rboxes_cxcywhr: np.ndarray
labels: np.ndarray
is_crowd: np.ndarray
additional_samples: Optional[List["OBBSample"]]

def __init__(
self,
image: Union[np.ndarray, torch.Tensor],
rboxes_cxcywhr: np.ndarray,
labels: np.ndarray,
is_crowd: Optional[np.ndarray] = None,
additional_samples: Optional[List["OBBSample"]] = None,
):
if is_crowd is None:
is_crowd = np.zeros(len(labels), dtype=bool)

if len(rboxes_cxcywhr) != len(labels):
raise ValueError("Number of bounding boxes and labels must be equal. Got {len(bboxes_xyxy)} and {len(labels)} respectively")

if len(rboxes_cxcywhr) != len(is_crowd):
raise ValueError("Number of bounding boxes and is_crowd flags must be equal. Got {len(bboxes_xyxy)} and {len(is_crowd)} respectively")

if len(rboxes_cxcywhr.shape) != 2 or rboxes_cxcywhr.shape[1] != 5:
raise ValueError(f"Oriented boxes must be in [N,5] format. Shape of input bboxes is {rboxes_cxcywhr.shape}")

if len(is_crowd.shape) != 1:
raise ValueError(f"Number of is_crowd flags must be in [N] format. Shape of input is_crowd is {is_crowd.shape}")

if len(labels.shape) != 1:
raise ValueError("Labels must be in [N] format. Shape of input labels is {labels.shape}")

self.image = image
self.rboxes_cxcywhr = rboxes_cxcywhr
self.labels = labels
self.is_crowd = is_crowd
self.additional_samples = additional_samples

def sanitize_sample(self) -> "OBBSample":
"""
Apply sanity checks on the detection sample, which includes clamping of rotate boxes to image boundaries
and removing boxes with non-positive area.
This method returns a new DetectionSample instance with sanitized data.
:return: A DetectionSample after filtering.
"""
polys = cxcywhr_to_poly(self.rboxes_cxcywhr)
# Clamp polygons to image boundaries
polys[..., 0] = np.clip(polys[..., 0], 0, self.image.shape[1])
polys[..., 1] = np.clip(polys[..., 1], 0, self.image.shape[0])
rboxes_cxcywhr = poly_to_cxcywhr(polys)
return OBBSample(
image=self.image,
rboxes_cxcywhr=rboxes_cxcywhr,
labels=self.labels,
is_crowd=self.is_crowd,
additional_samples=self.additional_samples,
).filter_by_bbox_area(0)

def filter_by_mask(self, mask: np.ndarray) -> "OBBSample":
"""
Remove boxes & labels with respect to a given mask.
This method returns a new DetectionSample instance with filtered data.

:param mask: A boolean or integer mask of samples to keep for given sample.
:return: A DetectionSample after filtering.
"""
return OBBSample(
image=self.image,
rboxes_cxcywhr=self.rboxes_cxcywhr[mask],
labels=self.labels[mask],
is_crowd=self.is_crowd[mask] if self.is_crowd is not None else None,
additional_samples=self.additional_samples,
)

def filter_by_bbox_area(self, min_rbox_area: Union[int, float]) -> "OBBSample":
"""
Remove pose instances that has area of the corresponding bounding box less than a certain threshold.

:param min_rbox_area: Minimal rotated box area of the box to keep.
:return: A OBBSample after filtering.
"""
area = self.rboxes_cxcywhr[..., 2:4].prod(axis=-1)
keep_mask = area > min_rbox_area
return self.filter_by_mask(keep_mask)

def __len__(self):
return len(self.labels)
17 changes: 17 additions & 0 deletions src/super_gradients/training/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@
from super_gradients.common.registry.albumentation import ALBUMENTATIONS_TRANSFORMS, ALBUMENTATIONS_COMP_TRANSFORMS, imported_albumentations_failure
from super_gradients.training.transforms.detection import AbstractDetectionTransform, DetectionPadIfNeeded, DetectionLongestMaxSize

from .obb import (
AbstractOBBDetectionTransform,
OBBDetectionPadIfNeeded,
OBBDetectionLongestMaxSize,
OBBDetectionStandardize,
OBBDetectionMixup,
OBBDetectionCompose,
OBBRemoveSmallObjects,
)

__all__ = [
"TRANSFORMS",
"ALBUMENTATIONS_TRANSFORMS",
Expand Down Expand Up @@ -76,6 +86,13 @@
"DetectionPadIfNeeded",
"DetectionLongestMaxSize",
"AbstractDetectionTransform",
"AbstractOBBDetectionTransform",
"OBBDetectionPadIfNeeded",
"OBBDetectionLongestMaxSize",
"OBBDetectionStandardize",
"OBBDetectionMixup",
"OBBDetectionCompose",
"OBBRemoveSmallObjects",
]

cv2.setNumThreads(0)
19 changes: 19 additions & 0 deletions src/super_gradients/training/transforms/obb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from super_gradients.training.samples.obb_sample import OBBSample
from .abstract_obb_transform import AbstractOBBDetectionTransform
from .obb_pad_if_needed import OBBDetectionPadIfNeeded
from .obb_longest_max_size import OBBDetectionLongestMaxSize
from .obb_standardize import OBBDetectionStandardize
from .obb_mixup import OBBDetectionMixup
from .obb_compose import OBBDetectionCompose
from .obb_remove_small_objects import OBBRemoveSmallObjects

__all__ = [
"OBBSample",
"AbstractOBBDetectionTransform",
"OBBDetectionPadIfNeeded",
"OBBDetectionLongestMaxSize",
"OBBDetectionStandardize",
"OBBDetectionMixup",
"OBBDetectionCompose",
"OBBRemoveSmallObjects",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import abc
import warnings
from abc import abstractmethod
from typing import List

from super_gradients.training.samples.obb_sample import OBBSample

__all__ = ["AbstractOBBDetectionTransform"]


class AbstractOBBDetectionTransform(abc.ABC):
"""
Base class for all transforms for object detection sample augmentation.
"""

def __init__(self, additional_samples_count: int = 0):
"""
:param additional_samples_count: (int) number of samples that must be extra samples from dataset. Default value is 0.
"""
self._additional_samples_count = additional_samples_count

@abstractmethod
def apply_to_sample(self, sample: OBBSample) -> OBBSample:
"""
Apply transformation to given pose estimation sample.
Important note - function call may return new object, may modify it in-place.
This is implementation dependent and if you need to keep original sample intact it
is recommended to make a copy of it BEFORE passing it to transform.

:param sample: Input sample to transform.
:return: Modified sample (It can be the same instance as input or a new object).
"""
raise NotImplementedError

@property
def additional_samples_count(self) -> int:
warnings.warn(
"This property is deprecated and will be removed in the future." "Please use `get_number_of_additional_samples` instead.", DeprecationWarning
)
return self.get_number_of_additional_samples()

def get_number_of_additional_samples(self) -> int:
"""
Returns number of additional samples required. The default implementation assumes that this number is fixed and deterministic.
Override in case this is not the case, e.g., you randomly choose to apply MixUp, etc
"""
return self._additional_samples_count

@property
def may_require_additional_samples(self) -> bool:
"""
Indicates whether additional samples are required. The default implementation assumes that this indicator is fixed and deterministic.
Override in case this is not the case, e.g., you randomly choose to apply MixUp, etc
"""
return self._additional_samples_count > 0

@abstractmethod
def get_equivalent_preprocessing(self) -> List:
raise NotImplementedError
93 changes: 93 additions & 0 deletions src/super_gradients/training/transforms/obb/obb_compose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import List

from .abstract_obb_transform import AbstractOBBDetectionTransform
from super_gradients.training.samples.obb_sample import OBBSample


class OBBDetectionCompose(AbstractOBBDetectionTransform):
"""
Composes several transforms together
"""

def __init__(self, transforms: List[AbstractOBBDetectionTransform], load_sample_fn=None):
"""

:param transforms: List of keypoint-based transformations
:param load_sample_fn: A method to load additional samples if needed (for mixup & mosaic augmentations).
Default value is None, which would raise an error if additional samples are needed.
"""
for transform in transforms:
if hasattr(transform, "may_require_additional_samples") and transform.may_require_additional_samples and load_sample_fn is None:
raise RuntimeError(f"Transform {transform.__class__.__name__} that requires additional samples but `load_sample_fn` is None")

super().__init__()
self.transforms = transforms
self.load_sample_fn = load_sample_fn

def apply_to_sample(self, sample: OBBSample) -> OBBSample:
"""
Applies the series of transformations to the input sample.
The function may modify the input sample inplace, so input sample should not be used after the call.

:param sample: Input sample
:return: Transformed sample.
"""
sample = sample.sanitize_sample()
sample = self._apply_transforms(sample, transforms=self.transforms, load_sample_fn=self.load_sample_fn)
return sample

@classmethod
def _apply_transforms(cls, sample: OBBSample, transforms: List[AbstractOBBDetectionTransform], load_sample_fn) -> OBBSample:
"""
This helper method allows us to query additional samples for mixup & mosaic augmentations
that would be also passed through augmentation pipeline. Example:

```
transforms:
- OBBDetectionLongestMaxSize:
max_height: ${dataset_params.image_size}
max_width: ${dataset_params.image_size}
- OBBDetectionMixup:
prob: ${dataset_params.mixup_prob}
```

In the example above all samples in mixup will be forwarded through OBBDetectionLongestMaxSize,
and only then mixed up.

:param sample: Input data sample
:param transforms: List of transformations to apply
:param load_sample_fn: A method to load additional samples if needed
:return: A data sample after applying transformations
"""
applied_transforms_so_far = []
for t in transforms:
if not hasattr(t, "may_require_additional_samples") or not t.may_require_additional_samples:
sample = t.apply_to_sample(sample)
applied_transforms_so_far.append(t)
else:
additional_samples = [load_sample_fn() for _ in range(t.get_number_of_additional_samples())]
additional_samples = [
cls._apply_transforms(
sample,
applied_transforms_so_far,
load_sample_fn=load_sample_fn,
)
for sample in additional_samples
]
sample.additional_samples = additional_samples
sample = t.apply_to_sample(sample)

return sample

def get_equivalent_preprocessing(self) -> List:
preprocessing = []
for t in self.transforms:
preprocessing += t.get_equivalent_preprocessing()
return preprocessing

def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += f"\t{repr(t)}"
format_string += "\n)"
return format_string
Loading
Loading