Skip to content

Commit 8822480

Browse files
authored
#20927 support inputs that have no batch (#26778)
#20927 ### Details: - *add batch dimension before pool* - *remove batch dimension after pool*
1 parent 55d8c47 commit 8822480

File tree

4 files changed

+231
-58
lines changed

4 files changed

+231
-58
lines changed

src/frontends/pytorch/src/op/avg_poolnd.cpp

+66-5
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,17 @@
33
//
44

55
#include "openvino/frontend/pytorch/node_context.hpp"
6+
#include "openvino/op/add.hpp"
67
#include "openvino/op/avg_pool.hpp"
78
#include "openvino/op/broadcast.hpp"
89
#include "openvino/op/concat.hpp"
910
#include "openvino/op/constant.hpp"
1011
#include "openvino/op/pad.hpp"
11-
#include "openvino/op/subtract.hpp"
12+
#include "openvino/op/reshape.hpp"
13+
#include "openvino/op/shape_of.hpp"
14+
#include "openvino/op/slice.hpp"
15+
#include "openvino/op/squeeze.hpp"
16+
#include "openvino/op/unsqueeze.hpp"
1217
#include "utils.hpp"
1318

1419
namespace ov {
@@ -17,10 +22,31 @@ namespace pytorch {
1722
namespace op {
1823

1924
using namespace ov::op;
20-
21-
OutputVector translate_avg_poolnd(const NodeContext& context) {
25+
OutputVector translate_avg_pool_base(const NodeContext& context, int dims) {
2226
num_inputs_check(context, 2, 7);
2327
auto input = context.get_input(0);
28+
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));
29+
30+
auto const_0 = v0::Constant::create(element::i64, Shape{1}, {0});
31+
auto const_1 = v0::Constant::create(element::i64, Shape{1}, {1});
32+
bool is_static = input.get_partial_shape().rank().is_static();
33+
bool no_batch_dim = is_static && input.get_partial_shape().rank().get_length() == dims + 1;
34+
35+
if (is_static) {
36+
if (no_batch_dim) {
37+
input = context.mark_node(std::make_shared<v0::Unsqueeze>(input, const_0));
38+
}
39+
} else {
40+
input = context.mark_node(std::make_shared<v0::Unsqueeze>(input, const_0));
41+
auto unsqueeze_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));
42+
auto rank = context.mark_node(std::make_shared<v0::ShapeOf>(unsqueeze_shape));
43+
auto end_index = context.mark_node(std::make_shared<v1::Add>(rank, const_1));
44+
auto start_index = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-dims - 2}));
45+
auto reshape_pattern =
46+
context.mark_node(std::make_shared<v8::Slice>(unsqueeze_shape, start_index, end_index, const_1, const_0));
47+
input = context.mark_node(std::make_shared<v1::Reshape>(input, reshape_pattern, true));
48+
}
49+
2450
auto kernel = context.const_input<Shape>(1);
2551
Strides strides;
2652
if (!context.input_is_none(2)) {
@@ -47,8 +73,43 @@ OutputVector translate_avg_poolnd(const NodeContext& context) {
4773
}
4874
PYTORCH_OP_CONVERSION_CHECK(context.input_is_none(6),
4975
"Translation for aten::avg_pool2d do not support divisor_override input.");
50-
return {context.mark_node(
51-
std::make_shared<v14::AvgPool>(input, strides, pads, pads, kernel, !count_include_pad, rounding_type))};
76+
auto res = context.mark_node(
77+
std::make_shared<v14::AvgPool>(input, strides, pads, pads, kernel, !count_include_pad, rounding_type));
78+
79+
if (is_static) {
80+
if (no_batch_dim) {
81+
res = context.mark_node(std::make_shared<v0::Squeeze>(res, const_0));
82+
}
83+
} else {
84+
auto pooled_output_shape = context.mark_node(std::make_shared<v3::ShapeOf>(res));
85+
86+
auto start_index_input = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-dims}));
87+
auto slice_input_shape =
88+
context.mark_node(std::make_shared<v8::Slice>(input_shape, const_0, start_index_input, const_1, const_0));
89+
90+
auto start_index_pooled = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-dims}));
91+
auto end_index_pooled = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {2 + dims}));
92+
auto slice_pooled_output_shape = context.mark_node(
93+
std::make_shared<v8::Slice>(pooled_output_shape, start_index_pooled, end_index_pooled, const_1, const_0));
94+
95+
auto concat_shape = context.mark_node(
96+
std::make_shared<v0::Concat>(OutputVector{slice_input_shape, slice_pooled_output_shape}, 0));
97+
res = context.mark_node(std::make_shared<v1::Reshape>(res, concat_shape, true));
98+
}
99+
100+
return {res};
101+
};
102+
103+
OutputVector translate_avg_pool1d(const NodeContext& context) {
104+
return translate_avg_pool_base(context, 1);
105+
};
106+
107+
OutputVector translate_avg_pool2d(const NodeContext& context) {
108+
return translate_avg_pool_base(context, 2);
109+
};
110+
111+
OutputVector translate_avg_pool3d(const NodeContext& context) {
112+
return translate_avg_pool_base(context, 3);
52113
};
53114

