Skip to content

Commit b639368

Browse files
Update output shape to 4d for static reshape not using new_shape_infer
1 parent 693b983 commit b639368

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

src/plugins/intel_gpu/src/graph/eltwise.cpp

-3
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,6 @@ eltwise_inst::typed_primitive_inst(network& network, eltwise_node const& node) :
396396
bool use_new_shape_infer = network.get_config().get_property(ov::intel_gpu::allow_new_shape_infer);
397397
auto input0_pshape = node.get_input_pshape(0);
398398

399-
if (!use_new_shape_infer && input0_pshape.size() < 4)
400-
input0_pshape.insert(input0_pshape.end(), 4 - input0_pshape.size(), 1);
401-
402399
for (size_t i = 1; i < inputs_count; ++i) {
403400
auto input_pshape = node.get_input_pshape(i);
404401

src/plugins/intel_gpu/src/graph/reshape.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,11 @@ layout reshape_inst::calc_output_layout(reshape_node const& node, kernel_impl_pa
115115
if (desc->output_shape.count() == 0) {
116116
if (desc->output_partial_shape.size() != 0) {
117117
format out_fmt = format::adjust_to_rank(input_layout.format, desc->output_partial_shape.rank().get_length());
118-
return layout{desc->output_partial_shape, input_layout.data_type, out_fmt};
118+
auto use_new_shape_infer = node.get_program().is_new_shape_infer();
119+
auto output_shape = desc->output_partial_shape;
120+
if (!use_new_shape_infer && output_shape.size() < 4)
121+
output_shape.insert(output_shape.end(), 4 - output_shape.size(), 1);
122+
return layout{output_shape, input_layout.data_type, out_fmt};
119123
} else {
120124
OPENVINO_ASSERT("[GPU] Output shape is not provided");
121125
}

src/plugins/intel_gpu/tests/unit/test_cases/eltwise_gpu_test.cpp

+11-8
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
#include <intel_gpu/primitives/eltwise.hpp>
1111
#include <intel_gpu/primitives/gather.hpp>
1212
#include <intel_gpu/primitives/reorder.hpp>
13+
#include <intel_gpu/primitives/reshape.hpp>
1314
#include <intel_gpu/primitives/data.hpp>
1415

1516
#include "eltwise_inst.h"
17+
#include "reshape_inst.h"
1618

1719
using namespace cldnn;
1820
using namespace ::tests;
@@ -3442,7 +3444,7 @@ TEST(eltwise_gpu_f32, broadcast_test_dim3_dim4) {
34423444
TEST(eltwise_gpu_f32, broadcast_test_dim3_dim4_new_shape_infer_false) {
34433445
auto& engine = get_test_engine();
34443446

3445-
ov::Shape in2_shape = {1, 4, 1, 1};
3447+
ov::Shape in2_shape = {1, 1, 4, 1};
34463448
auto input2 = engine.allocate_memory({ ov::PartialShape(in2_shape), data_types::f32, format::bfyx });
34473449

34483450
std::vector<float> const_input = {
@@ -3457,26 +3459,27 @@ TEST(eltwise_gpu_f32, broadcast_test_dim3_dim4_new_shape_infer_false) {
34573459
});
34583460

34593461
float answers[16] = {
3460-
1.5, 0.5, 7.5, 4,
3461-
2.5, 0.5, 8.5, 7.7,
3462-
3.5, 1, 9.5, 14.5,
3463-
4.5, 0, 10.5, 10.5
3462+
1.5, 2.5, 5.5, 4,
3463+
2.5, 2.5, 6.5, 7.7,
3464+
3.5, 3, 7.5, 14.5,
3465+
4.5, 2, 8.5, 10.5
34643466
};
34653467

34663468
ExecutionConfig config = get_test_default_config(engine);
34673469
config.set_property(ov::intel_gpu::allow_new_shape_infer(false));
34683470

3469-
// in1:dim3, int2:dim4
3471+
// Eltwise in1:dim3, int2:dim4
34703472
{
3471-
ov::Shape in1_shape = {2, 4, 2};
3473+
ov::Shape in1_shape = {1, 2, 2, 4};
34723474

34733475
auto input = engine.allocate_memory({ ov::PartialShape(in1_shape), data_types::f32, format::bfyx });
34743476
set_values(input, const_input);
34753477

34763478
topology topology;
34773479
topology.add(input_layout("input", input->get_layout()));
34783480
topology.add(input_layout("input2", input2->get_layout()));
3479-
topology.add(eltwise("eltwise", { input_info("input"), input_info("input2") }, eltwise_mode::sum));
3481+
topology.add(reshape("reshape_input1", input_info("input"), false, {}, ov::PartialShape({2, 2, 4})));
3482+
topology.add(eltwise("eltwise", { input_info("reshape_input1"), input_info("input2") }, eltwise_mode::sum));
34803483

34813484
network network(engine, topology, config);
34823485

0 commit comments

Comments
 (0)