Skip to content

Commit f05391b

Browse files
authored
[PT FE]: support aten::floor_divide_ and aten::movedim (openvinotoolkit#23446)
### Details: - *aten::floor_divide* - *aten::movedim* ### Tickets: - *ticket-id*
1 parent 8ba1ae3 commit f05391b

File tree

4 files changed

+111
-1
lines changed

4 files changed

+111
-1
lines changed

src/frontends/pytorch/src/op/transpose.cpp

+46-1
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,17 @@
66

77
#include "openvino/frontend/pytorch/node_context.hpp"
88
#include "openvino/op/add.hpp"
9+
#include "openvino/op/broadcast.hpp"
910
#include "openvino/op/concat.hpp"
1011
#include "openvino/op/constant.hpp"
1112
#include "openvino/op/equal.hpp"
1213
#include "openvino/op/if.hpp"
14+
#include "openvino/op/non_zero.hpp"
15+
#include "openvino/op/not_equal.hpp"
1316
#include "openvino/op/range.hpp"
17+
#include "openvino/op/reshape.hpp"
1418
#include "openvino/op/scatter_elements_update.hpp"
19+
#include "openvino/op/shape_of.hpp"
1520
#include "openvino/op/unsqueeze.hpp"
1621
#include "utils.hpp"
1722

@@ -81,7 +86,47 @@ OutputVector translate_t(const NodeContext& context) {
8186
if_node->set_input(input, param_then, param_else);
8287
return {if_node->set_output(result_then, result_else)};
8388
}
84-
}
89+
};
90+
91+
OutputVector translate_movedim(const NodeContext& context) {
92+
// aten::movedim.int(Tensor(a) self, int source, int destination) -> Tensor(a)
93+
// aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)
94+
// based on https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/TensorShape.cpp#L3816
95+
num_inputs_check(context, 3, 3);
96+
auto x = context.get_input(0);
97+
auto src_dims = context.get_input(1);
98+
auto dst_dims = context.get_input(2);
99+
Output<Node> rank;
100+
std::tie(std::ignore, rank) = get_shape_rank(context, context.get_input(0), true);
101+
src_dims = normalize_axis(context, src_dims, rank);
102+
dst_dims = normalize_axis(context, dst_dims, rank);
103+
auto const_0 = context.mark_node(v0::Constant::create(element::i32, {}, {0}));
104+
auto const_1 = context.mark_node(v0::Constant::create(element::i32, {}, {1}));
105+
auto range = context.mark_node(std::make_shared<v4::Range>(const_0, rank, const_1, element::i32));
106+
auto dims_1d_shape = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
107+
// operation accepts 0d and 1d source and destination, make them always 1d
108+
src_dims = context.mark_node(std::make_shared<v1::Reshape>(src_dims, dims_1d_shape, false));
109+
dst_dims = context.mark_node(std::make_shared<v1::Reshape>(dst_dims, dims_1d_shape, false));
110+
auto dims_shape = context.mark_node(std::make_shared<v3::ShapeOf>(src_dims, element::i32));
111+
auto minus_one_replaces = context.mark_node(std::make_shared<v1::Broadcast>(dims_1d_shape, dims_shape));
112+
// update position for the dim provided by user and mark used dims for source and destination as -1
113+
auto perm_dims = context.mark_node(std::make_shared<v3::ScatterElementsUpdate>(range, dst_dims, src_dims, const_0));
114+
auto src_perm_dims =
115+
context.mark_node(std::make_shared<v3::ScatterElementsUpdate>(range, src_dims, minus_one_replaces, const_0));
116+
auto dst_perm_dims =
117+
context.mark_node(std::make_shared<v3::ScatterElementsUpdate>(range, dst_dims, minus_one_replaces, const_0));
118+
// Remove the dims whose position we already know, the ones marked with -1 in previous step
119+
auto not_changed_src = context.mark_node(std::make_shared<v1::NotEqual>(src_perm_dims, dims_1d_shape));
120+
auto not_changed_dst = context.mark_node(std::make_shared<v1::NotEqual>(dst_perm_dims, dims_1d_shape));
121+
auto indices = context.mark_node(std::make_shared<v3::NonZero>(not_changed_dst, element::i32));
122+
auto updates = context.mark_node(std::make_shared<v3::NonZero>(not_changed_src, element::i32));
123+
// Update the position of the remaining dimensions. indices now contains the original position
124+
// updates contains the new position it will shifted to after considering the user inputs.
125+
indices = context.mark_node(std::make_shared<v1::Reshape>(indices, dims_1d_shape, false));
126+
updates = context.mark_node(std::make_shared<v1::Reshape>(updates, dims_1d_shape, false));
127+
auto scatter = std::make_shared<v3::ScatterElementsUpdate>(perm_dims, indices, updates, const_0);
128+
return {context.mark_node(std::make_shared<v1::Transpose>(x, scatter))};
129+
};
85130

86131
} // namespace op
87132
} // namespace pytorch

