Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for aten::randperm and aten::polar #29585

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/frontends/pytorch/src/op/polar.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include "openvino/frontend/complex_type_mark.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/cos.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/sin.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_polar(const NodeContext& context) {
num_inputs_check(context, 2, 3);
auto abs = context.get_input(0);
auto angle = context.get_input(1);
auto real =
context.mark_node(std::make_shared<v1::Multiply>(abs, context.mark_node(std::make_shared<v0::Cos>(angle))));
auto imag =
context.mark_node(std::make_shared<v1::Multiply>(abs, context.mark_node(std::make_shared<v0::Sin>(angle))));
auto complex_tensor = context.mark_node(std::make_shared<ComplexTypeMark>(real, imag));
return {complex_tensor};
}

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
60 changes: 60 additions & 0 deletions src/frontends/pytorch/src/op/randperm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/random_uniform.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/topk.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_randperm(const NodeContext& context) {
auto num_inputs = context.get_input_size();
int64_t n = context.const_input<int64_t>(0);
int dtype_value = 4;
if (num_inputs == 1) {
} else if (num_inputs == 2) {
if (!context.input_is_none(1)) {
dtype_value = context.const_input<int>(1);
OPENVINO_ASSERT(dtype_value == 4,
"Only dtype value 4 (int64) is supported for aten::randperm, got: ",
dtype_value);
}
} else if (num_inputs == 5) {
if (!context.input_is_none(1)) {
dtype_value = context.const_input<int>(1);
OPENVINO_ASSERT(dtype_value == 4,
"Only dtype value 4 (int64) is supported for aten::randperm, got: ",
dtype_value);
}
} else {
PYTORCH_OP_CONVERSION_CHECK(false, "Unexpected number of inputs for aten::randperm: ", num_inputs);
}
if (n == 0) {
auto const_empty = std::make_shared<v0::Constant>(element::i64, Shape{0}, std::vector<int64_t>{});
return {context.mark_node(const_empty)};
}
auto shape = v0::Constant::create(element::i64, Shape{1}, {n});
auto min_val = v0::Constant::create(element::f32, Shape{}, {0.0f});
auto max_val = v0::Constant::create(element::f32, Shape{}, {1.0f});
auto random_tensor = context.mark_node(std::make_shared<v8::RandomUniform>(shape, min_val, max_val, element::f32));
const int64_t axis = 0;
auto k = v0::Constant::create(element::i64, Shape{}, {n});
auto topk = context.mark_node(std::make_shared<v11::TopK>(random_tensor,
k,
axis,
ov::op::TopKMode::MIN,
ov::op::TopKSortType::SORT_VALUES,
element::i64,
false));
return {topk->output(1)};
}

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
4 changes: 4 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ OP_CONVERTER(translate_permute);
OP_CONVERTER(translate_pairwise_distance);
OP_CONVERTER(translate_pixel_shuffle);
OP_CONVERTER(translate_pixel_unshuffle);
OP_CONVERTER(translate_polar);
OP_CONVERTER(translate_pow);
OP_CONVERTER(translate_prod);
OP_CONVERTER(translate_pythonop);
Expand All @@ -208,6 +209,7 @@ OP_CONVERTER(translate_quantized_hardswish);
OP_CONVERTER(translate_quantized_mul);
OP_CONVERTER(translate_range_length);
OP_CONVERTER(translate_rand);
OP_CONVERTER(translate_randperm);
OP_CONVERTER(translate_randn);
OP_CONVERTER(translate_randint);
OP_CONVERTER(translate_rand_like);
Expand Down Expand Up @@ -635,12 +637,14 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::pixel_shuffle", op::translate_pixel_shuffle},
{"aten::pixel_unshuffle", op::translate_pixel_unshuffle},
{"aten::prelu", op::translate_1to1_match_2_inputs<opset10::PRelu>},
{"aten::polar", op::translate_polar},
{"aten::pow", op::translate_pow},
{"aten::pow_", op::translate_pow},
{"aten::prod", op::translate_prod},
{"aten::quantize_per_channel", op::translate_quantize_per_channel},
{"aten::quantize_per_tensor", op::translate_quantize_per_tensor},
{"aten::rand", op::translate_rand},
{"aten::randperm", op::translate_randperm},
{"aten::rand_like", op::translate_rand_like},
{"aten::randint", op::translate_randint},
{"aten::randn", op::translate_randn},
Expand Down
42 changes: 42 additions & 0 deletions tests/layer_tests/pytorch_tests/test_polar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
import torch
import openvino as ov
from pytorch_layer_test_class import PytorchLayerTest

