@@ -3993,7 +3993,7 @@ class TestGaussianNoise:
3993
3993
"make_input" ,
3994
3994
[make_image_tensor , make_image , make_video ],
3995
3995
)
3996
- def test_kernel (self , make_input ):
3996
+ def test_kernel_float (self , make_input ):
3997
3997
check_kernel (
3998
3998
F .gaussian_noise ,
3999
3999
make_input (dtype = torch .float32 ),
@@ -4005,9 +4005,28 @@ def test_kernel(self, make_input):
4005
4005
"make_input" ,
4006
4006
[make_image_tensor , make_image , make_video ],
4007
4007
)
4008
- def test_functional (self , make_input ):
4008
+ def test_kernel_uint8 (self , make_input ):
4009
+ check_kernel (
4010
+ F .gaussian_noise ,
4011
+ make_input (dtype = torch .uint8 ),
4012
+ # This cannot pass because the noise on a batch in not per-image
4013
+ check_batched_vs_unbatched = False ,
4014
+ )
4015
+
4016
+ @pytest .mark .parametrize (
4017
+ "make_input" ,
4018
+ [make_image_tensor , make_image , make_video ],
4019
+ )
4020
+ def test_functional_float (self , make_input ):
4009
4021
check_functional (F .gaussian_noise , make_input (dtype = torch .float32 ))
4010
4022
4023
+ @pytest .mark .parametrize (
4024
+ "make_input" ,
4025
+ [make_image_tensor , make_image , make_video ],
4026
+ )
4027
+ def test_functional_uint8 (self , make_input ):
4028
+ check_functional (F .gaussian_noise , make_input (dtype = torch .uint8 ))
4029
+
4011
4030
@pytest .mark .parametrize (
4012
4031
("kernel" , "input_type" ),
4013
4032
[
@@ -4023,10 +4042,11 @@ def test_functional_signature(self, kernel, input_type):
4023
4042
"make_input" ,
4024
4043
[make_image_tensor , make_image , make_video ],
4025
4044
)
4026
- def test_transform (self , make_input ):
4045
+ def test_transform_float (self , make_input ):
4027
4046
def adapter (_ , input , __ ):
4028
- # This transform doesn't support uint8 so we have to convert the auto-generated uint8 tensors to float32
4029
- # Same for PIL images
4047
+ # We have two different implementations for floats and uint8
4048
+ # To test this implementation we'll convert the auto-generated uint8 tensors to float32
4049
+ # We don't support other int dtypes nor pil images
4030
4050
for key , value in input .items ():
4031
4051
if isinstance (value , torch .Tensor ) and not value .is_floating_point ():
4032
4052
input [key ] = value .to (torch .float32 )
@@ -4036,11 +4056,29 @@ def adapter(_, input, __):
4036
4056
4037
4057
check_transform (transforms .GaussianNoise (), make_input (dtype = torch .float32 ), check_sample_input = adapter )
4038
4058
4059
+ @pytest .mark .parametrize (
4060
+ "make_input" ,
4061
+ [make_image_tensor , make_image , make_video ],
4062
+ )
4063
+ def test_transform_uint8 (self , make_input ):
4064
+ def adapter (_ , input , __ ):
4065
+ # We have two different implementations for floats and uint8
4066
+ # To test this implementation we'll convert every tensor to uint8
4067
+ # We don't support other int dtypes nor pil images
4068
+ for key , value in input .items ():
4069
+ if isinstance (value , torch .Tensor ) and not value .dtype != torch .uint8 :
4070
+ input [key ] = value .to (torch .uint8 )
4071
+ if isinstance (value , PIL .Image .Image ):
4072
+ input [key ] = F .pil_to_tensor (value ).to (torch .uint8 )
4073
+ return input
4074
+
4075
+ check_transform (transforms .GaussianNoise (), make_input (dtype = torch .uint8 ), check_sample_input = adapter )
4076
+
4039
4077
def test_bad_input (self ):
4040
4078
with pytest .raises (ValueError , match = "Gaussian Noise is not implemented for PIL images." ):
4041
4079
F .gaussian_noise (make_image_pil ())
4042
- with pytest .raises (ValueError , match = "Input tensor is expected to be in float dtype" ):
4043
- F .gaussian_noise (make_image (dtype = torch .uint8 ))
4080
+ with pytest .raises (ValueError , match = "Input tensor is expected to be in uint8 or float dtype" ):
4081
+ F .gaussian_noise (make_image (dtype = torch .int32 ))
4044
4082
with pytest .raises (ValueError , match = "sigma shouldn't be negative" ):
4045
4083
F .gaussian_noise (make_image (dtype = torch .float32 ), sigma = - 1 )
4046
4084
0 commit comments