-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for aten::randperm and aten::polar #29585
Open
vijaykr338
wants to merge
15
commits into
openvinotoolkit:master
Choose a base branch
from
vijaykr338:work2
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+195
−0
Open
Changes from 11 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
0656890
added aten::take function in pytorch, but tests are not running
vijaykr338 98c1c61
implemented aten::str, aten::delete but unable to write their tests
vijaykr338 725c474
Merge branch 'master' into master
vijaykr338 4176919
added support for aten::randperm with tests
vijaykr338 e32628f
Delete src/frontends/pytorch/src/op/delete.cpp
vijaykr338 1483255
cleaned
vijaykr338 8a16331
Delete tests/layer_tests/pytorch_tests/test_take.py
vijaykr338 e8b35e8
Delete src/frontends/pytorch/src/op/take.cpp
vijaykr338 16f11fa
Delete src/frontends/pytorch/src/op/str.cpp
vijaykr338 cd18e01
added support for aten::polar
vijaykr338 113d07e
fixed code style for aten::polar and aten::randperm
vijaykr338 c6ec86e
Merge branch 'master' into work2
vijaykr338 bf75e7b
Merge branch 'master' into work2
vijaykr338 c8dac15
Merge branch 'master' into work2
vijaykr338 dcc5d37
fixed coding style with clang-format-9 and added the suggested changes
vijaykr338 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#include "openvino/op/cos.hpp" | ||
#include "openvino/op/sin.hpp" | ||
#include "openvino/op/multiply.hpp" | ||
#include "openvino/op/concat.hpp" | ||
#include "openvino/frontend/complex_type_mark.hpp" | ||
#include "openvino/op/convert.hpp" | ||
#include "utils.hpp" | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace pytorch { | ||
namespace op { | ||
|
||
using namespace ov::op; | ||
|
||
OutputVector translate_polar(const NodeContext& context) { | ||
num_inputs_check(context, 2, 3); | ||
auto abs = context.get_input(0); | ||
auto angle = context.get_input(1); | ||
auto real = context.mark_node(std::make_shared<v1::Multiply>(abs,context.mark_node(std::make_shared<v0::Cos>(angle)))); | ||
auto imag = context.mark_node(std::make_shared<v1::Multiply>(abs,context.mark_node(std::make_shared<v0::Sin>(angle)))); | ||
auto complex_concat = context.mark_node(std::make_shared<v0::Concat>(OutputVector{real, imag}, -1)); | ||
// wrap the tensor with ComplexTypeMark to flag it as complex for later operations. | ||
auto complex_tensor = context.mark_node(std::make_shared<ComplexTypeMark>(complex_concat)); | ||
return {complex_tensor}; | ||
} | ||
|
||
} // namespace op | ||
} // namespace pytorch | ||
} // namespace frontend | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
#include "openvino/op/topk.hpp" | ||
#include "openvino/op/random_uniform.hpp" | ||
#include "openvino/frontend/pytorch/node_context.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "utils.hpp" | ||
#include "openvino/op/shape_of.hpp" | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace pytorch { | ||
namespace op { | ||
|
||
using namespace ov::op; | ||
|
||
OutputVector translate_randperm(const NodeContext& context) { | ||
auto num_inputs = context.get_input_size(); | ||
int64_t n = context.const_input<int64_t>(0); | ||
int dtype_value = 4; | ||
if (num_inputs == 1) { | ||
} else if (num_inputs == 2) { | ||
if (!context.input_is_none(1)) { | ||
dtype_value = context.const_input<int>(1); | ||
OPENVINO_ASSERT(dtype_value == 4, "Only dtype value 4 (int64) is supported for aten::randperm, got: ", dtype_value); | ||
} | ||
} else if (num_inputs == 5) { | ||
if (!context.input_is_none(1)) { | ||
dtype_value = context.const_input<int>(1); | ||
OPENVINO_ASSERT(dtype_value == 4, "Only dtype value 4 (int64) is supported for aten::randperm, got: ", dtype_value); | ||
} | ||
} else { | ||
PYTORCH_OP_CONVERSION_CHECK(false, "Unexpected number of inputs for aten::randperm: ", num_inputs); | ||
} | ||
if (n == 0) { | ||
return {context.mark_node(v0::Constant::create(element::i64, Shape{0},std::vector<int64_t>{}))};} | ||
auto shape = v0::Constant::create(element::i64, Shape{1}, {n}); | ||
auto min_val = v0::Constant::create(element::f32, Shape{}, {0.0f}); | ||
auto max_val = v0::Constant::create(element::f32, Shape{}, {1.0f}); | ||
auto random_tensor = context.mark_node(std::make_shared<v8::RandomUniform>(shape, min_val, max_val, element::f32)); | ||
const int64_t axis = 0; | ||
auto k = v0::Constant::create(element::i64, Shape{}, {n}); | ||
auto topk = context.mark_node(std::make_shared<v11::TopK>(random_tensor, k, axis, ov::op::TopKMode::MIN, ov::op::TopKSortType::SORT_VALUES, element::i64, false)); | ||
return {topk->output(1)}; | ||
} | ||
|
||
} // namespace op | ||
} // namespace pytorch | ||
} // namespace frontend | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright (C) 2018-2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
import torch | ||
import numpy as np | ||
from pytorch_layer_test_class import PytorchLayerTest | ||
|
||
class TestPolar(PytorchLayerTest): | ||
def _prepare_input(self): | ||
return ( | ||
np.array([1.0, 2.0, 3.0], dtype=np.float32), | ||
np.array([0.1, 0.2, 0.3], dtype=np.float32) | ||
) | ||
|
||
def create_model(self): | ||
class PolarModel(torch.nn.Module): | ||
def forward(self, abs, angle): | ||
real = abs * torch.cos(angle) | ||
imag = abs * torch.sin(angle) | ||
return torch.stack([real, imag], dim=-1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are not using aten::polar here. Please use it and return the value using |
||
return PolarModel(), None, None | ||
|
||
@pytest.mark.nightly | ||
@pytest.mark.precommit | ||
@pytest.mark.parametrize("input_variant", ["static", "dynamic"]) | ||
def test_polar(self, ie_device, precision, ir_version, input_variant): | ||
atol = 1e-4 if precision == "FP32" else 1e-3 | ||
rtol = 1e-4 | ||
if input_variant == "static": | ||
input_data = self._prepare_input() | ||
else: | ||
static_input = self._prepare_input() | ||
input_data = ( | ||
np.expand_dims(static_input[0], axis=0), | ||
np.expand_dims(static_input[1], axis=0) | ||
) | ||
self._test(*self.create_model(), ie_device, precision, ir_version, | ||
input_data=input_data, model_trace=True, atol=atol, rtol=rtol) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright (C) 2018-2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
import torch | ||
import numpy as np | ||
from pytorch_layer_test_class import PytorchLayerTest, flattenize_inputs | ||
from copy import deepcopy | ||
|
||
class TestRandperm(PytorchLayerTest): | ||
def _prepare_input(self): | ||
return () | ||
|
||
def create_model(self, n): | ||
class AtenRandperm(torch.nn.Module): | ||
def __init__(self, n): | ||
super().__init__() | ||
self.n = n | ||
|
||
def forward(self): | ||
return torch.randperm(self.n, dtype=torch.int64) | ||
|
||
return AtenRandperm(n), None, "aten::randperm" | ||
|
||
def is_valid_permutation(self, output, n): | ||
if hasattr(output, 'detach'): | ||
arr = output.detach().cpu().numpy().astype(np.int64) | ||
else: | ||
arr = np.array(output, dtype=np.int64) | ||
sorted_arr = np.sort(arr.flatten()) | ||
expected = np.arange(n, dtype=np.int64) | ||
return np.array_equal(sorted_arr, expected) | ||
|
||
@pytest.mark.parametrize("n", [1, 5, 10]) | ||
@pytest.mark.nightly | ||
@pytest.mark.precommit | ||
def test_randperm_custom(self, n, ie_device, precision, ir_version): | ||
model, ref_net, op = self.create_model(n) | ||
inputs = self._prepare_input() | ||
torch_inputs = [torch.from_numpy(x) if isinstance(x, np.ndarray) else x for x in inputs] | ||
ov_inputs = flattenize_inputs(inputs) | ||
trace_model = True | ||
dynamic_shapes = True | ||
freeze_model = True | ||
|
||
with torch.no_grad(): | ||
smodel, converted_model = self.convert_directly_via_frontend( | ||
model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model | ||
) | ||
|
||
from openvino import Core | ||
core = Core() | ||
compiled_model = core.compile_model(converted_model, ie_device). | ||
ov_output_dict = compiled_model(()) | ||
ov_output_tensor = list(ov_output_dict.values())[0] | ||
|
||
assert ov_output_tensor.shape[0] == n, f"Output shape {ov_output_tensor.shape} does not match expected ({n},)" | ||
assert self.is_valid_permutation(ov_output_tensor, n), ( | ||
f"Output {ov_output_tensor} is not a valid permutation of [0, 1, ..., {n-1}]" | ||
) | ||
|
||
@pytest.mark.xfail(reason="OpenVINO doesn't support empty tensors for randperm") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Openvino can create an empty tensor:
|
||
def test_randperm_zero(self, ie_device, precision, ir_version): | ||
model, ref_net, op = self.create_model(0) | ||
inputs = self._prepare_input() | ||
torch_inputs = [torch.from_numpy(x) if isinstance(x, np.ndarray) else x for x in inputs] | ||
ov_inputs = flattenize_inputs(inputs) | ||
trace_model = True | ||
dynamic_shapes = True | ||
freeze_model = True | ||
|
||
with torch.no_grad(): | ||
smodel, converted_model = self.convert_directly_via_frontend( | ||
model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model | ||
) | ||
from openvino import Core | ||
core = Core() | ||
compiled_model = core.compile_model(converted_model, ie_device) | ||
_ = compiled_model(()) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of concat you can pass the real and imaginary parts to complex mark.