From 06568905760d1947e90203feb81fd2e50a4bdfc5 Mon Sep 17 00:00:00 2001 From: Vijay Kumar Date: Sat, 15 Mar 2025 10:24:34 +0530 Subject: [PATCH 01/13] added aten::take function in pytorch, but tests are not running --- src/frontends/pytorch/src/op/take.cpp | 72 ++++++++++++++++++++ src/frontends/pytorch/src/op_table.cpp | 2 + tests/layer_tests/pytorch_tests/test_take.py | 35 ++++++++++ 3 files changed, 109 insertions(+) create mode 100644 src/frontends/pytorch/src/op/take.cpp create mode 100644 tests/layer_tests/pytorch_tests/test_take.py diff --git a/src/frontends/pytorch/src/op/take.cpp b/src/frontends/pytorch/src/op/take.cpp new file mode 100644 index 00000000000000..d4ded4bee9032b --- /dev/null +++ b/src/frontends/pytorch/src/op/take.cpp @@ -0,0 +1,72 @@ +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/shape_of.hpp" +#include "utils.hpp" + + + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; +using namespace std; + + + +OutputVector translate_take_op(const NodeContext& context){ + + num_inputs_check(context, 2, 2, true); + //Get input tensor and tensor indices + + Output input = context.get_input(0); + Output indices = context.get_input(1); + //get their inputs + + //We get information about the input tensor + auto input_shape = input.get_partial_shape(); + + if (input_shape.rank().is_static() && input_shape.rank().get_length() == 0) { + FRONT_END_OP_CONVERSION_CHECK(false, "input tensor MUST be non-scalar"); + } + //always flatten the tensor to 1D by using -1 + auto new_shape = context.mark_node( + v0::Constant::create(element::i64, Shape{1}, {-1}) + ); + + input = context.mark_node( + std::make_shared(input, new_shape, false) + ); + + //the openVINO needs the indices always in i64 + indices = context.mark_node( + std::make_shared(indices, element::i64) + ); + + //handle negative indices + auto input_size = context.mark_node( + std::make_shared(input, element::i64) + ); + indices = normalize_axis(context, indices, input_size); + //create a axis_constant = 0 + auto axis_constant = context.mark_node( + v0::Constant::create(element::i64, Shape{}, {0}) + ); + + //now apply the gather function + auto gather = context.mark_node( + std::make_shared(input, indices, axis_constant) + ); + + + return {gather}; + +} + + + + +}}}} \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 018812354a23b6..2664e213307783 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -237,6 +237,7 @@ OP_CONVERTER(translate_sub_); OP_CONVERTER(translate_sum); OP_CONVERTER(translate_t); OP_CONVERTER(translate_take_along_dim); +OP_CONVERTER(translate_take_op); OP_CONVERTER(translate_to); OP_CONVERTER(translate_topk); OP_CONVERTER(translate_transpose); @@ -685,6 +686,7 @@ const std::unordered_map get_supported_ops_ts() { {"aten::swapaxes", op::quantizable_op}, {"aten::t", op::translate_t}, {"aten::take_along_dim", op::translate_take_along_dim}, + {"aten::take", op::translate_take_op}, {"aten::tan", op::optional_out, 1>}, {"aten::tan_", op::inplace_op>}, {"aten::tanh", op::optional_out, 1>}, diff --git a/tests/layer_tests/pytorch_tests/test_take.py b/tests/layer_tests/pytorch_tests/test_take.py new file mode 100644 index 00000000000000..9d559d2a2f3c51 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_take.py @@ -0,0 +1,35 @@ +# Copyright (C) 2018-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import numpy as np +import torch +from pytorch_layer_test_class import PytorchLayerTest + +class TestTake(PytorchLayerTest): + def _prepare_input(self, input_shape, indices_shape, max_val): + input_tensor = np.random.randn(*input_shape).astype(np.float32) + indices = np.random.randint(-max_val, max_val, indices_shape).astype(np.int64) + return (input_tensor, indices) + + def create_model(self): + class aten_take(torch.nn.Module): + def forward(self, x, indices): + return torch.take(x, indices) + + ref_net = None + return aten_take(), ref_net, "aten::take" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.precommit_torch_export + @pytest.mark.parametrize("input_shape", [(10,), (3, 4), (2, 3, 4), (100,)]) + @pytest.mark.parametrize("indices_shape", [(5,), (2, 2), (3, 2), (50,)]) + def test_take(self, input_shape, indices_shape, ie_device, precision, ir_version): + max_val = np.prod(input_shape) + self._test(*self.create_model(), ie_device, precision, ir_version, + kwargs_to_prepare_input={ + "input_shape": input_shape, + "indices_shape": indices_shape, + "max_val": max_val + }) From 98c1c61292ef91badc912216e41ed56617dae34f Mon Sep 17 00:00:00 2001 From: Vijay Kumar Date: Sun, 16 Mar 2025 03:26:06 +0530 Subject: [PATCH 02/13] implemented aten::str, aten::delete but unable to write their tests --- src/frontends/pytorch/src/op/delete.cpp | 66 ++++++++++++++++++++++++ src/frontends/pytorch/src/op/str.cpp | 67 +++++++++++++++++++++++++ src/frontends/pytorch/src/op_table.cpp | 4 ++ 3 files changed, 137 insertions(+) create mode 100644 src/frontends/pytorch/src/op/delete.cpp create mode 100644 src/frontends/pytorch/src/op/str.cpp diff --git a/src/frontends/pytorch/src/op/delete.cpp b/src/frontends/pytorch/src/op/delete.cpp new file mode 100644 index 00000000000000..d04c0b6d6f6bf0 --- /dev/null +++ b/src/frontends/pytorch/src/op/delete.cpp @@ -0,0 +1,66 @@ +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/strided_slice.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/convert.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_delete(const NodeContext& context) { + // check the input + num_inputs_check(context, 2, 2); + // Retrieve inputs + auto input = context.get_input(0); // container/target tensor + auto indices = context.get_input(1); // Indices for elements to delete + // ensure int32 + if (indices.get_element_type() != ov::element::i32) { + indices = context.mark_node(std::make_shared(indices, ov::element::i32)); + } + // getting the shape + auto input_shape = context.mark_node(std::make_shared(input, ov::element::i32)); + + //implementation plan -> + //slice the tensor before the indices + //slice the tensor after the indices + //concatenate the slices + + // calculate the end index for slicing after the target indices + auto end_index = context.mark_node(std::make_shared( + indices, + v0::Constant::create(ov::element::i32, Shape{1}, {1}) + )); + + // slice elements before the indices + auto before_indices = context.mark_node(std::make_shared( + input, + v0::Constant::create(ov::element::i32, Shape{1}, {0}), indices, // Begin at 0, End at indices + v0::Constant::create(ov::element::i32, Shape{1}, {1}), // Stride of 1 + std::vector{1}, std::vector{0} + )); + // slice elements after the indices + auto after_indices = context.mark_node(std::make_shared( + input,end_index, // Begin after target, End + input_shape, + v0::Constant::create(ov::element::i32, Shape{1}, {1}), // Stride of 1 + std::vector{0}, std::vector{1} + )); + // concat or join the slices + auto result = context.mark_node(std::make_shared( + OutputVector{before_indices, after_indices}, 0 // axis along which to concatenate + )); + + return {result}; +} + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op/str.cpp b/src/frontends/pytorch/src/op/str.cpp new file mode 100644 index 00000000000000..fe4efa632b4661 --- /dev/null +++ b/src/frontends/pytorch/src/op/str.cpp @@ -0,0 +1,67 @@ +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/shape_of.hpp" +#include "pt_framework_node.hpp" +#include "openvino/core/validation_util.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_str(const NodeContext& context) { + // Get the input tensor + num_inputs_check(context, 1, 1); + auto input = context.get_input(0); + // Check if the input tensor is of a supported type + if (input.get_element_type() != ov::element::string && + input.get_element_type() != ov::element::f32 && + input.get_element_type() != ov::element::i32) { + OPENVINO_THROW("Unsupported tensor type for str operation"); + } + // if the tensor is constant + // Try to get the constant value of the input tensor + auto constant_value = util::get_constant_from_source(input); + if (constant_value) { + // Convert the constant value to a string + std::string str_value; + if (input.get_element_type() == ov::element::f32) { + str_value = std::to_string(constant_value->cast_vector()[0]); + } else if (input.get_element_type() == ov::element::i32) { + str_value = std::to_string(constant_value->cast_vector()[0]); + } else if (input.get_element_type() == ov::element::string) { + str_value = constant_value->cast_vector()[0]; + } + // Create a new Constant node with the string value + auto str_node = std::make_shared( + ov::element::string, + ov::Shape{1}, + std::vector{str_value}); + return {context.mark_node(str_node)}; + } + // if the tensor is not constant + auto shape = context.mark_node(std::make_shared(input, element::i64)); + auto dtype = input.get_element_type(); + // OpenVINO is able to handle numerical data, and does not support string data natively + // therefore, we capture tensor's shape and datatype instead which can be later used in debugging + // Utilising the PtFrameworkNode for this purpose + + // from the documentation + // https://github.com/openvinotoolkit/openvino/blob/master/src/frontends/pytorch/README.md + // PtFrameworkNode is used to represent unconverted operation from the original model. + ov::op::util::FrameworkNodeAttrs attrs; + attrs[PtFrameworkNode::op_type_key] = "aten::str"; + attrs["dtype"] = dtype.get_type_name(); + + auto decoder = context.get_decoder(); + auto ptf_node = std::make_shared(decoder, OutputVector{input}, 1); + ptf_node->set_attrs(attrs); + + return {context.mark_node(ptf_node)}; +} + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 2664e213307783..e7a9e98fb4158d 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -73,6 +73,7 @@ OP_CONVERTER(translate_copy_); OP_CONVERTER(translate_cross); OP_CONVERTER(translate_cumsum); OP_CONVERTER(translate_deform_conv); +OP_CONVERTER(translate_delete); OP_CONVERTER(translate_derive_index); OP_CONVERTER(translate_dim); OP_CONVERTER(translate_div); @@ -232,6 +233,7 @@ OP_CONVERTER(translate_squeeze); OP_CONVERTER(translate_std); OP_CONVERTER(translate_std_mean); OP_CONVERTER(translate_stft); +OP_CONVERTER(translate_str); OP_CONVERTER(translate_sub); OP_CONVERTER(translate_sub_); OP_CONVERTER(translate_sum); @@ -458,6 +460,7 @@ const std::unordered_map get_supported_ops_ts() { {"aten::cosh_", op::inplace_op>}, {"aten::cross", op::translate_cross}, {"aten::cumsum", op::translate_cumsum}, + {"aten::delete", op::translate_delete}, {"aten::detach", op::skip_node}, {"aten::dequantize", op::skip_node}, // we convert model to fp32 using FQ, so dequantization is not needed {"aten::dim", op::translate_dim}, @@ -680,6 +683,7 @@ const std::unordered_map get_supported_ops_ts() { {"aten::std", op::translate_std}, {"aten::std_mean", op::translate_std_mean}, {"aten::stft", op::translate_stft}, + {"aten::stft", op::translate_str}, {"aten::sub", op::translate_sub}, {"aten::sub_", op::translate_sub_}, {"aten::sum", op::translate_sum}, From 417691947bc0ab685b7eb591cfa8cb30a412cc95 Mon Sep 17 00:00:00 2001 From: Vijay Kumar Date: Wed, 19 Mar 2025 22:43:33 +0530 Subject: [PATCH 03/13] added support for aten::randperm with tests --- src/frontends/pytorch/src/op/randperm.cpp | 70 ++++++++++++++++ src/frontends/pytorch/src/op_table.cpp | 12 +-- .../pytorch_tests/test_randperm.py | 79 +++++++++++++++++++ 3 files changed, 156 insertions(+), 5 deletions(-) create mode 100644 src/frontends/pytorch/src/op/randperm.cpp create mode 100644 tests/layer_tests/pytorch_tests/test_randperm.py diff --git a/src/frontends/pytorch/src/op/randperm.cpp b/src/frontends/pytorch/src/op/randperm.cpp new file mode 100644 index 00000000000000..52214c24069b63 --- /dev/null +++ b/src/frontends/pytorch/src/op/randperm.cpp @@ -0,0 +1,70 @@ +#include "openvino/op/topk.hpp" +#include "openvino/op/random_uniform.hpp" +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/constant.hpp" +#include "utils.hpp" +#include "openvino/op/shape_of.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_randperm(const NodeContext& context) { + auto num_inputs = context.get_input_size(); + int64_t n = context.const_input(0); + int dtype_value = 4; + if (num_inputs == 1) { + } else if (num_inputs == 2) { + if (!context.input_is_none(1)) { + dtype_value = context.const_input(1); + OPENVINO_ASSERT(dtype_value == 4, + "Only dtype value 4 (int64) is supported for aten::randperm, got: ", + dtype_value); + } + } else if (num_inputs == 5) { + if (!context.input_is_none(1)) { + dtype_value = context.const_input(1); + OPENVINO_ASSERT(dtype_value == 4, + "Only dtype value 4 (int64) is supported for aten::randperm, got: ", + dtype_value); + } + } else { + PYTORCH_OP_CONVERSION_CHECK(false, "Unexpected number of inputs for aten::randperm: ", num_inputs); + } + if (n == 0) { + return {context.mark_node(v0::Constant::create( + element::i64, + Shape{0}, + std::vector{} + ))}; + } + auto shape = v0::Constant::create(element::i64, Shape{1}, {n}); + auto min_val = v0::Constant::create(element::f32, Shape{}, {0.0f}); + auto max_val = v0::Constant::create(element::f32, Shape{}, {1.0f}); + auto random_tensor = context.mark_node(std::make_shared( + shape, + min_val, + max_val, + element::f32 + )); + const int64_t axis = 0; + auto k = v0::Constant::create(element::i64, Shape{}, {n}); + auto topk = context.mark_node(std::make_shared( + random_tensor, + k, + axis, + ov::op::TopKMode::MIN, + ov::op::TopKSortType::SORT_VALUES, + element::i64, + false + )); + return {topk->output(1)}; +} + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index e7a9e98fb4158d..230875c8045f9d 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -73,7 +73,7 @@ OP_CONVERTER(translate_copy_); OP_CONVERTER(translate_cross); OP_CONVERTER(translate_cumsum); OP_CONVERTER(translate_deform_conv); -OP_CONVERTER(translate_delete); + OP_CONVERTER(translate_derive_index); OP_CONVERTER(translate_dim); OP_CONVERTER(translate_div); @@ -198,6 +198,7 @@ OP_CONVERTER(translate_quantized_hardswish); OP_CONVERTER(translate_quantized_mul); OP_CONVERTER(translate_range_length); OP_CONVERTER(translate_rand); +OP_CONVERTER(translate_randperm); OP_CONVERTER(translate_randn); OP_CONVERTER(translate_randint); OP_CONVERTER(translate_rand_like); @@ -233,7 +234,7 @@ OP_CONVERTER(translate_squeeze); OP_CONVERTER(translate_std); OP_CONVERTER(translate_std_mean); OP_CONVERTER(translate_stft); -OP_CONVERTER(translate_str); + OP_CONVERTER(translate_sub); OP_CONVERTER(translate_sub_); OP_CONVERTER(translate_sum); @@ -460,7 +461,7 @@ const std::unordered_map get_supported_ops_ts() { {"aten::cosh_", op::inplace_op>}, {"aten::cross", op::translate_cross}, {"aten::cumsum", op::translate_cumsum}, - {"aten::delete", op::translate_delete}, + {"aten::detach", op::skip_node}, {"aten::dequantize", op::skip_node}, // we convert model to fp32 using FQ, so dequantization is not needed {"aten::dim", op::translate_dim}, @@ -622,6 +623,7 @@ const std::unordered_map get_supported_ops_ts() { {"aten::quantize_per_channel", op::translate_quantize_per_channel}, {"aten::quantize_per_tensor", op::translate_quantize_per_tensor}, {"aten::rand", op::translate_rand}, + {"aten::rand", op::translate_randperm}, {"aten::rand_like", op::translate_rand_like}, {"aten::randint", op::translate_randint}, {"aten::randn", op::translate_randn}, @@ -683,14 +685,14 @@ const std::unordered_map get_supported_ops_ts() { {"aten::std", op::translate_std}, {"aten::std_mean", op::translate_std_mean}, {"aten::stft", op::translate_stft}, - {"aten::stft", op::translate_str}, + {"aten::sub", op::translate_sub}, {"aten::sub_", op::translate_sub_}, {"aten::sum", op::translate_sum}, {"aten::swapaxes", op::quantizable_op}, {"aten::t", op::translate_t}, {"aten::take_along_dim", op::translate_take_along_dim}, - {"aten::take", op::translate_take_op}, + {"aten::tan", op::optional_out, 1>}, {"aten::tan_", op::inplace_op>}, {"aten::tanh", op::optional_out, 1>}, diff --git a/tests/layer_tests/pytorch_tests/test_randperm.py b/tests/layer_tests/pytorch_tests/test_randperm.py new file mode 100644 index 00000000000000..d3c1ec4b7bac54 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_randperm.py @@ -0,0 +1,79 @@ +# Copyright (C) 2018-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import numpy as np +from pytorch_layer_test_class import PytorchLayerTest, flattenize_inputs +from copy import deepcopy + +class TestRandperm(PytorchLayerTest): + def _prepare_input(self): + return () + + def create_model(self, n): + class AtenRandperm(torch.nn.Module): + def __init__(self, n): + super().__init__() + self.n = n + + def forward(self): + return torch.randperm(self.n, dtype=torch.int64) + + return AtenRandperm(n), None, "aten::randperm" + + def is_valid_permutation(self, output, n): + if hasattr(output, 'detach'): + arr = output.detach().cpu().numpy().astype(np.int64) + else: + arr = np.array(output, dtype=np.int64) + sorted_arr = np.sort(arr.flatten()) + expected = np.arange(n, dtype=np.int64) + return np.array_equal(sorted_arr, expected) + + @pytest.mark.parametrize("n", [1, 5, 10]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_randperm_custom(self, n, ie_device, precision, ir_version): + model, ref_net, op = self.create_model(n) + inputs = self._prepare_input() + torch_inputs = [torch.from_numpy(x) if isinstance(x, np.ndarray) else x for x in inputs] + ov_inputs = flattenize_inputs(inputs) + trace_model = True + dynamic_shapes = True + freeze_model = True + + with torch.no_grad(): + smodel, converted_model = self.convert_directly_via_frontend( + model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model + ) + + from openvino import Core + core = Core() + compiled_model = core.compile_model(converted_model, ie_device). + ov_output_dict = compiled_model(()) + ov_output_tensor = list(ov_output_dict.values())[0] + + assert ov_output_tensor.shape[0] == n, f"Output shape {ov_output_tensor.shape} does not match expected ({n},)" + assert self.is_valid_permutation(ov_output_tensor, n), ( + f"Output {ov_output_tensor} is not a valid permutation of [0, 1, ..., {n-1}]" + ) + + @pytest.mark.xfail(reason="OpenVINO doesn't support empty tensors for randperm") + def test_randperm_zero(self, ie_device, precision, ir_version): + model, ref_net, op = self.create_model(0) + inputs = self._prepare_input() + torch_inputs = [torch.from_numpy(x) if isinstance(x, np.ndarray) else x for x in inputs] + ov_inputs = flattenize_inputs(inputs) + trace_model = True + dynamic_shapes = True + freeze_model = True + + with torch.no_grad(): + smodel, converted_model = self.convert_directly_via_frontend( + model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model + ) + from openvino import Core + core = Core() + compiled_model = core.compile_model(converted_model, ie_device) + _ = compiled_model(()) From e32628f1439761fb7435e50acd44d5bf5a180ca0 Mon Sep 17 00:00:00 2001 From: Vijay Kumar <82241873+vijaykr338@users.noreply.github.com> Date: Wed, 19 Mar 2025 22:50:19 +0530 Subject: [PATCH 04/13] Delete src/frontends/pytorch/src/op/delete.cpp --- src/frontends/pytorch/src/op/delete.cpp | 66 ------------------------- 1 file changed, 66 deletions(-) delete mode 100644 src/frontends/pytorch/src/op/delete.cpp diff --git a/src/frontends/pytorch/src/op/delete.cpp b/src/frontends/pytorch/src/op/delete.cpp deleted file mode 100644 index d04c0b6d6f6bf0..00000000000000 --- a/src/frontends/pytorch/src/op/delete.cpp +++ /dev/null @@ -1,66 +0,0 @@ -#include "openvino/frontend/pytorch/node_context.hpp" -#include "openvino/op/constant.hpp" -#include "openvino/op/concat.hpp" -#include "openvino/op/shape_of.hpp" -#include "openvino/op/strided_slice.hpp" -#include "openvino/op/add.hpp" -#include "openvino/op/convert.hpp" -#include "utils.hpp" - -namespace ov { -namespace frontend { -namespace pytorch { -namespace op { - -using namespace ov::op; - -OutputVector translate_delete(const NodeContext& context) { - // check the input - num_inputs_check(context, 2, 2); - // Retrieve inputs - auto input = context.get_input(0); // container/target tensor - auto indices = context.get_input(1); // Indices for elements to delete - // ensure int32 - if (indices.get_element_type() != ov::element::i32) { - indices = context.mark_node(std::make_shared(indices, ov::element::i32)); - } - // getting the shape - auto input_shape = context.mark_node(std::make_shared(input, ov::element::i32)); - - //implementation plan -> - //slice the tensor before the indices - //slice the tensor after the indices - //concatenate the slices - - // calculate the end index for slicing after the target indices - auto end_index = context.mark_node(std::make_shared( - indices, - v0::Constant::create(ov::element::i32, Shape{1}, {1}) - )); - - // slice elements before the indices - auto before_indices = context.mark_node(std::make_shared( - input, - v0::Constant::create(ov::element::i32, Shape{1}, {0}), indices, // Begin at 0, End at indices - v0::Constant::create(ov::element::i32, Shape{1}, {1}), // Stride of 1 - std::vector{1}, std::vector{0} - )); - // slice elements after the indices - auto after_indices = context.mark_node(std::make_shared( - input,end_index, // Begin after target, End - input_shape, - v0::Constant::create(ov::element::i32, Shape{1}, {1}), // Stride of 1 - std::vector{0}, std::vector{1} - )); - // concat or join the slices - auto result = context.mark_node(std::make_shared( - OutputVector{before_indices, after_indices}, 0 // axis along which to concatenate - )); - - return {result}; -} - -} // namespace op -} // namespace pytorch -} // namespace frontend -} // namespace ov From 14832552ed5209138855e77a9bd4e39667b55810 Mon Sep 17 00:00:00 2001 From: Vijay Kumar <82241873+vijaykr338@users.noreply.github.com> Date: Wed, 19 Mar 2025 17:40:35 +0000 Subject: [PATCH 05/13] cleaned --- src/frontends/pytorch/src/op_table.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 230875c8045f9d..b791c6d6c51e8a 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -73,7 +73,6 @@ OP_CONVERTER(translate_copy_); OP_CONVERTER(translate_cross); OP_CONVERTER(translate_cumsum); OP_CONVERTER(translate_deform_conv); - OP_CONVERTER(translate_derive_index); OP_CONVERTER(translate_dim); OP_CONVERTER(translate_div); @@ -234,7 +233,6 @@ OP_CONVERTER(translate_squeeze); OP_CONVERTER(translate_std); OP_CONVERTER(translate_std_mean); OP_CONVERTER(translate_stft); - OP_CONVERTER(translate_sub); OP_CONVERTER(translate_sub_); OP_CONVERTER(translate_sum); @@ -461,7 +459,6 @@ const std::unordered_map get_supported_ops_ts() { {"aten::cosh_", op::inplace_op>}, {"aten::cross", op::translate_cross}, {"aten::cumsum", op::translate_cumsum}, - {"aten::detach", op::skip_node}, {"aten::dequantize", op::skip_node}, // we convert model to fp32 using FQ, so dequantization is not needed {"aten::dim", op::translate_dim}, @@ -685,14 +682,12 @@ const std::unordered_map get_supported_ops_ts() { {"aten::std", op::translate_std}, {"aten::std_mean", op::translate_std_mean}, {"aten::stft", op::translate_stft}, - {"aten::sub", op::translate_sub}, {"aten::sub_", op::translate_sub_}, {"aten::sum", op::translate_sum}, {"aten::swapaxes", op::quantizable_op}, {"aten::t", op::translate_t}, {"aten::take_along_dim", op::translate_take_along_dim}, - {"aten::tan", op::optional_out, 1>}, {"aten::tan_", op::inplace_op>}, {"aten::tanh", op::optional_out, 1>}, From 8a16331e023daa3a19b523265e03fa4b77ae3cd2 Mon Sep 17 00:00:00 2001 From: Vijay Kumar <82241873+vijaykr338@users.noreply.github.com> Date: Wed, 19 Mar 2025 23:11:31 +0530 Subject: [PATCH 06/13] Delete tests/layer_tests/pytorch_tests/test_take.py --- tests/layer_tests/pytorch_tests/test_take.py | 35 -------------------- 1 file changed, 35 deletions(-) delete mode 100644 tests/layer_tests/pytorch_tests/test_take.py diff --git a/tests/layer_tests/pytorch_tests/test_take.py b/tests/layer_tests/pytorch_tests/test_take.py deleted file mode 100644 index 9d559d2a2f3c51..00000000000000 --- a/tests/layer_tests/pytorch_tests/test_take.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (C) 2018-2025 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import pytest -import numpy as np -import torch -from pytorch_layer_test_class import PytorchLayerTest - -class TestTake(PytorchLayerTest): - def _prepare_input(self, input_shape, indices_shape, max_val): - input_tensor = np.random.randn(*input_shape).astype(np.float32) - indices = np.random.randint(-max_val, max_val, indices_shape).astype(np.int64) - return (input_tensor, indices) - - def create_model(self): - class aten_take(torch.nn.Module): - def forward(self, x, indices): - return torch.take(x, indices) - - ref_net = None - return aten_take(), ref_net, "aten::take" - - @pytest.mark.nightly - @pytest.mark.precommit - @pytest.mark.precommit_torch_export - @pytest.mark.parametrize("input_shape", [(10,), (3, 4), (2, 3, 4), (100,)]) - @pytest.mark.parametrize("indices_shape", [(5,), (2, 2), (3, 2), (50,)]) - def test_take(self, input_shape, indices_shape, ie_device, precision, ir_version): - max_val = np.prod(input_shape) - self._test(*self.create_model(), ie_device, precision, ir_version, - kwargs_to_prepare_input={ - "input_shape": input_shape, - "indices_shape": indices_shape, - "max_val": max_val - }) From e8b35e841c3ef3b8d7eff9af122594e438c61eab Mon Sep 17 00:00:00 2001 From: Vijay Kumar <82241873+vijaykr338@users.noreply.github.com> Date: Wed, 19 Mar 2025 23:11:51 +0530 Subject: [PATCH 07/13] Delete src/frontends/pytorch/src/op/take.cpp --- src/frontends/pytorch/src/op/take.cpp | 72 --------------------------- 1 file changed, 72 deletions(-) delete mode 100644 src/frontends/pytorch/src/op/take.cpp diff --git a/src/frontends/pytorch/src/op/take.cpp b/src/frontends/pytorch/src/op/take.cpp deleted file mode 100644 index d4ded4bee9032b..00000000000000 --- a/src/frontends/pytorch/src/op/take.cpp +++ /dev/null @@ -1,72 +0,0 @@ -#include "openvino/frontend/pytorch/node_context.hpp" -#include "openvino/op/gather.hpp" -#include "openvino/op/reshape.hpp" -#include "openvino/op/constant.hpp" -#include "openvino/op/shape_of.hpp" -#include "utils.hpp" - - - -namespace ov { -namespace frontend { -namespace pytorch { -namespace op { - -using namespace ov::op; -using namespace std; - - - -OutputVector translate_take_op(const NodeContext& context){ - - num_inputs_check(context, 2, 2, true); - //Get input tensor and tensor indices - - Output input = context.get_input(0); - Output indices = context.get_input(1); - //get their inputs - - //We get information about the input tensor - auto input_shape = input.get_partial_shape(); - - if (input_shape.rank().is_static() && input_shape.rank().get_length() == 0) { - FRONT_END_OP_CONVERSION_CHECK(false, "input tensor MUST be non-scalar"); - } - //always flatten the tensor to 1D by using -1 - auto new_shape = context.mark_node( - v0::Constant::create(element::i64, Shape{1}, {-1}) - ); - - input = context.mark_node( - std::make_shared(input, new_shape, false) - ); - - //the openVINO needs the indices always in i64 - indices = context.mark_node( - std::make_shared(indices, element::i64) - ); - - //handle negative indices - auto input_size = context.mark_node( - std::make_shared(input, element::i64) - ); - indices = normalize_axis(context, indices, input_size); - //create a axis_constant = 0 - auto axis_constant = context.mark_node( - v0::Constant::create(element::i64, Shape{}, {0}) - ); - - //now apply the gather function - auto gather = context.mark_node( - std::make_shared(input, indices, axis_constant) - ); - - - return {gather}; - -} - - - - -}}}} \ No newline at end of file From 16f11faa6a350aad0a5be693cea1559c67a566fd Mon Sep 17 00:00:00 2001 From: Vijay Kumar <82241873+vijaykr338@users.noreply.github.com> Date: Wed, 19 Mar 2025 23:12:16 +0530 Subject: [PATCH 08/13] Delete src/frontends/pytorch/src/op/str.cpp --- src/frontends/pytorch/src/op/str.cpp | 67 ---------------------------- 1 file changed, 67 deletions(-) delete mode 100644 src/frontends/pytorch/src/op/str.cpp diff --git a/src/frontends/pytorch/src/op/str.cpp b/src/frontends/pytorch/src/op/str.cpp deleted file mode 100644 index fe4efa632b4661..00000000000000 --- a/src/frontends/pytorch/src/op/str.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include "openvino/frontend/pytorch/node_context.hpp" -#include "openvino/op/constant.hpp" -#include "openvino/op/shape_of.hpp" -#include "pt_framework_node.hpp" -#include "openvino/core/validation_util.hpp" -#include "utils.hpp" - -namespace ov { -namespace frontend { -namespace pytorch { -namespace op { - -OutputVector translate_str(const NodeContext& context) { - // Get the input tensor - num_inputs_check(context, 1, 1); - auto input = context.get_input(0); - // Check if the input tensor is of a supported type - if (input.get_element_type() != ov::element::string && - input.get_element_type() != ov::element::f32 && - input.get_element_type() != ov::element::i32) { - OPENVINO_THROW("Unsupported tensor type for str operation"); - } - // if the tensor is constant - // Try to get the constant value of the input tensor - auto constant_value = util::get_constant_from_source(input); - if (constant_value) { - // Convert the constant value to a string - std::string str_value; - if (input.get_element_type() == ov::element::f32) { - str_value = std::to_string(constant_value->cast_vector()[0]); - } else if (input.get_element_type() == ov::element::i32) { - str_value = std::to_string(constant_value->cast_vector()[0]); - } else if (input.get_element_type() == ov::element::string) { - str_value = constant_value->cast_vector()[0]; - } - // Create a new Constant node with the string value - auto str_node = std::make_shared( - ov::element::string, - ov::Shape{1}, - std::vector{str_value}); - return {context.mark_node(str_node)}; - } - // if the tensor is not constant - auto shape = context.mark_node(std::make_shared(input, element::i64)); - auto dtype = input.get_element_type(); - // OpenVINO is able to handle numerical data, and does not support string data natively - // therefore, we capture tensor's shape and datatype instead which can be later used in debugging - // Utilising the PtFrameworkNode for this purpose - - // from the documentation - // https://github.com/openvinotoolkit/openvino/blob/master/src/frontends/pytorch/README.md - // PtFrameworkNode is used to represent unconverted operation from the original model. - ov::op::util::FrameworkNodeAttrs attrs; - attrs[PtFrameworkNode::op_type_key] = "aten::str"; - attrs["dtype"] = dtype.get_type_name(); - - auto decoder = context.get_decoder(); - auto ptf_node = std::make_shared(decoder, OutputVector{input}, 1); - ptf_node->set_attrs(attrs); - - return {context.mark_node(ptf_node)}; -} - -} // namespace op -} // namespace pytorch -} // namespace frontend -} // namespace ov \ No newline at end of file From cd18e01424638bbf8b1458f8395ffdf794a8325f Mon Sep 17 00:00:00 2001 From: Vijay Kumar <82241873+vijaykr338@users.noreply.github.com> Date: Sun, 23 Mar 2025 12:38:24 +0000 Subject: [PATCH 09/13] added support for aten::polar --- src/frontends/pytorch/src/op/polar.cpp | 31 +++++++++++++++ src/frontends/pytorch/src/op_table.cpp | 2 + tests/layer_tests/pytorch_tests/test_polar.py | 39 +++++++++++++++++++ 3 files changed, 72 insertions(+) create mode 100644 src/frontends/pytorch/src/op/polar.cpp create mode 100644 tests/layer_tests/pytorch_tests/test_polar.py diff --git a/src/frontends/pytorch/src/op/polar.cpp b/src/frontends/pytorch/src/op/polar.cpp new file mode 100644 index 00000000000000..a34644594a4a03 --- /dev/null +++ b/src/frontends/pytorch/src/op/polar.cpp @@ -0,0 +1,31 @@ +#include "openvino/op/cos.hpp" +#include "openvino/op/sin.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/frontend/complex_type_mark.hpp" +#include "openvino/op/convert.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_polar(const NodeContext& context) { + num_inputs_check(context, 2, 3); + auto abs = context.get_input(0); + auto angle = context.get_input(1); + auto real = context.mark_node(std::make_shared(abs,context.mark_node(std::make_shared(angle)))); + auto imag = context.mark_node(std::make_shared(abs,context.mark_node(std::make_shared(angle)))); + auto complex_concat = context.mark_node(std::make_shared(OutputVector{real, imag}, -1)); + // wrap the tensor with ComplexTypeMark to flag it as complex for later operations. + auto complex_tensor = context.mark_node(std::make_shared(complex_concat)); + return {complex_tensor}; +} + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index b791c6d6c51e8a..13fa21c5b86284 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -186,6 +186,7 @@ OP_CONVERTER(translate_permute); OP_CONVERTER(translate_pairwise_distance); OP_CONVERTER(translate_pixel_shuffle); OP_CONVERTER(translate_pixel_unshuffle); +OP_CONVERTER(translate_polar); OP_CONVERTER(translate_pow); OP_CONVERTER(translate_prod); OP_CONVERTER(translate_pythonop); @@ -614,6 +615,7 @@ const std::unordered_map get_supported_ops_ts() { {"aten::pixel_shuffle", op::translate_pixel_shuffle}, {"aten::pixel_unshuffle", op::translate_pixel_unshuffle}, {"aten::prelu", op::translate_1to1_match_2_inputs}, + {"aten::polar", op::translate_polar}, {"aten::pow", op::translate_pow}, {"aten::pow_", op::translate_pow}, {"aten::prod", op::translate_prod}, diff --git a/tests/layer_tests/pytorch_tests/test_polar.py b/tests/layer_tests/pytorch_tests/test_polar.py new file mode 100644 index 00000000000000..b450f5542de041 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_polar.py @@ -0,0 +1,39 @@ +# Copyright (C) 2018-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import numpy as np +from pytorch_layer_test_class import PytorchLayerTest + +class TestPolar(PytorchLayerTest): + def _prepare_input(self): + return ( + np.array([1.0, 2.0, 3.0], dtype=np.float32), + np.array([0.1, 0.2, 0.3], dtype=np.float32) + ) + + def create_model(self): + class PolarModel(torch.nn.Module): + def forward(self, abs, angle): + real = abs * torch.cos(angle) + imag = abs * torch.sin(angle) + return torch.stack([real, imag], dim=-1) + return PolarModel(), None, None + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("input_variant", ["static", "dynamic"]) + def test_polar(self, ie_device, precision, ir_version, input_variant): + atol = 1e-4 if precision == "FP32" else 1e-3 + rtol = 1e-4 + if input_variant == "static": + input_data = self._prepare_input() + else: + static_input = self._prepare_input() + input_data = ( + np.expand_dims(static_input[0], axis=0), + np.expand_dims(static_input[1], axis=0) + ) + self._test(*self.create_model(), ie_device, precision, ir_version, + input_data=input_data, model_trace=True, atol=atol, rtol=rtol) From 113d07eef57f2d6096cb0b220d6dbeb904048cf4 Mon Sep 17 00:00:00 2001 From: Vijay Kumar <82241873+vijaykr338@users.noreply.github.com> Date: Sun, 23 Mar 2025 12:47:42 +0000 Subject: [PATCH 10/13] fixed code style for aten::polar and aten::randperm --- src/frontends/pytorch/src/op/randperm.cpp | 34 ++++------------------- src/frontends/pytorch/src/op_table.cpp | 1 - 2 files changed, 6 insertions(+), 29 deletions(-) diff --git a/src/frontends/pytorch/src/op/randperm.cpp b/src/frontends/pytorch/src/op/randperm.cpp index 52214c24069b63..2ebce184841b73 100644 --- a/src/frontends/pytorch/src/op/randperm.cpp +++ b/src/frontends/pytorch/src/op/randperm.cpp @@ -20,47 +20,25 @@ OutputVector translate_randperm(const NodeContext& context) { } else if (num_inputs == 2) { if (!context.input_is_none(1)) { dtype_value = context.const_input(1); - OPENVINO_ASSERT(dtype_value == 4, - "Only dtype value 4 (int64) is supported for aten::randperm, got: ", - dtype_value); + OPENVINO_ASSERT(dtype_value == 4, "Only dtype value 4 (int64) is supported for aten::randperm, got: ", dtype_value); } } else if (num_inputs == 5) { if (!context.input_is_none(1)) { dtype_value = context.const_input(1); - OPENVINO_ASSERT(dtype_value == 4, - "Only dtype value 4 (int64) is supported for aten::randperm, got: ", - dtype_value); + OPENVINO_ASSERT(dtype_value == 4, "Only dtype value 4 (int64) is supported for aten::randperm, got: ", dtype_value); } } else { - PYTORCH_OP_CONVERSION_CHECK(false, "Unexpected number of inputs for aten::randperm: ", num_inputs); + PYTORCH_OP_CONVERSION_CHECK(false, "Unexpected number of inputs for aten::randperm: ", num_inputs); } if (n == 0) { - return {context.mark_node(v0::Constant::create( - element::i64, - Shape{0}, - std::vector{} - ))}; - } + return {context.mark_node(v0::Constant::create(element::i64, Shape{0},std::vector{}))};} auto shape = v0::Constant::create(element::i64, Shape{1}, {n}); auto min_val = v0::Constant::create(element::f32, Shape{}, {0.0f}); auto max_val = v0::Constant::create(element::f32, Shape{}, {1.0f}); - auto random_tensor = context.mark_node(std::make_shared( - shape, - min_val, - max_val, - element::f32 - )); + auto random_tensor = context.mark_node(std::make_shared(shape, min_val, max_val, element::f32)); const int64_t axis = 0; auto k = v0::Constant::create(element::i64, Shape{}, {n}); - auto topk = context.mark_node(std::make_shared( - random_tensor, - k, - axis, - ov::op::TopKMode::MIN, - ov::op::TopKSortType::SORT_VALUES, - element::i64, - false - )); + auto topk = context.mark_node(std::make_shared(random_tensor, k, axis, ov::op::TopKMode::MIN, ov::op::TopKSortType::SORT_VALUES, element::i64, false)); return {topk->output(1)}; } diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 13fa21c5b86284..c5db5b25963960 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -239,7 +239,6 @@ OP_CONVERTER(translate_sub_); OP_CONVERTER(translate_sum); OP_CONVERTER(translate_t); OP_CONVERTER(translate_take_along_dim); -OP_CONVERTER(translate_take_op); OP_CONVERTER(translate_to); OP_CONVERTER(translate_topk); OP_CONVERTER(translate_transpose); From dcc5d37026aa0dffb1411b610ad44314813a091b Mon Sep 17 00:00:00 2001 From: Vijay Kumar <82241873+vijaykr338@users.noreply.github.com> Date: Tue, 25 Mar 2025 04:27:17 +0000 Subject: [PATCH 11/13] fixed coding style with clang-format-9 and added the suggested changes --- src/frontends/pytorch/src/op/polar.cpp | 19 ++-- src/frontends/pytorch/src/op/randperm.cpp | 28 +++-- tests/layer_tests/pytorch_tests/test_polar.py | 47 ++++---- .../pytorch_tests/test_randperm.py | 100 +++++++----------- 4 files changed, 94 insertions(+), 100 deletions(-) diff --git a/src/frontends/pytorch/src/op/polar.cpp b/src/frontends/pytorch/src/op/polar.cpp index a34644594a4a03..96014d82e2244f 100644 --- a/src/frontends/pytorch/src/op/polar.cpp +++ b/src/frontends/pytorch/src/op/polar.cpp @@ -1,9 +1,8 @@ -#include "openvino/op/cos.hpp" -#include "openvino/op/sin.hpp" -#include "openvino/op/multiply.hpp" -#include "openvino/op/concat.hpp" #include "openvino/frontend/complex_type_mark.hpp" #include "openvino/op/convert.hpp" +#include "openvino/op/cos.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/sin.hpp" #include "utils.hpp" namespace ov { @@ -17,15 +16,15 @@ OutputVector translate_polar(const NodeContext& context) { num_inputs_check(context, 2, 3); auto abs = context.get_input(0); auto angle = context.get_input(1); - auto real = context.mark_node(std::make_shared(abs,context.mark_node(std::make_shared(angle)))); - auto imag = context.mark_node(std::make_shared(abs,context.mark_node(std::make_shared(angle)))); - auto complex_concat = context.mark_node(std::make_shared(OutputVector{real, imag}, -1)); - // wrap the tensor with ComplexTypeMark to flag it as complex for later operations. - auto complex_tensor = context.mark_node(std::make_shared(complex_concat)); + auto real = + context.mark_node(std::make_shared(abs, context.mark_node(std::make_shared(angle)))); + auto imag = + context.mark_node(std::make_shared(abs, context.mark_node(std::make_shared(angle)))); + auto complex_tensor = context.mark_node(std::make_shared(real, imag)); return {complex_tensor}; } } // namespace op } // namespace pytorch } // namespace frontend -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/src/frontends/pytorch/src/op/randperm.cpp b/src/frontends/pytorch/src/op/randperm.cpp index 2ebce184841b73..cb2acbcf48a3d7 100644 --- a/src/frontends/pytorch/src/op/randperm.cpp +++ b/src/frontends/pytorch/src/op/randperm.cpp @@ -1,9 +1,9 @@ -#include "openvino/op/topk.hpp" -#include "openvino/op/random_uniform.hpp" #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/constant.hpp" -#include "utils.hpp" +#include "openvino/op/random_uniform.hpp" #include "openvino/op/shape_of.hpp" +#include "openvino/op/topk.hpp" +#include "utils.hpp" namespace ov { namespace frontend { @@ -20,25 +20,37 @@ OutputVector translate_randperm(const NodeContext& context) { } else if (num_inputs == 2) { if (!context.input_is_none(1)) { dtype_value = context.const_input(1); - OPENVINO_ASSERT(dtype_value == 4, "Only dtype value 4 (int64) is supported for aten::randperm, got: ", dtype_value); + OPENVINO_ASSERT(dtype_value == 4, + "Only dtype value 4 (int64) is supported for aten::randperm, got: ", + dtype_value); } } else if (num_inputs == 5) { if (!context.input_is_none(1)) { dtype_value = context.const_input(1); - OPENVINO_ASSERT(dtype_value == 4, "Only dtype value 4 (int64) is supported for aten::randperm, got: ", dtype_value); + OPENVINO_ASSERT(dtype_value == 4, + "Only dtype value 4 (int64) is supported for aten::randperm, got: ", + dtype_value); } } else { - PYTORCH_OP_CONVERSION_CHECK(false, "Unexpected number of inputs for aten::randperm: ", num_inputs); + PYTORCH_OP_CONVERSION_CHECK(false, "Unexpected number of inputs for aten::randperm: ", num_inputs); } if (n == 0) { - return {context.mark_node(v0::Constant::create(element::i64, Shape{0},std::vector{}))};} + auto const_empty = std::make_shared(element::i64, Shape{0}, std::vector{}); + return {context.mark_node(const_empty)}; + } auto shape = v0::Constant::create(element::i64, Shape{1}, {n}); auto min_val = v0::Constant::create(element::f32, Shape{}, {0.0f}); auto max_val = v0::Constant::create(element::f32, Shape{}, {1.0f}); auto random_tensor = context.mark_node(std::make_shared(shape, min_val, max_val, element::f32)); const int64_t axis = 0; auto k = v0::Constant::create(element::i64, Shape{}, {n}); - auto topk = context.mark_node(std::make_shared(random_tensor, k, axis, ov::op::TopKMode::MIN, ov::op::TopKSortType::SORT_VALUES, element::i64, false)); + auto topk = context.mark_node(std::make_shared(random_tensor, + k, + axis, + ov::op::TopKMode::MIN, + ov::op::TopKSortType::SORT_VALUES, + element::i64, + false)); return {topk->output(1)}; } diff --git a/tests/layer_tests/pytorch_tests/test_polar.py b/tests/layer_tests/pytorch_tests/test_polar.py index b450f5542de041..88105edf58a6b6 100644 --- a/tests/layer_tests/pytorch_tests/test_polar.py +++ b/tests/layer_tests/pytorch_tests/test_polar.py @@ -1,39 +1,44 @@ # Copyright (C) 2018-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import numpy as np import pytest import torch -import numpy as np + from pytorch_layer_test_class import PytorchLayerTest class TestPolar(PytorchLayerTest): def _prepare_input(self): return ( - np.array([1.0, 2.0, 3.0], dtype=np.float32), - np.array([0.1, 0.2, 0.3], dtype=np.float32) + np.random.uniform(0, 10, (1, 1000)).astype(np.float32), + np.random.uniform(-np.pi, np.pi, (1, 1000)).astype(np.float32) ) def create_model(self): class PolarModel(torch.nn.Module): def forward(self, abs, angle): - real = abs * torch.cos(angle) - imag = abs * torch.sin(angle) - return torch.stack([real, imag], dim=-1) - return PolarModel(), None, None + complex_tensor = torch.polar(abs, angle) + return torch.view_as_real(complex_tensor) + return PolarModel(), None, "aten::polar" + + @pytest.mark.parametrize("input_case", [ + (1, 1000), + (2, 500), + (5, 200), + (10, 100), + ]) + @pytest.mark.parametrize("dtype", [ + np.float32, + np.float64 + ]) @pytest.mark.nightly @pytest.mark.precommit - @pytest.mark.parametrize("input_variant", ["static", "dynamic"]) - def test_polar(self, ie_device, precision, ir_version, input_variant): - atol = 1e-4 if precision == "FP32" else 1e-3 - rtol = 1e-4 - if input_variant == "static": - input_data = self._prepare_input() - else: - static_input = self._prepare_input() - input_data = ( - np.expand_dims(static_input[0], axis=0), - np.expand_dims(static_input[1], axis=0) - ) - self._test(*self.create_model(), ie_device, precision, ir_version, - input_data=input_data, model_trace=True, atol=atol, rtol=rtol) + def test_polar(self, input_case, dtype, ie_device, precision, ir_version): + self.input_shape = input_case + self._prepare_input = lambda: ( + np.random.uniform(0, 10, input_case).astype(dtype), + np.random.uniform(-np.pi, np.pi, input_case).astype(dtype) + ) + self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True, + use_convert_model=True, kwargs_to_prepare_input={}) diff --git a/tests/layer_tests/pytorch_tests/test_randperm.py b/tests/layer_tests/pytorch_tests/test_randperm.py index d3c1ec4b7bac54..66124608bc4179 100644 --- a/tests/layer_tests/pytorch_tests/test_randperm.py +++ b/tests/layer_tests/pytorch_tests/test_randperm.py @@ -1,79 +1,57 @@ -# Copyright (C) 2018-2025 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - import pytest import torch import numpy as np from pytorch_layer_test_class import PytorchLayerTest, flattenize_inputs -from copy import deepcopy class TestRandperm(PytorchLayerTest): def _prepare_input(self): - return () + return (np.array([self.n], dtype=np.int64),) - def create_model(self, n): - class AtenRandperm(torch.nn.Module): - def __init__(self, n): + def create_model(self, n, num_inputs, dtype_value=None): + class aten_randperm(torch.nn.Module): + def __init__(self, n, num_inputs, dtype_value): super().__init__() - self.n = n - - def forward(self): - return torch.randperm(self.n, dtype=torch.int64) - - return AtenRandperm(n), None, "aten::randperm" - - def is_valid_permutation(self, output, n): - if hasattr(output, 'detach'): - arr = output.detach().cpu().numpy().astype(np.int64) - else: - arr = np.array(output, dtype=np.int64) - sorted_arr = np.sort(arr.flatten()) - expected = np.arange(n, dtype=np.int64) - return np.array_equal(sorted_arr, expected) - - @pytest.mark.parametrize("n", [1, 5, 10]) + self.n = torch.tensor(n, dtype=torch.int64) + self.num_inputs = num_inputs + self.dtype = torch.int64 if dtype_value == 4 else None + + def forward(self, x): + if self.num_inputs == 1: + return torch.randperm(self.n) + elif self.num_inputs == 2: + return torch.randperm(self.n, dtype=self.dtype) + elif self.num_inputs == 5: + return torch.randperm(self.n, dtype=self.dtype, layout=torch.strided, + device=x.device, pin_memory=False) + raise ValueError("Invalid num_inputs") + + return aten_randperm(n, num_inputs, dtype_value), None, "aten::randperm" + + @pytest.mark.parametrize(("n", "num_inputs", "dtype_value"), [ + (0, 1, None), + (1, 1, None), + (5, 1, None), + (5, 2, 4), + (5, 5, 4), + ]) @pytest.mark.nightly @pytest.mark.precommit - def test_randperm_custom(self, n, ie_device, precision, ir_version): - model, ref_net, op = self.create_model(n) + def test_randperm(self, n, num_inputs, dtype_value, ie_device, precision, ir_version): + self.n = n + model, ref_net, op = self.create_model(n, num_inputs, dtype_value) inputs = self._prepare_input() torch_inputs = [torch.from_numpy(x) if isinstance(x, np.ndarray) else x for x in inputs] ov_inputs = flattenize_inputs(inputs) - trace_model = True - dynamic_shapes = True - freeze_model = True - - with torch.no_grad(): - smodel, converted_model = self.convert_directly_via_frontend( - model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model - ) - - from openvino import Core - core = Core() - compiled_model = core.compile_model(converted_model, ie_device). - ov_output_dict = compiled_model(()) - ov_output_tensor = list(ov_output_dict.values())[0] - - assert ov_output_tensor.shape[0] == n, f"Output shape {ov_output_tensor.shape} does not match expected ({n},)" - assert self.is_valid_permutation(ov_output_tensor, n), ( - f"Output {ov_output_tensor} is not a valid permutation of [0, 1, ..., {n-1}]" + smodel, converted_model = self.convert_directly_via_frontend( + model, torch_inputs, trace_model=True, dynamic_shapes=False, ov_inputs=ov_inputs, freeze_model=True ) - - @pytest.mark.xfail(reason="OpenVINO doesn't support empty tensors for randperm") - def test_randperm_zero(self, ie_device, precision, ir_version): - model, ref_net, op = self.create_model(0) - inputs = self._prepare_input() - torch_inputs = [torch.from_numpy(x) if isinstance(x, np.ndarray) else x for x in inputs] - ov_inputs = flattenize_inputs(inputs) - trace_model = True - dynamic_shapes = True - freeze_model = True - - with torch.no_grad(): - smodel, converted_model = self.convert_directly_via_frontend( - model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model - ) from openvino import Core core = Core() compiled_model = core.compile_model(converted_model, ie_device) - _ = compiled_model(()) + + ov_output = compiled_model(ov_inputs)[0] + if n > 0: + assert ov_output.shape[0] == n, f"Output shape {ov_output.shape} does not match expected ({n},)" + assert np.array_equal(np.sort(ov_output), np.arange(n)), f"Output is not a valid permutation of [0, ..., {n-1}]" + else: + assert ov_output.shape[0] == 0, f"Output shape for n=0 should be (0,), got {ov_output.shape}" From 6bd79143a6dcf4775e83e98c8aef46e40b7c89d4 Mon Sep 17 00:00:00 2001 From: Vijay Kumar <82241873+vijaykr338@users.noreply.github.com> Date: Tue, 25 Mar 2025 22:57:07 +0530 Subject: [PATCH 12/13] Update src/frontends/pytorch/src/op_table.cpp Co-authored-by: Maxim Vafin --- src/frontends/pytorch/src/op_table.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 07003dc27b1901..37e7383fcbbdc1 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -644,7 +644,7 @@ const std::unordered_map get_supported_ops_ts() { {"aten::quantize_per_channel", op::translate_quantize_per_channel}, {"aten::quantize_per_tensor", op::translate_quantize_per_tensor}, {"aten::rand", op::translate_rand}, - {"aten::rand", op::translate_randperm}, + {"aten::randperm", op::translate_randperm}, {"aten::rand_like", op::translate_rand_like}, {"aten::randint", op::translate_randint}, {"aten::randn", op::translate_randn}, From 21f72459960a4a73bd022db5b0a4db1a6d88a6ab Mon Sep 17 00:00:00 2001 From: Vijay Kumar <82241873+vijaykr338@users.noreply.github.com> Date: Wed, 26 Mar 2025 01:01:11 +0530 Subject: [PATCH 13/13] Update test_polar.py --- tests/layer_tests/pytorch_tests/test_polar.py | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/tests/layer_tests/pytorch_tests/test_polar.py b/tests/layer_tests/pytorch_tests/test_polar.py index 88105edf58a6b6..9e862de003095a 100644 --- a/tests/layer_tests/pytorch_tests/test_polar.py +++ b/tests/layer_tests/pytorch_tests/test_polar.py @@ -4,14 +4,14 @@ import numpy as np import pytest import torch - +import openvino as ov from pytorch_layer_test_class import PytorchLayerTest class TestPolar(PytorchLayerTest): - def _prepare_input(self): + def _prepare_input(self, input_shape=(1, 1000), dtype=np.float32): return ( - np.random.uniform(0, 10, (1, 1000)).astype(np.float32), - np.random.uniform(-np.pi, np.pi, (1, 1000)).astype(np.float32) + np.random.uniform(0, 10, input_shape).astype(dtype), + np.random.uniform(-np.pi, np.pi, input_shape).astype(dtype) ) def create_model(self): @@ -19,14 +19,14 @@ class PolarModel(torch.nn.Module): def forward(self, abs, angle): complex_tensor = torch.polar(abs, angle) return torch.view_as_real(complex_tensor) - + ref_net = None return PolarModel(), None, "aten::polar" @pytest.mark.parametrize("input_case", [ - (1, 1000), - (2, 500), - (5, 200), - (10, 100), + (1, 1000), + (2, 500), + (5, 200), + (10, 100), ]) @pytest.mark.parametrize("dtype", [ np.float32, @@ -35,10 +35,8 @@ def forward(self, abs, angle): @pytest.mark.nightly @pytest.mark.precommit def test_polar(self, input_case, dtype, ie_device, precision, ir_version): - self.input_shape = input_case - self._prepare_input = lambda: ( - np.random.uniform(0, 10, input_case).astype(dtype), - np.random.uniform(-np.pi, np.pi, input_case).astype(dtype) - ) - self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True, - use_convert_model=True, kwargs_to_prepare_input={}) + atol = 1e-4 if precision == "FP32" else 1e-3 + rtol = 1e-4 + self._test(*self.create_model(), ie_device, precision, ir_version, + kwargs_to_prepare_input={"input_shape": input_case, "dtype": dtype}, + trace_model=True, use_convert_model=True, custom_eps=atol, dynamic_shapes=False)