54115
} // namespace op

src/frontends/pytorch/src/op/max_poolnd.cpp

+94-11
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,13 @@
1212
#include "openvino/op/multiply.hpp"
1313
#include "openvino/op/pad.hpp"
1414
#include "openvino/op/range.hpp"
15+
#include "openvino/op/reshape.hpp"
1516
#include "openvino/op/select.hpp"
1617
#include "openvino/op/shape_of.hpp"
18+
#include "openvino/op/slice.hpp"
19+
#include "openvino/op/squeeze.hpp"
1720
#include "openvino/op/subtract.hpp"
21+
#include "openvino/op/unsqueeze.hpp"
1822
#include "openvino/op/util/framework_node.hpp"
1923
#include "utils.hpp"
2024

@@ -24,9 +28,31 @@ namespace pytorch {
2428
namespace op {
2529

2630
using namespace ov::op;
27-
28-
OutputVector translate_max_poolnd(const NodeContext& context) {
31+
OutputVector translate_max_pool_base(const NodeContext& context, int dims) {
2932
num_inputs_check(context, 3, 6);
33+
auto input = context.get_input(0);
34+
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));
35+
36+
auto const_0 = v0::Constant::create(element::i64, Shape{1}, {0});
37+
auto const_1 = v0::Constant::create(element::i64, Shape{1}, {1});
38+
bool is_static = input.get_partial_shape().rank().is_static();
39+
bool no_batch_dim = is_static && input.get_partial_shape().rank().get_length() == dims + 1;
40+
41+
if (is_static) {
42+
if (no_batch_dim) {
43+
input = context.mark_node(std::make_shared<v0::Unsqueeze>(input, const_0));
44+
}
45+
} else {
46+
input = context.mark_node(std::make_shared<v0::Unsqueeze>(input, const_0));
47+
auto unsqueeze_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));
48+
auto rank = context.mark_node(std::make_shared<v0::ShapeOf>(unsqueeze_shape));
49+
auto end_index = context.mark_node(std::make_shared<v1::Add>(rank, const_1));
50+
auto start_index = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-dims - 2}));
51+
auto reshape_pattern =
52+
context.mark_node(std::make_shared<v8::Slice>(unsqueeze_shape, start_index, end_index, const_1, const_0));
53+
input = context.mark_node(std::make_shared<v1::Reshape>(input, reshape_pattern, true));
54+
}
55+
3056
auto kernel = context.const_input<Shape>(1);
3157
Strides strides;
3258
if (!context.input_is_none(2)) {
@@ -53,7 +79,7 @@ OutputVector translate_max_poolnd(const NodeContext& context) {
5379
rounding_type = context.const_input<bool>(5) ? RoundingType::CEIL_TORCH : RoundingType::FLOOR;
5480
}
5581

56-
auto res = context.mark_node(std::make_shared<v14::MaxPool>(context.get_input(0),
82+
auto res = context.mark_node(std::make_shared<v14::MaxPool>(input,
5783
strides,
5884
dilations,
5985
pads,
@@ -63,19 +89,76 @@ OutputVector translate_max_poolnd(const NodeContext& context) {
6389
PadType::EXPLICIT,
6490
element::i64,
6591
2));
66-
if (context.get_output_size() == 2) {
67-
auto out1 = res->output(0);
68-
auto out2 = res->output(1);
69-
return {std::move(out1), std::move(out2)};
92+
if (is_static) {
93+
if (no_batch_dim) {
94+
if (context.get_output_size() == 2) {
95+
auto out1 = res->output(0);
96+
auto out2 = res->output(1);
97+
out1 = context.mark_node(std::make_shared<v0::Squeeze>(out1, const_0));
98+
out2 = context.mark_node(std::make_shared<v0::Squeeze>(out2, const_0));
99+
return {std::move(out1), std::move(out2)};
100+
} else {
101+
res = context.mark_node(std::make_shared<v0::Squeeze>(res, const_0));
102+
return {res};
103+
}
104+
} else {
105+
if (context.get_output_size() == 2) {
106+
auto out1 = res->output(0);
107+
auto out2 = res->output(1);
108+
return {std::move(out1), std::move(out2)};
109+
} else {
110+
return {res};
111+
}
112+
}
113+
70114
} else {
71-
return {res};
115+
auto pooled_output_shape = context.mark_node(std::make_shared<v3::ShapeOf>(res));
116+
117+
auto start_index_input = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-dims}));
118+
auto slice_input_shape =
119+
context.mark_node(std::make_shared<v8::Slice>(input_shape, const_0, start_index_input, const_1, const_0));
120+
121+
auto start_index_pooled = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-dims}));
122+
auto end_index_pooled = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {2 + dims}));
123+
auto slice_pooled_output_shape = context.mark_node(
124+
std::make_shared<v8::Slice>(pooled_output_shape, start_index_pooled, end_index_pooled, const_1, const_0));
125+
126+
auto concat_shape = context.mark_node(
127+
std::make_shared<v0::Concat>(OutputVector{slice_input_shape, slice_pooled_output_shape}, 0));
128+
if (context.get_output_size() == 2) {
129+
auto out1 = res->output(0);
130+
auto out2 = res->output(1);
131+
out1 = context.mark_node(std::make_shared<v1::Reshape>(out1, concat_shape, true));
132+
out2 = context.mark_node(std::make_shared<v1::Reshape>(out2, concat_shape, true));
133+
return {std::move(out1), std::move(out2)};
134+
} else {
135+
res = context.mark_node(std::make_shared<v1::Reshape>(res, concat_shape, true));
136+
return {res};
137+
}
72138
}
73139
};
74140

