Skip to content

Commit b208f7f

Browse files
diaz-esparzasergiusNicolasHug
authored
Implementation of the GaussianNoise transform for uint8 inputs (#9169)
Co-authored-by: sergius <UO293837@uniovi.es> Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com> Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
1 parent ce5b26a commit b208f7f

File tree

3 files changed

+76
-17
lines changed

3 files changed

+76
-17
lines changed

test/test_transforms_v2.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3993,7 +3993,7 @@ class TestGaussianNoise:
39933993
"make_input",
39943994
[make_image_tensor, make_image, make_video],
39953995
)
3996-
def test_kernel(self, make_input):
3996+
def test_kernel_float(self, make_input):
39973997
check_kernel(
39983998
F.gaussian_noise,
39993999
make_input(dtype=torch.float32),
@@ -4005,9 +4005,28 @@ def test_kernel(self, make_input):
40054005
"make_input",
40064006
[make_image_tensor, make_image, make_video],
40074007
)
4008-
def test_functional(self, make_input):
4008+
def test_kernel_uint8(self, make_input):
4009+
check_kernel(
4010+
F.gaussian_noise,
4011+
make_input(dtype=torch.uint8),
4012+
# This cannot pass because the noise on a batch in not per-image
4013+
check_batched_vs_unbatched=False,
4014+
)
4015+
4016+
@pytest.mark.parametrize(
4017+
"make_input",
4018+
[make_image_tensor, make_image, make_video],
4019+
)
4020+
def test_functional_float(self, make_input):
40094021
check_functional(F.gaussian_noise, make_input(dtype=torch.float32))
40104022

