Skip to content

Commit 00b7c0f

Browse files
[GPU] Fix reshpae output padding
1 parent 7c0db8b commit 00b7c0f

File tree

2 files changed

+78
-1
lines changed

2 files changed

+78
-1
lines changed

src/plugins/intel_gpu/src/graph/reshape.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,11 @@ std::vector<layout> reshape_inst::calc_output_layouts(reshape_node const& node,
228228
output_format = node.get_preferred_output_fmt();
229229
}
230230

231-
return { layout {output_shapes[0], input_layout.data_type, format::adjust_to_rank(output_format, output_shapes[0].size()), out_pad} };
231+
auto new_out_pad = out_pad;
232+
if (new_out_pad == padding())
233+
new_out_pad = impl_param.get_output_layout(0).data_padding;
234+
235+
return { layout {output_shapes[0], input_layout.data_type, format::adjust_to_rank(output_format, output_shapes[0].size()), new_out_pad} };
232236
}
233237

234238
template std::vector<layout> reshape_inst::calc_output_layouts<ov::PartialShape>(reshape_node const& node, const kernel_impl_params& impl_param);

src/plugins/intel_gpu/tests/unit/test_cases/reshape_gpu_test.cpp

+73
Original file line numberDiff line numberDiff line change
@@ -1688,3 +1688,76 @@ TEST(reshape_gpu_f32, followed_by_convolution_dynamic) {
16881688
}
16891689
}
16901690
}
1691+
1692+
TEST(reshape_gpu_f32, followed_by_convolution_dynamic_w_pad) {
1693+
auto& engine = get_test_engine();
1694+
1695+
ov::Shape in0_shape = { 1, 1, 4, 5 };
1696+
auto in0_dyn_layout = layout{ov::PartialShape::dynamic(in0_shape.size()), data_types::f32, format::bfyx};
1697+
auto weights = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, 3, 2 } });
1698+
set_values(weights, {
1699+
1.0f, 2.0f, 1.0f,
1700+
2.0f, 1.0f, 2.0f
1701+
});
1702+
1703+
topology topology(
1704+
input_layout("input", in0_dyn_layout),
1705+
shape_of("shape_of_input", input_info("input"), data_types::i32),
1706+
reshape("reshape", input_info("input"), input_info("shape_of_input"), false, ov::PartialShape::dynamic(4),
1707+
cldnn::reshape::reshape_mode::base, padding({0, 0, 1, 1}, {0, 0, 2, 2})),
1708+
data("weights", weights),
1709+
pooling("pooling", input_info("weights"), pooling_mode::max, ov::Shape{3, 3}, { 1, 1 }, {0, 0}, {0, 0}, tensor(3, 3, 1, 1), data_types::f32),
1710+
convolution("conv", input_info("reshape"), "pooling", "", 1, { 1, 1 }, {1, 1}, {2, 2}, {0, 0}, false)
1711+
);
1712+
1713+
ExecutionConfig config = get_test_default_config(engine);
1714+
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
1715+
config.set_property(ov::intel_gpu::allow_static_input_reorder(true));
1716+
1717+
network network(engine, topology, config);
1718+
1719+
// execute
1720+
{
1721+
auto input0 = engine.allocate_memory({ in0_shape, data_types::f32, format::bfyx });
1722+
set_values(input0, {
1723+
1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
1724+
2.0f, 2.0f, 3.0f, 4.0f, 6.0f,
1725+
3.0f, 3.0f, 3.0f, 5.0f, 1.0f,
1726+
1.0f, 1.0f, 1.0f, 1.0f, 1.0f
1727+
});
1728+
network.set_input_data("input", input0);
1729+
1730+
auto outputs = network.execute();
1731+
1732+
// check 'conv'
1733+
auto output_memory = outputs.at("conv").get_memory();
1734+
auto output_layout = output_memory->get_layout();
1735+
cldnn::mem_lock<float> output_ptr(output_memory, get_test_stream());
1736+
1737+
int y_size = output_layout.spatial(1);
1738+
int x_size = output_layout.spatial(0);
1739+
int f_size = output_layout.feature();
1740+
int b_size = output_layout.batch();
1741+
1742+
ASSERT_EQ(output_layout.format, format::bfyx);
1743+
ASSERT_EQ(y_size, 6);
1744+
ASSERT_EQ(x_size, 7);
1745+
ASSERT_EQ(f_size, 1);
1746+
ASSERT_EQ(b_size, 1);
1747+
1748+
VVF<float> output_vec = {
1749+
{ 0, 0, 0, 0, 0, 0, 0 },
1750+
{ 0, 0, 0, 0, 0, 0, 0 },
1751+
{ 0, 0, 2, 4, 6, 8, 10 },
1752+
{ 0, 0, 4, 4, 6, 8, 12 },
1753+
{ 0, 0, 6, 6, 6, 10, 2 },
1754+
{ 0, 0, 2, 2, 2, 2, 2 }
1755+
};
1756+
1757+
for (int y = 0; y < y_size; ++y) {
1758+
for (int x = 0; x < x_size; ++x) {
1759+
ASSERT_EQ(output_vec[y][x], output_ptr[y * x_size + x]);
1760+
}
1761+
}
1762+
}
1763+
}

0 commit comments

Comments
 (0)