75-
OutputVector translate_max_poolnd_fx(const NodeContext& context) {
76-
auto output = translate_max_poolnd(context);
141+
OutputVector translate_max_pool1d(const NodeContext& context) {
142+
return translate_max_pool_base(context, 1);
143+
};
144+
145+
OutputVector translate_max_pool2d(const NodeContext& context) {
146+
return translate_max_pool_base(context, 2);
147+
};
148+
149+
OutputVector translate_max_pool3d(const NodeContext& context) {
150+
return translate_max_pool_base(context, 3);
151+
};
152+
153+
OutputVector translate_max_pool2d_fx(const NodeContext& context) {
154+
auto output = translate_max_pool2d(context);
77155
return {context.mark_node(make_list_construct(output))};
78-
}
156+
};
157+
158+
OutputVector translate_max_pool3d_fx(const NodeContext& context) {
159+
auto output = translate_max_pool3d(context);
160+
return {context.mark_node(make_list_construct(output))};
161+
};
79162

80163
} // namespace op
81164
} // namespace pytorch

src/frontends/pytorch/src/op_table.cpp

+21-16
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ OP_CONVERTER(translate_argmax);
4242
OP_CONVERTER(translate_argmin);
4343
OP_CONVERTER(translate_as_strided);
4444
OP_CONVERTER(translate_as_tensor);
45-
OP_CONVERTER(translate_avg_poolnd);
45+
OP_CONVERTER(translate_avg_pool1d);
46+
OP_CONVERTER(translate_avg_pool2d);
47+
OP_CONVERTER(translate_avg_pool3d);
4648
OP_CONVERTER(translate_bool);
4749
OP_CONVERTER(translate_batch_norm);
4850
OP_CONVERTER(translate_bitwise_and);
@@ -139,7 +141,9 @@ OP_CONVERTER(translate_masked_scatter);
139141
OP_CONVERTER(translate_masked_select);
140142
OP_CONVERTER(translate_max);
141143
OP_CONVERTER(translate_maximum);
142-
OP_CONVERTER(translate_max_poolnd);
144+
OP_CONVERTER(translate_max_pool1d);
145+
OP_CONVERTER(translate_max_pool2d);
146+
OP_CONVERTER(translate_max_pool3d);
143147
OP_CONVERTER(translate_mean);
144148
OP_CONVERTER(translate_meshgrid);
145149
OP_CONVERTER(translate_min);
@@ -281,7 +285,8 @@ OP_CONVERTER(translate_leaky_relu_fx);
281285
OP_CONVERTER(translate_log_sigmoid_fx);
282286
OP_CONVERTER(translate_log_softmax_fx);
283287
OP_CONVERTER(translate_max_dim_fx);
284-
OP_CONVERTER(translate_max_poolnd_fx);
288+
OP_CONVERTER(translate_max_pool2d_fx);
289+
OP_CONVERTER(translate_max_pool3d_fx);
285290
OP_CONVERTER(translate_mean_fx);
286291
OP_CONVERTER(translate_min_dim_fx);
287292
OP_CONVERTER(translate_new_full_fx);
@@ -380,9 +385,9 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
380385
{"aten::atanh",
381386
op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atanh>, 1>},
382387
{"aten::atanh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Atanh>>},
383-
{"aten::avg_pool1d", op::quantizable_op<op::translate_avg_poolnd>},
384-
{"aten::avg_pool2d", op::quantizable_op<op::translate_avg_poolnd>},
385-
{"aten::avg_pool3d", op::quantizable_op<op::translate_avg_poolnd>},
388+
{"aten::avg_pool1d", op::quantizable_op<op::translate_avg_pool1d>},
389+
{"aten::avg_pool2d", op::quantizable_op<op::translate_avg_pool2d>},
390+
{"aten::avg_pool3d", op::quantizable_op<op::translate_avg_pool3d>},
386391
{"aten::baddbmm", op::translate_addmm},
387392
{"aten::batch_norm", op::translate_batch_norm},
388393
{"aten::bitwise_and", op::translate_bitwise_and},
@@ -534,12 +539,12 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
534539
{"aten::max", op::translate_max},
535540
{"aten::mv", op::translate_1to1_match_2_inputs<opset10::MatMul>},
536541
{"aten::maximum", op::translate_maximum},
537-
{"aten::max_pool1d", op::quantizable_op<op::translate_max_poolnd>},
538-
{"aten::max_pool1d_with_indices", op::quantizable_op<op::translate_max_poolnd>},
539-
{"aten::max_pool2d", op::quantizable_op<op::translate_max_poolnd>},
540-
{"aten::max_pool2d_with_indices", op::quantizable_op<op::translate_max_poolnd>},
541-
{"aten::max_pool3d", op::quantizable_op<op::translate_max_poolnd>},
542-
{"aten::max_pool3d_with_indices", op::quantizable_op<op::translate_max_poolnd>},
542+
{"aten::max_pool1d", op::quantizable_op<op::translate_max_pool1d>},
543+
{"aten::max_pool1d_with_indices", op::quantizable_op<op::translate_max_pool1d>},
544+
{"aten::max_pool2d", op::quantizable_op<op::translate_max_pool2d>},
545+
{"aten::max_pool2d_with_indices", op::quantizable_op<op::translate_max_pool2d>},
546+
{"aten::max_pool3d", op::quantizable_op<op::translate_max_pool3d>},
547+
{"aten::max_pool3d_with_indices", op::quantizable_op<op::translate_max_pool3d>},
543548
{"aten::mean", op::quantizable_op<op::translate_mean>},
544549
{"aten::meshgrid", op::translate_meshgrid},
545550
{"aten::min", op::translate_min},
@@ -771,8 +776,8 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
771776
{"aten.asinh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Asinh>},
772777
{"aten.atan.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atan>},
773778
{"aten.atanh.default", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atanh>},
774-
{"aten.avg_pool2d.default", op::translate_avg_poolnd},
775-
{"aten.avg_pool3d.default", op::translate_avg_poolnd},
779+
{"aten.avg_pool2d.default", op::translate_avg_pool2d},
780+
{"aten.avg_pool3d.default", op::translate_avg_pool3d},
776781
{"aten.baddbmm.default", op::translate_addmm_fx},
777782
{"aten.bitwise_and.Scalar", op::translate_bitwise_and},
778783
{"aten.bitwise_and.Tensor", op::translate_bitwise_and},
@@ -870,8 +875,8 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
870875
{"aten.masked_fill_.Tensor", op::inplace_op<op::translate_masked_fill>},
871876
{"aten.max.default", op::translate_max},
872877
{"aten.max.dim", op::translate_max_dim_fx},
873-
{"aten.max_pool2d_with_indices.default", op::translate_max_poolnd_fx},
874-
{"aten.max_pool3d_with_indices.default", op::translate_max_poolnd_fx},
878+
{"aten.max_pool2d_with_indices.default", op::translate_max_pool2d_fx},
879+
{"aten.max_pool3d_with_indices.default", op::translate_max_pool3d_fx},
875880
{"aten.maximum.default", op::translate_maximum},
876881
{"aten.mean.default", op::translate_mean_fx},
877882
{"aten.mean.dim", op::translate_mean_fx},

0 commit comments

Comments
 (0)