class TestPolar(PytorchLayerTest):
def _prepare_input(self, input_shape=(1, 1000), dtype=np.float32):
return (
np.random.uniform(0, 10, input_shape).astype(dtype),
np.random.uniform(-np.pi, np.pi, input_shape).astype(dtype)
)

def create_model(self):
class PolarModel(torch.nn.Module):
def forward(self, abs, angle):
complex_tensor = torch.polar(abs, angle)
return torch.view_as_real(complex_tensor)
ref_net = None
return PolarModel(), None, "aten::polar"

@pytest.mark.parametrize("input_case", [
(1, 1000),
(2, 500),
(5, 200),
(10, 100),
])
@pytest.mark.parametrize("dtype", [
np.float32,
np.float64
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_polar(self, input_case, dtype, ie_device, precision, ir_version):
atol = 1e-4 if precision == "FP32" else 1e-3
rtol = 1e-4
self._test(*self.create_model(), ie_device, precision, ir_version,
kwargs_to_prepare_input={"input_shape": input_case, "dtype": dtype},
trace_model=True, use_convert_model=True, custom_eps=atol, dynamic_shapes=False)
57 changes: 57 additions & 0 deletions tests/layer_tests/pytorch_tests/test_randperm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest
import torch
import numpy as np
from pytorch_layer_test_class import PytorchLayerTest, flattenize_inputs

class TestRandperm(PytorchLayerTest):
def _prepare_input(self):
return (np.array([self.n], dtype=np.int64),)

def create_model(self, n, num_inputs, dtype_value=None):
class aten_randperm(torch.nn.Module):
def __init__(self, n, num_inputs, dtype_value):
super().__init__()
self.n = torch.tensor(n, dtype=torch.int64)
self.num_inputs = num_inputs
self.dtype = torch.int64 if dtype_value == 4 else None

def forward(self, x):
if self.num_inputs == 1:
return torch.randperm(self.n)
elif self.num_inputs == 2:
return torch.randperm(self.n, dtype=self.dtype)
elif self.num_inputs == 5:
return torch.randperm(self.n, dtype=self.dtype, layout=torch.strided,
device=x.device, pin_memory=False)
raise ValueError("Invalid num_inputs")

return aten_randperm(n, num_inputs, dtype_value), None, "aten::randperm"

@pytest.mark.parametrize(("n", "num_inputs", "dtype_value"), [
(0, 1, None),
(1, 1, None),
(5, 1, None),
(5, 2, 4),
(5, 5, 4),
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_randperm(self, n, num_inputs, dtype_value, ie_device, precision, ir_version):
self.n = n
model, ref_net, op = self.create_model(n, num_inputs, dtype_value)
inputs = self._prepare_input()
torch_inputs = [torch.from_numpy(x) if isinstance(x, np.ndarray) else x for x in inputs]
ov_inputs = flattenize_inputs(inputs)
smodel, converted_model = self.convert_directly_via_frontend(
model, torch_inputs, trace_model=True, dynamic_shapes=False, ov_inputs=ov_inputs, freeze_model=True
)
from openvino import Core
core = Core()
compiled_model = core.compile_model(converted_model, ie_device)

ov_output = compiled_model(ov_inputs)[0]
if n > 0:
assert ov_output.shape[0] == n, f"Output shape {ov_output.shape} does not match expected ({n},)"
assert np.array_equal(np.sort(ov_output), np.arange(n)), f"Output is not a valid permutation of [0, ..., {n-1}]"
else:
assert ov_output.shape[0] == 0, f"Output shape for n=0 should be (0,), got {ov_output.shape}"