src/frontends/pytorch/src/op_table.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ OP_CONVERTER(translate_mean);
135135
OP_CONVERTER(translate_meshgrid);
136136
OP_CONVERTER(translate_min);
137137
OP_CONVERTER(translate_minimum);
138+
OP_CONVERTER(translate_movedim);
138139
OP_CONVERTER(translate_multinomial);
139140
OP_CONVERTER(translate_narrow);
140141
OP_CONVERTER(translate_native_multi_head_attention);
@@ -433,6 +434,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
433434
{"aten::floor", op::optional_out<op::translate_1to1_match_1_inputs<opset10::Floor>, 1>},
434435
{"aten::floor_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Floor>>},
435436
{"aten::floor_divide", op::translate_floor_divide},
437+
{"aten::floor_divide_", op::inplace_op<op::translate_floor_divide>},
436438
{"aten::floordiv", op::translate_floor_divide},
437439
{"aten::fmod", op::translate_fmod},
438440
{"aten::frobenius_norm", op::translate_frobenius_norm},
@@ -520,6 +522,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
520522
{"aten::mish", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Mish>},
521523
{"aten::mish_", op::inplace_op<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Mish>>},
522524
{"aten::mm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
525+
{"aten::movedim", op::translate_movedim},
523526
{"aten::mul", op::translate_mul},
524527
{"aten::mul_", op::translate_mul_},
525528
{"aten::multiply", op::translate_mul},

tests/layer_tests/pytorch_tests/test_floor_divide.py

+38
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,21 @@ def forward(self, input_tensor, other_tensor):
4141

4242
return aten_floor_divide(), ref_net, "aten::floor_divide"
4343

44+
def create_model_inplace(self):
45+
import torch
46+
47+
class aten_floor_divide_(torch.nn.Module):
48+
def __init__(self):
49+
super(aten_floor_divide_, self).__init__()
50+
51+
def forward(self, input_tensor, other_tensor):
52+
return input_tensor.floor_divide_(other_tensor), input_tensor
53+
54+
ref_net = None
55+
56+
return aten_floor_divide_(), ref_net, "aten::floor_divide_"
57+
58+
4459
@pytest.mark.parametrize('input_tensor',
4560
([
4661
[5], [5, 5, 1], [1, 1, 5, 5],
@@ -65,6 +80,29 @@ def test_floor_divide(self, input_tensor, other_tensor, ie_device, precision, ir
6580
self.other_tensor = other_tensor
6681
self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True, use_convert_model=True)
6782

83+
@pytest.mark.parametrize('input_tensor',
84+
([
85+
[5, 5, 5], [1, 1, 5, 5],
86+
]))
87+
@pytest.mark.parametrize('other_tensor',
88+
([
89+
np.array([0.5]).astype(np.float32), [5], [5, 1], [1, 5]
90+
]))
91+
@pytest.mark.nightly
92+
@pytest.mark.precommit
93+
@pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64',
94+
reason='Ticket - 122715')
95+
def test_floor_divide_(self, input_tensor, other_tensor, ie_device, precision, ir_version):
96+
if type(input_tensor) is list:
97+
self.input_tensor = np.random.randn(*input_tensor).astype(np.float32)
98+
else:
99+
self.input_tensor = input_tensor
100+
if type(other_tensor) is list:
101+
self.other_tensor = np.random.randn(*other_tensor).astype(np.float32)
102+
else:
103+
self.other_tensor = other_tensor
104+
self._test(*self.create_model_inplace(), ie_device, precision, ir_version, trace_model=True, use_convert_model=True)
105+
68106
@pytest.mark.parametrize('input_data',
69107
[
70108
{ "tensor": [5], "low": 0, "high": 10 },

tests/layer_tests/pytorch_tests/test_transpose.py

+24
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,30 @@ def test_transpose(self, dim0, dim1, op_type, ie_device, precision, ir_version):
5050
self._test(*self.create_model(dim0, dim1, op_type), ie_device, precision, ir_version, trace_model=True)
5151

5252

53+
class TestMoveDim(PytorchLayerTest):
54+
def _prepare_input(self):
55+
return (np.random.randn(2, 3, 4, 5).astype(np.float32),)
56+
57+
def create_model(self, dim0, dim1):
58+
class aten_move_dim(torch.nn.Module):
59+
def __init__(self, dim0, dim1):
60+
super(aten_move_dim, self).__init__()
61+
self.dim0 = dim0
62+
self.dim1 = dim1
63+
64+
def forward(self, x):
65+
return torch.movedim(x, self.dim0, self.dim1)
66+
67+
ref_net = None
68+
69+
return aten_move_dim(dim0, dim1), ref_net, f"aten::movedim"
70+
71+
@pytest.mark.parametrize(("dim0", "dim1"), [[0, 1], [-1, 0], [2, -2], [3, 1], [3, 3], [[1, 2], [3, 0]], [[-4, 1], [1, -1]], [[1, 3, 2], [0, 1, 2 ]]])
72+
@pytest.mark.nightly
73+
@pytest.mark.precommit
74+
def test_move_dim(self, dim0, dim1, ie_device, precision, ir_version):
75+
self._test(*self.create_model(dim0, dim1), ie_device, precision, ir_version, trace_model=True)
76+
5377
class TestTSmall(PytorchLayerTest):
5478
def _prepare_input(self, num_dims=2, input_dtype="float32"):
5579
shape = (2, 3)

0 commit comments

Comments
 (0)