Skip to content

Commit 2660f85

Browse files
authored
[PT FE] Support aten::masked_select for pytorch models (#26162)
### Details: - support `aten::masked_select` operator ### Tickets: - [None](#23325)
1 parent b80a179 commit 2660f85

File tree

5 files changed

+100
-0
lines changed

5 files changed

+100
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include "openvino/frontend/pytorch/node_context.hpp"
2+
#include "utils.hpp"
3+
4+
namespace ov {
5+
namespace frontend {
6+
namespace pytorch {
7+
namespace op {
8+
9+
using namespace ov::op;
10+
11+
OutputVector translate_masked_select(const NodeContext& context) {
12+
// aten::masked_select(Tensor self, Tensor mask, Tensor source) -> Tensor
13+
num_inputs_check(context, 2, 2);
14+
auto data = context.get_input(0);
15+
auto mask = context.get_input(1);
16+
auto res = masked_select(context, data, mask);
17+
return {res};
18+
};
19+
20+
} // namespace op
21+
} // namespace pytorch
22+
} // namespace frontend
23+
} // namespace ov

src/frontends/pytorch/src/op_table.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ OP_CONVERTER(translate_loop);
136136
OP_CONVERTER(translate_lstm);
137137
OP_CONVERTER(translate_masked_fill);
138138
OP_CONVERTER(translate_masked_scatter);
139+
OP_CONVERTER(translate_masked_select);
139140
OP_CONVERTER(translate_max);
140141
OP_CONVERTER(translate_maximum);
141142
OP_CONVERTER(translate_max_poolnd);
@@ -528,6 +529,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
528529
{"aten::lt", op::translate_1to1_match_2_inputs_align_types<opset10::Less>},
529530
{"aten::masked_fill", op::translate_masked_fill},
530531
{"aten::masked_scatter", op::translate_masked_scatter},
532+
{"aten::masked_select", op::translate_masked_select},
531533
{"aten::matmul", op::translate_1to1_match_2_inputs<opset10::MatMul>},
532534
{"aten::max", op::translate_max},
533535
{"aten::mv", op::translate_1to1_match_2_inputs<opset10::MatMul>},

src/frontends/pytorch/src/utils.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@
1010
#include "openvino/op/add.hpp"
1111
#include "openvino/op/broadcast.hpp"
1212
#include "openvino/op/concat.hpp"
13+
#include "openvino/op/constant.hpp"
1314
#include "openvino/op/convert_promote_types.hpp"
1415
#include "openvino/op/divide.hpp"
1516
#include "openvino/op/gather.hpp"
17+
#include "openvino/op/gather_nd.hpp"
1618
#include "openvino/op/mod.hpp"
19+
#include "openvino/op/non_zero.hpp"
1720
#include "openvino/op/range.hpp"
1821
#include "openvino/op/reduce_prod.hpp"
1922
#include "openvino/op/reshape.hpp"
@@ -22,6 +25,7 @@
2225
#include "openvino/op/slice.hpp"
2326
#include "openvino/op/squeeze.hpp"
2427
#include "openvino/op/subtract.hpp"
28+
#include "openvino/op/transpose.hpp"
2529
#include "openvino/op/unsqueeze.hpp"
2630
#include "openvino/util/log.hpp"
2731
#include "pt_framework_node.hpp"
@@ -592,6 +596,13 @@ Output<Node> concat_list_from_inputs(const NodeContext& context, size_t begin, s
592596
return concat;
593597
}
594598

599+
Output<Node> masked_select(const NodeContext& context, const Output<Node>& data, const Output<Node>& mask) {
600+
auto input_order = context.mark_node(v0::Constant::create(element::i32, Shape{2}, {1, 0}));
601+
auto nonzero = context.mark_node(std::make_shared<v3::NonZero>(mask));
602+
auto masked_id = context.mark_node(std::make_shared<v1::Transpose>(nonzero, input_order));
603+
return context.mark_node(std::make_shared<v8::GatherND>(data, masked_id));
604+
}
605+
595606
} // namespace pytorch
596607
} // namespace frontend
597608
} // namespace ov

src/frontends/pytorch/src/utils.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ Output<Node> masked_fill(ov::pass::NodeRegistry& rg,
121121

122122
Output<Node> concat_list_from_inputs(const NodeContext& context, size_t begin, size_t end);
123123

124+
Output<Node> masked_select(const NodeContext& context, const Output<Node>& data, const Output<Node>& mask);
125+
124126
namespace op {
125127
template <OutputVector (*T)(const NodeContext&), size_t idx = 0>
126128
OutputVector inplace_op(const NodeContext& context) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (C) 2018-2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import numpy as np
5+
import torch
6+
from packaging.version import parse as parse_version
7+
import pytest
8+
9+
from pytorch_layer_test_class import PytorchLayerTest
10+
11+
12+
class TestMaskedSelect(PytorchLayerTest):
13+
def _prepare_input(self, mask_select='ones', mask_dtype=bool, input_dtype=float):
14+
input_shape = [1, 10]
15+
mask = np.zeros(input_shape).astype(mask_dtype)
16+
if mask_select == 'ones':
17+
mask = np.ones(input_shape).astype(mask_dtype)
18+
if mask_select == 'random':
19+
idx = np.random.choice(10, 5)
20+
mask[:, idx] = 1
21+
return (np.random.randn(1, 10).astype(input_dtype), mask)
22+
23+
def create_model(self):
24+
import torch
25+
26+
class aten_masked_select(torch.nn.Module):
27+
def __init__(self):
28+
super(aten_masked_select, self).__init__()
29+
30+
def forward(self, x, mask):
31+
return x.masked_select(mask)
32+
33+
ref_net = None
34+
35+
return aten_masked_select(), ref_net, "aten::masked_select"
36+
37+
@pytest.mark.parametrize(
38+
"mask_select", ['zeros', 'ones', 'random'])
39+
@pytest.mark.parametrize("input_dtype", [np.float32, np.float64, int, np.int32])
40+
@pytest.mark.nightly
41+
@pytest.mark.precommit
42+
def test_masked_select(self, mask_select, input_dtype, ie_device, precision, ir_version):
43+
self._test(*self.create_model(),
44+
ie_device, precision, ir_version,
45+
dynamic_shapes=False,
46+
trace_model=True,
47+
kwargs_to_prepare_input={'mask_select': mask_select, 'mask_dtype': bool, "input_dtype": input_dtype})
48+
49+
@pytest.mark.skipif(parse_version(torch.__version__) >= parse_version("2.1.0"), reason="pytorch 2.1 and above does not support nonboolean mask")
50+
@pytest.mark.parametrize(
51+
"mask_select", ['zeros', 'ones', 'random'])
52+
@pytest.mark.parametrize("input_dtype", [np.float32, np.float64, int, np.int32])
53+
@pytest.mark.parametrize("mask_dtype", [np.uint8, np.int32, np.float32])
54+
@pytest.mark.nightly
55+
@pytest.mark.precommit
56+
def test_masked_select_non_bool_mask(self, mask_select, mask_dtype, input_dtype, ie_device, precision, ir_version):
57+
self._test(*self.create_model(),
58+
ie_device, precision, ir_version,
59+
dynamic_shapes=False,
60+
trace_model=True,
61+
kwargs_to_prepare_input={'mask_select': mask_select, 'mask_dtype': mask_dtype, "input_dtype": input_dtype})
62+

0 commit comments

Comments
 (0)