|
6 | 6 |
|
7 | 7 | #include "openvino/frontend/pytorch/node_context.hpp"
|
8 | 8 | #include "openvino/op/add.hpp"
|
| 9 | +#include "openvino/op/broadcast.hpp" |
9 | 10 | #include "openvino/op/concat.hpp"
|
10 | 11 | #include "openvino/op/constant.hpp"
|
11 | 12 | #include "openvino/op/equal.hpp"
|
12 | 13 | #include "openvino/op/if.hpp"
|
| 14 | +#include "openvino/op/non_zero.hpp" |
| 15 | +#include "openvino/op/not_equal.hpp" |
13 | 16 | #include "openvino/op/range.hpp"
|
| 17 | +#include "openvino/op/reshape.hpp" |
14 | 18 | #include "openvino/op/scatter_elements_update.hpp"
|
| 19 | +#include "openvino/op/shape_of.hpp" |
15 | 20 | #include "openvino/op/unsqueeze.hpp"
|
16 | 21 | #include "utils.hpp"
|
17 | 22 |
|
@@ -81,7 +86,47 @@ OutputVector translate_t(const NodeContext& context) {
|
81 | 86 | if_node->set_input(input, param_then, param_else);
|
82 | 87 | return {if_node->set_output(result_then, result_else)};
|
83 | 88 | }
|
84 |
| -} |
| 89 | +}; |
| 90 | + |
| 91 | +OutputVector translate_movedim(const NodeContext& context) { |
| 92 | + // aten::movedim.int(Tensor(a) self, int source, int destination) -> Tensor(a) |
| 93 | + // aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a) |
| 94 | + // based on https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/TensorShape.cpp#L3816 |
| 95 | + num_inputs_check(context, 3, 3); |
| 96 | + auto x = context.get_input(0); |
| 97 | + auto src_dims = context.get_input(1); |
| 98 | + auto dst_dims = context.get_input(2); |
| 99 | + Output<Node> rank; |
| 100 | + std::tie(std::ignore, rank) = get_shape_rank(context, context.get_input(0), true); |
| 101 | + src_dims = normalize_axis(context, src_dims, rank); |
| 102 | + dst_dims = normalize_axis(context, dst_dims, rank); |
| 103 | + auto const_0 = context.mark_node(v0::Constant::create(element::i32, {}, {0})); |
| 104 | + auto const_1 = context.mark_node(v0::Constant::create(element::i32, {}, {1})); |
| 105 | + auto range = context.mark_node(std::make_shared<v4::Range>(const_0, rank, const_1, element::i32)); |
| 106 | + auto dims_1d_shape = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); |
| 107 | + // operation accepts 0d and 1d source and destination, make them always 1d |
| 108 | + src_dims = context.mark_node(std::make_shared<v1::Reshape>(src_dims, dims_1d_shape, false)); |
| 109 | + dst_dims = context.mark_node(std::make_shared<v1::Reshape>(dst_dims, dims_1d_shape, false)); |
| 110 | + auto dims_shape = context.mark_node(std::make_shared<v3::ShapeOf>(src_dims, element::i32)); |
| 111 | + auto minus_one_replaces = context.mark_node(std::make_shared<v1::Broadcast>(dims_1d_shape, dims_shape)); |
| 112 | + // update position for the dim provided by user and mark used dims for source and destination as -1 |
| 113 | + auto perm_dims = context.mark_node(std::make_shared<v3::ScatterElementsUpdate>(range, dst_dims, src_dims, const_0)); |
| 114 | + auto src_perm_dims = |
| 115 | + context.mark_node(std::make_shared<v3::ScatterElementsUpdate>(range, src_dims, minus_one_replaces, const_0)); |
| 116 | + auto dst_perm_dims = |
| 117 | + context.mark_node(std::make_shared<v3::ScatterElementsUpdate>(range, dst_dims, minus_one_replaces, const_0)); |
| 118 | + // Remove the dims whose position we already know, the ones marked with -1 in previous step |
| 119 | + auto not_changed_src = context.mark_node(std::make_shared<v1::NotEqual>(src_perm_dims, dims_1d_shape)); |
| 120 | + auto not_changed_dst = context.mark_node(std::make_shared<v1::NotEqual>(dst_perm_dims, dims_1d_shape)); |
| 121 | + auto indices = context.mark_node(std::make_shared<v3::NonZero>(not_changed_dst, element::i32)); |
| 122 | + auto updates = context.mark_node(std::make_shared<v3::NonZero>(not_changed_src, element::i32)); |
| 123 | + // Update the position of the remaining dimensions. indices now contains the original position |
| 124 | + // updates contains the new position it will shifted to after considering the user inputs. |
| 125 | + indices = context.mark_node(std::make_shared<v1::Reshape>(indices, dims_1d_shape, false)); |
| 126 | + updates = context.mark_node(std::make_shared<v1::Reshape>(updates, dims_1d_shape, false)); |
| 127 | + auto scatter = std::make_shared<v3::ScatterElementsUpdate>(perm_dims, indices, updates, const_0); |
| 128 | + return {context.mark_node(std::make_shared<v1::Transpose>(x, scatter))}; |
| 129 | +}; |
85 | 130 |
|
86 | 131 | } // namespace op
|
87 | 132 | } // namespace pytorch
|
|
0 commit comments