Skip to content

Commit 9ab2c1a

Browse files
authored
[PDPD] fix ops for baidu customer (#27222)
### Details: - *fix the negative dim issue when run into `reduce_fusion` pass* - *enable eye op* - *enable elu op* - *add tests* - *upgrade paddlepaddle to 2.6.2* - *upgrade opset to 14 version* ### Tickets: - *N/A*
1 parent 2ef42d4 commit 9ab2c1a

File tree

11 files changed

+175
-5
lines changed

11 files changed

+175
-5
lines changed

src/frontends/paddle/src/default_opset.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
// SPDX-License-Identifier: Apache-2.0
33
//
44

5-
#include "openvino/opsets/opset9.hpp"
5+
#include "openvino/opsets/opset14.hpp"
66

77
namespace ov {
88
namespace frontend {
99
namespace paddle {
1010
namespace op {
11-
namespace default_opset = ov::opset9;
11+
namespace default_opset = ov::opset14;
1212

1313
} // namespace op
1414
} // namespace paddle

src/frontends/paddle/src/op/elu.cpp

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "default_opset.hpp"
6+
#include "openvino/frontend/paddle/node_context.hpp"
7+
#include "openvino/frontend/paddle/visibility.hpp"
8+
9+
namespace ov {
10+
namespace frontend {
11+
namespace paddle {
12+
namespace op {
13+
NamedOutputs elu(const NodeContext& node) {
14+
auto data = node.get_input("X");
15+
auto alpha = node.get_attribute<float>("alpha", 1.0);
16+
const auto& elu_node = std::make_shared<default_opset::Elu>(data, alpha);
17+
return node.default_single_output_mapping({elu_node}, {"Out"});
18+
}
19+
20+
} // namespace op
21+
} // namespace paddle
22+
} // namespace frontend
23+
} // namespace ov

src/frontends/paddle/src/op/expand_v2.cpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,16 @@ NamedOutputs expand_v2(const NodeContext& node) {
1919
auto inputs = node.get_ng_inputs("expand_shapes_tensor");
2020
ov::NodeVector node_vec;
2121
for (auto& input : inputs) {
22+
if (input.get_partial_shape().rank().get_length() == 0) {
23+
// should unsqueeze the input with non-shape.
24+
auto unsqueeze_scalar = default_opset::Constant::create(ov::element::i32, {}, {0});
25+
input = std::make_shared<default_opset::Unsqueeze>(input, unsqueeze_scalar);
26+
}
27+
PADDLE_OP_CHECK(node,
28+
input.get_partial_shape().rank().get_length() == 1,
29+
"the rank of conv input must == 1");
2230
auto cast = std::make_shared<Convert>(input, element::i32);
23-
node_vec.push_back(cast);
31+
node_vec.emplace_back(cast);
2432
}
2533
shape_expected_node = std::make_shared<Concat>(node_vec, 0);
2634
} else {

src/frontends/paddle/src/op/eye.cpp

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "default_opset.hpp"
6+
#include "openvino/frontend/paddle/node_context.hpp"
7+
8+
namespace ov {
9+
namespace frontend {
10+
namespace paddle {
11+
namespace op {
12+
NamedOutputs eye(const NodeContext& node) {
13+
auto row = node.get_attribute<int64_t>("num_rows");
14+
auto col = node.get_attribute<int64_t>("num_columns", row);
15+
auto dtype = node.get_attribute<ov::element::Type>("dtype", ov::element::f32);
16+
17+
const auto& row_node = std::make_shared<default_opset::Constant>(ov::element::i64, Shape{}, (row));
18+
const auto& col_node = std::make_shared<default_opset::Constant>(ov::element::i64, Shape{}, (col));
19+
const auto& diagonal_index_node = std::make_shared<default_opset::Constant>(ov::element::i32, Shape{}, (0));
20+
21+
std::shared_ptr<Node> out_node;
22+
if (dtype == ov::element::i32 || dtype == ov::element::i64) {
23+
out_node = std::make_shared<default_opset::Eye>(row_node, col_node, diagonal_index_node, dtype);
24+
} else {
25+
const auto& eye_node =
26+
std::make_shared<default_opset::Eye>(row_node, col_node, diagonal_index_node, ov::element::i32);
27+
out_node = std::make_shared<default_opset::Convert>(eye_node, dtype);
28+
}
29+
30+
return node.default_single_output_mapping({out_node}, {"Out"});
31+
}
32+
33+
} // namespace op
34+
} // namespace paddle
35+
} // namespace frontend
36+
} // namespace ov

src/frontends/paddle/src/op/fill_constant.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ NamedOutputs fill_constant(const NodeContext& node) {
2929
PADDLE_OP_CHECK(node, false, "fill_constant only supports i32, f32, i64");
3030
}
3131

32+
if (shape.empty()) {
33+
shape.emplace_back(1);
34+
}
35+
3236
PADDLE_OP_CHECK(node,
3337
shape.size() > 0 || node.has_input("ShapeTensor") || node.has_input("ShapeTensorList"),
3438
"fill_constant shape not set");

src/frontends/paddle/src/op/interp.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "default_opset.hpp"
66
#include "openvino/frontend/paddle/node_context.hpp"
7+
#include "openvino/opsets/opset4.hpp"
78

89
namespace ov {
910
namespace frontend {
@@ -147,8 +148,9 @@ static NamedOutputs interpolate(const NodeContext& node,
147148
attrs.pads_begin = {0, 0, 0, 0};
148149
attrs.pads_end = {0, 0, 0, 0};
149150

150-
return node.default_single_output_mapping({std::make_shared<Interpolate>(x, target_spatial_shape, scales, attrs)},
151-
{"Out"});
151+
return node.default_single_output_mapping(
152+
{std::make_shared<ov::opset4::Interpolate>(x, target_spatial_shape, scales, attrs)},
153+
{"Out"});
152154
}
153155

154156
NamedOutputs linear_interp_v2(const NodeContext& node) {

src/frontends/paddle/src/op/reduce_ops.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ NamedOutputs reduce_ops(const NodeContext& node) {
3131
dims = node.get_attribute<std::vector<int64_t>>("dim");
3232
}
3333

34+
std::transform(dims.begin(), dims.end(), dims.begin(), [&input_rank](int64_t value) {
35+
return value >= 0 ? value : value + input_rank;
36+
});
37+
3438
int64_t axis_size = static_cast<int64_t>(dims.size());
3539
reduce_all = reduce_all || (axis_size == input_rank || axis_size == 0);
3640

src/frontends/paddle/src/op_table.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,11 @@ OP_CONVERTER(elementwise_sub);
3939
OP_CONVERTER(equal);
4040
OP_CONVERTER(greater_equal);
4141
OP_CONVERTER(not_equal);
42+
OP_CONVERTER(elu);
4243
OP_CONVERTER(embedding);
4344
OP_CONVERTER(exp);
4445
OP_CONVERTER(expand_v2);
46+
OP_CONVERTER(eye);
4547
OP_CONVERTER(flip);
4648
OP_CONVERTER(flatten_contiguous_range);
4749
OP_CONVERTER(floor);
@@ -173,9 +175,11 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
173175
{"elementwise_sub", op::elementwise_sub},
174176
{"dropout", op::dropout},
175177
{"elementwise_pow", op::elementwise_pow},
178+
{"elu", op::elu},
176179
{"equal", op::equal},
177180
{"exp", op::exp},
178181
{"expand_v2", op::expand_v2},
182+
{"eye", op::eye},
179183
{"fill_any_like", op::fill_any_like},
180184
{"fill_constant", op::fill_constant},
181185
{"fill_constant_batch_size_like", op::fill_constant_batch_size_like},

src/frontends/paddle/tests/op_fuzzy.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ static const std::vector<std::string> models{
188188
std::string("elementwise_floordiv_int64_2/elementwise_floordiv_int64_2.pdmodel"),
189189
std::string("elementwise_floordiv_int64_3/elementwise_floordiv_int64_3.pdmodel"),
190190
std::string("elementwise_mul_bool1/elementwise_mul_bool1.pdmodel"),
191+
std::string("elu/elu.pdmodel"),
191192
std::string("embedding_0/embedding_0.pdmodel"),
192193
std::string("embedding_sparse/embedding_sparse.pdmodel"),
193194
std::string("embedding_none_weight/embedding_none_weight.pdmodel"),
@@ -201,6 +202,9 @@ static const std::vector<std::string> models{
201202
std::string("expand_v2_tensor_list/expand_v2_tensor_list.pdmodel"),
202203
std::string("expand_v2_tensor_list2/expand_v2_tensor_list2.pdmodel"),
203204
std::string("exp_test_float32/exp_test_float32.pdmodel"),
205+
std::string("eye/eye.pdmodel"),
206+
std::string("eye_int32/eye_int32.pdmodel"),
207+
std::string("eye_int64/eye_int64.pdmodel"),
204208
std::string("flip_1/flip_1.pdmodel"),
205209
std::string("flip_2/flip_2.pdmodel"),
206210
std::string("flip_3/flip_3.pdmodel"),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (C) 2018-2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
#
5+
# relu6 paddle model generator
6+
#
7+
import numpy as np
8+
from save_model import saveModel
9+
import paddle
10+
import sys
11+
12+
13+
def elu(name: str, x, alpha=None, data_type='float32'):
14+
paddle.enable_static()
15+
16+
with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
17+
node_x = paddle.static.data(name='x', shape=x.shape, dtype=data_type)
18+
19+
if paddle.__version__ >= '2.0.0':
20+
out = paddle.nn.functional.elu(node_x, alpha, name='elu')
21+
else:
22+
out = paddle.fluid.layers.elu(node_x, alpha, name='elu')
23+
cpu = paddle.static.cpu_places(1)
24+
exe = paddle.static.Executor(cpu[0])
25+
# startup program will call initializer to initialize the parameters.
26+
exe.run(paddle.static.default_startup_program())
27+
28+
outs = exe.run(
29+
feed={'x': x},
30+
fetch_list=[out])
31+
32+
saveModel(name, exe, feed_vars=[node_x], fetchlist=[out],
33+
inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1])
34+
35+
return outs[0]
36+
37+
38+
def main():
39+
data_type = 'float32'
40+
data = np.random.randn(2, 3, 4).astype('float32')
41+
elu("elu", data)
42+
43+
if __name__ == "__main__":
44+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (C) 2018-2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
#
5+
# fill_const paddle model generator
6+
#
7+
import numpy as np
8+
from save_model import saveModel
9+
import paddle
10+
import sys
11+
12+
13+
def eye(name : str, rows, cols = None, dtype = None):
14+
paddle.enable_static()
15+
with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
16+
if paddle.__version__ >= '2.0.0':
17+
x1 = paddle.eye(num_rows=rows, num_columns=cols, dtype=dtype, name='fill')
18+
x2 = paddle.eye(num_rows=rows, num_columns=cols, dtype=dtype, name='fill')
19+
else:
20+
x1 = paddle.fluid.layers.eye(num_rows=rows, num_columns=cols, dtype=dtype, name='fill_constant')
21+
x2 = paddle.fluid.layers.eye(num_rows=rows, num_columns=cols, dtype=dtype, name='fill_constant')
22+
out = paddle.add(x1, x2)
23+
cpu = paddle.static.cpu_places(1)
24+
exe = paddle.static.Executor(cpu[0])
25+
# startup program will call initializer to initialize the parameters.
26+
exe.run(paddle.static.default_startup_program())
27+
28+
outs = exe.run(
29+
fetch_list=[out])
30+
31+
saveModel(name, exe, feed_vars=[], fetchlist=[out], inputs=[], outputs=[outs[0]], target_dir=sys.argv[1])
32+
33+
return outs[0]
34+
35+
def main():
36+
eye("eye", 3)
37+
eye("eye_int32", 2, 3, "int32")
38+
eye("eye_int64", 2, 3, "int64")
39+
40+
if __name__ == "__main__":
41+
main()

0 commit comments

Comments
 (0)