1
+ # Copyright (C) 2018-2024 Intel Corporation
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import torch
6
+ from packaging .version import parse as parse_version
7
+ import pytest
8
+
9
+ from pytorch_layer_test_class import PytorchLayerTest
10
+
11
+
12
+ class TestMaskedSelect (PytorchLayerTest ):
13
+ def _prepare_input (self , mask_select = 'ones' , mask_dtype = bool , input_dtype = float ):
14
+ input_shape = [1 , 10 ]
15
+ mask = np .zeros (input_shape ).astype (mask_dtype )
16
+ if mask_select == 'ones' :
17
+ mask = np .ones (input_shape ).astype (mask_dtype )
18
+ if mask_select == 'random' :
19
+ idx = np .random .choice (10 , 5 )
20
+ mask [:, idx ] = 1
21
+ return (np .random .randn (1 , 10 ).astype (input_dtype ), mask )
22
+
23
+ def create_model (self ):
24
+ import torch
25
+
26
+ class aten_masked_select (torch .nn .Module ):
27
+ def __init__ (self ):
28
+ super (aten_masked_select , self ).__init__ ()
29
+
30
+ def forward (self , x , mask ):
31
+ return x .masked_select (mask )
32
+
33
+ ref_net = None
34
+
35
+ return aten_masked_select (), ref_net , "aten::masked_select"
36
+
37
+ @pytest .mark .parametrize (
38
+ "mask_select" , ['zeros' , 'ones' , 'random' ])
39
+ @pytest .mark .parametrize ("input_dtype" , [np .float32 , np .float64 , int , np .int32 ])
40
+ @pytest .mark .nightly
41
+ @pytest .mark .precommit
42
+ def test_masked_select (self , mask_select , input_dtype , ie_device , precision , ir_version ):
43
+ self ._test (* self .create_model (),
44
+ ie_device , precision , ir_version ,
45
+ dynamic_shapes = False ,
46
+ trace_model = True ,
47
+ kwargs_to_prepare_input = {'mask_select' : mask_select , 'mask_dtype' : bool , "input_dtype" : input_dtype })
48
+
49
+ @pytest .mark .skipif (parse_version (torch .__version__ ) >= parse_version ("2.1.0" ), reason = "pytorch 2.1 and above does not support nonboolean mask" )
50
+ @pytest .mark .parametrize (
51
+ "mask_select" , ['zeros' , 'ones' , 'random' ])
52
+ @pytest .mark .parametrize ("input_dtype" , [np .float32 , np .float64 , int , np .int32 ])
53
+ @pytest .mark .parametrize ("mask_dtype" , [np .uint8 , np .int32 , np .float32 ])
54
+ @pytest .mark .nightly
55
+ @pytest .mark .precommit
56
+ def test_masked_select_non_bool_mask (self , mask_select , mask_dtype , input_dtype , ie_device , precision , ir_version ):
57
+ self ._test (* self .create_model (),
58
+ ie_device , precision , ir_version ,
59
+ dynamic_shapes = False ,
60
+ trace_model = True ,
61
+ kwargs_to_prepare_input = {'mask_select' : mask_select , 'mask_dtype' : mask_dtype , "input_dtype" : input_dtype })
62
+
0 commit comments