Skip to content

Commit b991c1e

Browse files
[GPU] Shift padding of crop created during buffer fusing
1 parent c4d6d2b commit b991c1e

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp

+20
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,26 @@ void prepare_buffer_fusing::run(program& p) {
847847
if (user_info.first) {
848848
node.get_users().front()->set_output_layout(user_info.second);
849849
}
850+
851+
// In case that the rank of weight node of gemm is less than 4 and,
852+
// it transforms to extend to 4 dims by adding 1 to begin().
853+
// Therefore, the padding of crop_layout should be shifted properly.
854+
const size_t TDIM = 4;
855+
auto user = node.get_users().front();
856+
if (user->is_type<gemm>() && pred_layout.is_static() && user->get_dependency(1).id().compare(node.id()) == 0) {
857+
auto input_rank = user->get_kernel_impl_params()->typed_desc<gemm>()->weight_rank;
858+
if (input_rank < TDIM) {
859+
std::vector<int32_t> l_pad = {0, 0, 0, 0};
860+
std::vector<int32_t> u_pad = {0, 0, 0, 0};
861+
862+
//shift right
863+
size_t shift_right = TDIM - input_rank;
864+
std::copy_n(crop_layout.data_padding._lower_size.begin(), l_pad.size() - shift_right, l_pad.begin() + shift_right);
865+
std::copy_n(crop_layout.data_padding._upper_size.begin(), u_pad.size() - shift_right, u_pad.begin() + shift_right);
866+
867+
crop_layout.data_padding = padding(l_pad, u_pad);
868+
}
869+
}
850870
}
851871
node.set_output_layout(crop_layout);
852872
node.can_be_optimized(true);

src/plugins/intel_gpu/tests/unit/passes/prepare_buffer_fusing_test.cpp

+50
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
#include "intel_gpu/graph/program.hpp"
1313
#include "data_inst.h"
14+
#include "concatenation_inst.h"
15+
#include "gemm_inst.h"
1416
#include "crop_inst.h"
1517
#include "convolution_inst.h"
1618
#include "gather_inst.h"
@@ -707,6 +709,54 @@ TEST(prepare_buffer_fusing, in_place_crop_static) {
707709
ASSERT_EQ(output_ptr_2[i], out2[i]);
708710
}
709711

712+
TEST(prepare_buffer_fusing, in_place_crop_static_padding_and_gemm) {
713+
auto& engine = get_test_engine();
714+
715+
auto gemm_input_mem = engine.allocate_memory({ {1, 4, 4, 2}, data_types::f32, format::bfyx });
716+
auto concat_input_mem = engine.allocate_memory({ {1, 4, 2}, data_types::f32, format::bfyx });
717+
718+
set_values(gemm_input_mem, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f,
719+
1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f,
720+
1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f,
721+
1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f });
722+
set_values(concat_input_mem, { -0.5f, 2.0f, 0.5f, 1.0f, 0.5f, -2.0f, -0.5f, -1.0f });
723+
724+
std::vector<float> expected = { 0.5, 4, 0.5, 10, 0.5, 16, 0.5, 22,
725+
0.5, 4, 0.5, 10, 0.5, 16, 0.5, 22,
726+
0.5, 4, 0.5, 10, 0.5, 16, 0.5, 22,
727+
0.5, 4, 0.5, 10, 0.5, 16, 0.5, 22};
728+
cldnn::tensor refSize = {1, 2, 1, 2};
729+
730+
topology topology(
731+
input_layout("gemm_input", gemm_input_mem->get_layout()),
732+
input_layout("concat_input", concat_input_mem->get_layout()),
733+
concatenation("concat", { input_info("concat_input"), input_info("concat_input") }, 2),
734+
crop("crop", input_info("concat"), refSize, tensor(0, 0, 0, 0)),
735+
gemm("gemm", { input_info("gemm_input"), input_info("crop") }, data_types::f32, false, false, 1.0, 0.0, 4, 3),
736+
reorder("output", input_info("gemm"), format::bfyx, data_types::f32)
737+
);
738+
739+
{
740+
auto config = get_test_default_config(engine);
741+
config.set_property(ov::intel_gpu::optimize_data(true));
742+
network network(engine, topology, config);
743+
744+
network.set_input_data("gemm_input", gemm_input_mem);
745+
network.set_input_data("concat_input", concat_input_mem);
746+
747+
auto outputs = network.execute();
748+
749+
auto crop_prim = network.get_primitive("crop");
750+
ASSERT_EQ(crop_prim->can_be_optimized(), true);
751+
752+
auto output = outputs.at("output").get_memory();
753+
cldnn::mem_lock<float> output_ptr(output, get_test_stream());
754+
for (size_t i = 0; i < expected.size(); i++) {
755+
ASSERT_EQ(output_ptr[i], expected[i]);
756+
}
757+
}
758+
}
759+
710760
TEST(prepare_buffer_fusing, in_place_crop_dynamic) {
711761
auto& engine = get_test_engine();
712762

0 commit comments

Comments
 (0)