4023+
@pytest.mark.parametrize(
4024+
"make_input",
4025+
[make_image_tensor, make_image, make_video],
4026+
)
4027+
def test_functional_uint8(self, make_input):
4028+
check_functional(F.gaussian_noise, make_input(dtype=torch.uint8))
4029+
40114030
@pytest.mark.parametrize(
40124031
("kernel", "input_type"),
40134032
[
@@ -4023,10 +4042,11 @@ def test_functional_signature(self, kernel, input_type):
40234042
"make_input",
40244043
[make_image_tensor, make_image, make_video],
40254044
)
4026-
def test_transform(self, make_input):
4045+
def test_transform_float(self, make_input):
40274046
def adapter(_, input, __):
4028-
# This transform doesn't support uint8 so we have to convert the auto-generated uint8 tensors to float32
4029-
# Same for PIL images
4047+
# We have two different implementations for floats and uint8
4048+
# To test this implementation we'll convert the auto-generated uint8 tensors to float32
4049+
# We don't support other int dtypes nor pil images
40304050
for key, value in input.items():
40314051
if isinstance(value, torch.Tensor) and not value.is_floating_point():
40324052
input[key] = value.to(torch.float32)
@@ -4036,11 +4056,29 @@ def adapter(_, input, __):
40364056

40374057
check_transform(transforms.GaussianNoise(), make_input(dtype=torch.float32), check_sample_input=adapter)
40384058

4059+
@pytest.mark.parametrize(
4060+
"make_input",
4061+
[make_image_tensor, make_image, make_video],
4062+
)
4063+
def test_transform_uint8(self, make_input):
4064+
def adapter(_, input, __):
4065+
# We have two different implementations for floats and uint8
4066+
# To test this implementation we'll convert every tensor to uint8
4067+
# We don't support other int dtypes nor pil images
4068+
for key, value in input.items():
4069+
if isinstance(value, torch.Tensor) and not value.dtype != torch.uint8:
4070+
input[key] = value.to(torch.uint8)
4071+
if isinstance(value, PIL.Image.Image):
4072+
input[key] = F.pil_to_tensor(value).to(torch.uint8)
4073+
return input
4074+
4075+
check_transform(transforms.GaussianNoise(), make_input(dtype=torch.uint8), check_sample_input=adapter)
4076+
40394077
def test_bad_input(self):
40404078
with pytest.raises(ValueError, match="Gaussian Noise is not implemented for PIL images."):
40414079
F.gaussian_noise(make_image_pil())
4042-
with pytest.raises(ValueError, match="Input tensor is expected to be in float dtype"):
4043-
F.gaussian_noise(make_image(dtype=torch.uint8))
4080+
with pytest.raises(ValueError, match="Input tensor is expected to be in uint8 or float dtype"):
4081+
F.gaussian_noise(make_image(dtype=torch.int32))
40444082
with pytest.raises(ValueError, match="sigma shouldn't be negative"):
40454083
F.gaussian_noise(make_image(dtype=torch.float32), sigma=-1)
40464084

torchvision/transforms/v2/_misc.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,22 @@ class GaussianNoise(Transform):
214214
Each image or frame in a batch will be transformed independently i.e. the
215215
noise added to each image will be different.
216216
217-
The input tensor is also expected to be of float dtype in ``[0, 1]``.
218-
This transform does not support PIL images.
217+
The input tensor is also expected to be of float dtype in ``[0, 1]``,
218+
or of ``uint8`` dtype in ``[0, 255]``. This transform does not support PIL
219+
images.
220+
221+
Regardless of the dtype used, the parameters of the function use the same
222+
scale, so a ``mean`` parameter of 0.5 will result in an average value
223+
increase of 0.5 units for float images, and an average increase of 127.5
224+
units for ``uint8`` images.
219225
220226
Args:
221227
mean (float): Mean of the sampled normal distribution. Default is 0.
222228
sigma (float): Standard deviation of the sampled normal distribution. Default is 0.1.
223-
clip (bool, optional): Whether to clip the values in ``[0, 1]`` after adding noise. Default is True.
229+
clip (bool, optional): Whether to clip the values after adding noise, be it to
230+
``[0, 1]`` for floats or to ``[0, 255]`` for ``uint8``. Setting this parameter to
231+
``False`` may cause unsigned integer overflows with uint8 inputs.
232+
Default is True.
224233
"""
225234

226235
def __init__(self, mean: float = 0.0, sigma: float = 0.1, clip=True) -> None:

torchvision/transforms/v2/functional/_misc.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -195,16 +195,28 @@ def gaussian_noise(inpt: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, cl
195195
@_register_kernel_internal(gaussian_noise, torch.Tensor)
196196
@_register_kernel_internal(gaussian_noise, tv_tensors.Image)
197197
def gaussian_noise_image(image: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor:
198-
if not image.is_floating_point():
199-
raise ValueError(f"Input tensor is expected to be in float dtype, got dtype={image.dtype}")
200198
if sigma < 0:
201199
raise ValueError(f"sigma shouldn't be negative. Got {sigma}")
202200

203-
noise = mean + torch.randn_like(image) * sigma
204-
out = image + noise
205-
if clip:
206-
out = torch.clamp(out, 0, 1)
207-
return out
201+
if image.is_floating_point():
202+
noise = mean + torch.randn_like(image) * sigma
203+
out = image + noise
204+
if clip:
205+
out = torch.clamp(out, 0, 1)
206+
return out
207+
208+
elif image.dtype == torch.uint8:
209+
# Convert to intermediate dtype int16 to add to input more efficiently
210+
# See https://github.com/pytorch/vision/pull/9169 for alternative implementations and benchmark
211+
noise = ((mean * 255) + torch.randn_like(image, dtype=torch.float32) * (sigma * 255)).to(torch.int16)
212+
out = image + noise
213+
214+
if clip:
215+
out = torch.clamp(out, 0, 255)
216+
return out.to(torch.uint8)
217+
218+
else:
219+
raise ValueError(f"Input tensor is expected to be in uint8 or float dtype, got dtype={image.dtype}")
208220

209221

210222
@_register_kernel_internal(gaussian_noise, tv_tensors.Video)

0 commit comments

Comments
 (0)