Skip to content

Commit f2182b3

Browse files
vijaykr338mvafin
andauthored
Added support for aten::randperm and aten::polar (#29585)
### Details: added support for aten::randperm and tests currently working on aten::polar and others ### Tickets: - #29547 --------- Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
1 parent 74cc63b commit f2182b3

File tree

5 files changed

+187
-0
lines changed

5 files changed

+187
-0
lines changed
+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "openvino/frontend/complex_type_mark.hpp"
6+
#include "openvino/op/convert.hpp"
7+
#include "openvino/op/cos.hpp"
8+
#include "openvino/op/multiply.hpp"
9+
#include "openvino/op/sin.hpp"
10+
#include "utils.hpp"
11+
12+
namespace ov {
13+
namespace frontend {
14+
namespace pytorch {
15+
namespace op {
16+
17+
using namespace ov::op;
18+
19+
OutputVector translate_polar(const NodeContext& context) {
20+
num_inputs_check(context, 2, 3);
21+
auto abs = context.get_input(0);
22+
auto angle = context.get_input(1);
23+
auto cos_node = context.mark_node(std::make_shared<v0::Cos>(angle));
24+
auto real = context.mark_node(std::make_shared<v1::Multiply>(abs, cos_node));
25+
auto sin_node = context.mark_node(std::make_shared<v0::Sin>(angle));
26+
auto imag = context.mark_node(std::make_shared<v1::Multiply>(abs, sin_node));
27+
auto complex_tensor = context.mark_node(std::make_shared<ComplexTypeMark>(real, imag));
28+
return {complex_tensor};
29+
}
30+
31+
} // namespace op
32+
} // namespace pytorch
33+
} // namespace frontend
34+
} // namespace ov
+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "openvino/frontend/pytorch/node_context.hpp"
6+
#include "openvino/op/constant.hpp"
7+
#include "openvino/op/random_uniform.hpp"
8+
#include "openvino/op/topk.hpp"
9+
#include "openvino/op/unsqueeze.hpp"
10+
#include "utils.hpp"
11+
12+
namespace ov {
13+
namespace frontend {
14+
namespace pytorch {
15+
namespace op {
16+
17+
using namespace ov::op;
18+
19+
OutputVector translate_randperm(const NodeContext& context) {
20+
auto num_inputs = context.get_input_size();
21+
auto n_node = context.get_input(0);
22+
int dtype_value = 4;
23+
if (num_inputs == 1) {
24+
} else if (num_inputs == 2) {
25+
if (!context.input_is_none(1)) {
26+
dtype_value = context.const_input<int>(1);
27+
PYTORCH_OP_CONVERSION_CHECK(dtype_value == 4,
28+
"Only dtype value 4 (int64) is supported for aten::randperm, got: ",
29+
dtype_value);
30+
}
31+
} else if (num_inputs == 5) {
32+
if (!context.input_is_none(1)) {
33+
dtype_value = context.const_input<int>(1);
34+
PYTORCH_OP_CONVERSION_CHECK(dtype_value == 4,
35+
"Only dtype value 4 (int64) is supported for aten::randperm, got: ",
36+
dtype_value);
37+
}
38+
} else {
39+
PYTORCH_OP_CONVERSION_CHECK(false, "Unexpected number of inputs for aten::randperm: ", num_inputs);
40+
}
41+
auto axis_zero = v0::Constant::create(element::i64, Shape{1}, {0});
42+
auto shape = context.mark_node(std::make_shared<v0::Unsqueeze>(n_node, axis_zero));
43+
auto min_val = v0::Constant::create(element::f32, Shape{}, {0.0f});
44+
auto max_val = v0::Constant::create(element::f32, Shape{}, {1.0f});
45+
auto random_tensor = context.mark_node(std::make_shared<v8::RandomUniform>(shape, min_val, max_val, element::f32));
46+
const int64_t axis = 0;
47+
auto topk = context.mark_node(std::make_shared<v11::TopK>(random_tensor,
48+
n_node,
49+
axis,
50+
ov::op::TopKMode::MIN,
51+
ov::op::TopKSortType::SORT_VALUES,
52+
element::i64,
53+
false));
54+
return {topk->output(1)};
55+
}
56+
57+
} // namespace op
58+
} // namespace pytorch
59+
} // namespace frontend
60+
} // namespace ov

src/frontends/pytorch/src/op_table.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ OP_CONVERTER(translate_permute);
197197
OP_CONVERTER(translate_pairwise_distance);
198198
OP_CONVERTER(translate_pixel_shuffle);
199199
OP_CONVERTER(translate_pixel_unshuffle);
200+
OP_CONVERTER(translate_polar);
200201
OP_CONVERTER(translate_pow);
201202
OP_CONVERTER(translate_prod);
202203
OP_CONVERTER(translate_pythonop);
@@ -208,6 +209,7 @@ OP_CONVERTER(translate_quantized_hardswish);
208209
OP_CONVERTER(translate_quantized_mul);
209210
OP_CONVERTER(translate_range_length);
210211
OP_CONVERTER(translate_rand);
212+
OP_CONVERTER(translate_randperm);
211213
OP_CONVERTER(translate_randn);
212214
OP_CONVERTER(translate_randint);
213215
OP_CONVERTER(translate_rand_like);
@@ -635,12 +637,14 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
635637
{"aten::pixel_shuffle", op::translate_pixel_shuffle},
636638
{"aten::pixel_unshuffle", op::translate_pixel_unshuffle},
637639
{"aten::prelu", op::translate_1to1_match_2_inputs<opset10::PRelu>},
640+
{"aten::polar", op::translate_polar},
638641
{"aten::pow", op::translate_pow},
639642
{"aten::pow_", op::translate_pow},
640643
{"aten::prod", op::translate_prod},
641644
{"aten::quantize_per_channel", op::translate_quantize_per_channel},
642645
{"aten::quantize_per_tensor", op::translate_quantize_per_tensor},
643646
{"aten::rand", op::translate_rand},
647+
{"aten::randperm", op::translate_randperm},
644648
{"aten::rand_like", op::translate_rand_like},
645649
{"aten::randint", op::translate_randint},
646650
{"aten::randn", op::translate_randn},
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (C) 2018-2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import numpy as np
5+
import pytest
6+
import torch
7+
import openvino as ov
8+
from pytorch_layer_test_class import PytorchLayerTest
9+
10+
class TestPolar(PytorchLayerTest):
11+
def _prepare_input(self, input_shape=(1, 1000), dtype=np.float32):
12+
return (
13+
np.random.uniform(0, 10, input_shape).astype(dtype),
14+
np.random.uniform(-np.pi, np.pi, input_shape).astype(dtype)
15+
)
16+
17+
def create_model(self):
18+
class PolarModel(torch.nn.Module):
19+
def forward(self, abs, angle):
20+
complex_tensor = torch.polar(abs, angle)
21+
return torch.view_as_real(complex_tensor)
22+
ref_net = None
23+
return PolarModel(), None, "aten::polar"
24+
25+
@pytest.mark.parametrize("input_case", [
26+
(1, 1000),
27+
(2, 500),
28+
(5, 200),
29+
(10, 100),
30+
])
31+
@pytest.mark.parametrize("dtype", [
32+
np.float32,
33+
np.float64
34+
])
35+
@pytest.mark.nightly
36+
@pytest.mark.precommit
37+
def test_polar(self, input_case, dtype, ie_device, precision, ir_version):
38+
self._test(*self.create_model(), ie_device, precision, ir_version,
39+
kwargs_to_prepare_input={"input_shape": input_case, "dtype": dtype},
40+
trace_model=True, use_convert_model=True, dynamic_shapes=False)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (C) 2018-2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import pytest
5+
import torch
6+
import numpy as np
7+
from pytorch_layer_test_class import PytorchLayerTest
8+
9+
class TestSortedRandperm(PytorchLayerTest):
10+
def _prepare_input(self):
11+
return (np.arange(self.n, dtype=np.int64),)
12+
13+
def create_model(self, n, num_inputs, dtype_value=None):
14+
class AtenSortedRandperm(torch.nn.Module):
15+
def __init__(self, n, num_inputs, dtype_value):
16+
super().__init__()
17+
self.n = n
18+
self.num_inputs = num_inputs
19+
self.dtype = torch.int64 if dtype_value == 4 else None
20+
21+
def forward(self, x):
22+
if self.num_inputs == 1:
23+
p = torch.randperm(self.n)
24+
elif self.num_inputs == 2:
25+
p = torch.randperm(self.n, dtype=self.dtype)
26+
elif self.num_inputs == 5:
27+
p = torch.randperm(self.n, dtype=self.dtype, layout=torch.strided,
28+
device=x.device, pin_memory=False)
29+
else:
30+
raise ValueError("Invalid num_inputs")
31+
# sort to get a deterministic order for verifying the permutation.
32+
x_permuted = x[p]
33+
sorted_tensor, _ = torch.sort(x_permuted)
34+
return sorted_tensor
35+
36+
return AtenSortedRandperm(n, num_inputs, dtype_value), None, "aten::randperm"
37+
38+
@pytest.mark.parametrize(("n", "num_inputs", "dtype_value"), [
39+
(0, 1, None),
40+
(1, 1, None),
41+
(5, 1, None),
42+
(5, 2, 4),
43+
(5, 5, 4),
44+
])
45+
@pytest.mark.nightly
46+
@pytest.mark.precommit
47+
def test_sorted_randperm(self, n, num_inputs, dtype_value, ie_device, precision, ir_version):
48+
self.n = n
49+
self._test(*self.create_model(n, num_inputs, dtype_value), ie_device, precision, ir_version)

0 commit comments

Comments
 (0)