Skip to content

Commit 2182709

Browse files
authored
[GPU] Separate input and weight rank check for reordered case in gemm (openvinotoolkit#29430)
### Details: - Separate input and weight rank check for reordered case in gemm to avoid exception from non reordered input (input0 in below layer) ![image](https://github.com/user-attachments/assets/283e7a7c-1bb2-42d7-a655-3dad6531850f) - convert_data_tensor() returns wrong data tensor when size of format.dims_order() is not same with shape.size(). So gemm need to set proper format for shape.size() in input/output layout. ### Tickets: - 163982
1 parent 89c9c7a commit 2182709

File tree

5 files changed

+27
-7
lines changed

5 files changed

+27
-7
lines changed

src/plugins/intel_gpu/src/graph/gemm.cpp

+14-3
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,12 @@ std::vector<layout> gemm_inst::calc_output_layouts(gemm_node const& node, const
132132
prim->output_transpose_order);
133133

134134
cldnn::format output_format = input0_layout.format;
135+
if (output_shapes[0].size() > output_format.dimension()) {
136+
// This happened when input0.rank=2, input1.rank=5, output0.rank=5.
137+
// Output should use format like bfzyx, but it was taken from input0_layout which is bfyx.
138+
// Therefore, adjust output_format to proper rank.(say, bfzyx)
139+
output_format = cldnn::format::adjust_to_rank(output_format, output_shapes[0].size());
140+
}
135141
if (node.get_preferred_output_fmt() != format::any)
136142
output_format = node.get_preferred_output_fmt();
137143

@@ -141,7 +147,8 @@ std::vector<layout> gemm_inst::calc_output_layouts(gemm_node const& node, const
141147
template std::vector<layout> gemm_inst::calc_output_layouts<ov::PartialShape>(gemm_node const& node, const kernel_impl_params& impl_param);
142148

143149
std::vector<layout> gemm_inst::transform_input_layouts(const std::shared_ptr<const gemm> primitive,
144-
const std::vector<layout>& input_layouts) {
150+
const std::vector<layout>& input_layouts,
151+
const bool allow_new_shape_infer) {
145152
auto get_transposed_input_shape = [&](const ov::PartialShape& input_pshape, size_t input_rank, size_t output_rank, bool transpose, bool first_input) {
146153
ov::PartialShape transposed_input_pshape;
147154

@@ -181,13 +188,17 @@ std::vector<layout> gemm_inst::transform_input_layouts(const std::shared_ptr<con
181188

182189
bool reordered = primitive->input_rank > 4 || primitive->weight_rank > 4;
183190
size_t output_rank = std::max(primitive->input_rank, primitive->weight_rank);
184-
size_t input_rank = reordered ? output_rank : primitive->input_rank;
185-
size_t weight_rank = reordered ? output_rank : primitive->weight_rank;
191+
// No need to get output_rank for rank>4 inputs when allow_new_shape_infer=true
192+
size_t input_rank = (reordered && !allow_new_shape_infer) ? output_rank : primitive->input_rank;
193+
size_t weight_rank = (reordered && !allow_new_shape_infer) ? output_rank : primitive->weight_rank;
186194

187195
auto transposed_input0_pshape = get_transposed_input_shape(input0_pshape, input_rank, output_rank, primitive->transpose_input0, true);
188196
auto transposed_input1_pshape = get_transposed_input_shape(input1_pshape, weight_rank, output_rank, primitive->transpose_input1, false);
189197

190198
std::vector<layout> layouts = input_layouts;
199+
// Format update for rank > 4 case
200+
if (layouts[0].format.dimension() < transposed_input0_pshape.size())
201+
layouts[0].format = cldnn::format::get_default_format(transposed_input0_pshape.size());
191202
layouts[0].set_partial_shape(transposed_input0_pshape);
192203
layouts[1].set_partial_shape(transposed_input1_pshape);
193204

src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ struct gemm_impl : multi_stage_primitive<gemm> {
275275
const auto& primitive = impl_params.typed_desc<gemm>();
276276
auto updated_impl_params = canonicalize_fused_shapes(impl_params);
277277

278-
updated_impl_params.input_layouts = gemm_inst::transform_input_layouts(primitive, impl_params.input_layouts);
278+
updated_impl_params.input_layouts = gemm_inst::transform_input_layouts(primitive, impl_params.input_layouts,
279+
impl_params.get_program().is_new_shape_infer());
279280
updated_impl_params.output_layouts[0] = gemm_inst::transform_output_layout(primitive, updated_impl_params.input_layouts, impl_params.output_layouts[0]);
280281

281282
for (auto& input_layout : updated_impl_params.input_layouts) {

src/plugins/intel_gpu/src/graph/impls/onednn/gemm_onednn.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ struct gemm_onednn : typed_primitive_onednn_impl<gemm> {
8484
if (gemm_with_bias) {
8585
in_layouts.emplace_back(impl_params.get_input_layout(2));
8686
}
87-
in_layouts = gemm_inst::transform_input_layouts(prim, in_layouts);
87+
in_layouts = gemm_inst::transform_input_layouts(prim, in_layouts, impl_params.get_program().is_new_shape_infer());
8888
out_l = gemm_inst::transform_output_layout(prim, in_layouts, out_l);
8989

9090
const auto& in0_l = in_layouts[0];

src/plugins/intel_gpu/src/graph/include/gemm_inst.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ class typed_primitive_inst<gemm> : public typed_primitive_inst_base<gemm> {
3434
static std::string to_string(gemm_node const& node);
3535

3636
static std::vector<layout> transform_input_layouts(const std::shared_ptr<const gemm> primitive,
37-
const std::vector<layout>& input_layouts);
37+
const std::vector<layout>& input_layouts,
38+
const bool allow_new_shape_infer);
3839
static layout transform_output_layout(const std::shared_ptr<const gemm> primitive, const std::vector<layout>& input_layouts, const layout& output_layout);
3940

4041
static bool is_fusable_permute_input_order_onednn(const std::vector<size_t>& permute_order, format& fmt) {

src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/matmul.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,14 @@ const std::vector<ShapeRelatedParams> IS_Dynamic = {
575575
{{ {1, 5}, 12, -1, 4 }, {{ 1, 12, 16, 4 }, { 1, 12, 16, 4 }}} // input 1
576576
},
577577
{false, false}
578-
}
578+
},
579+
{
580+
{ //dynamic case description each pair per each input has {{dynamic shape}, {{static shape case1}, {static shape case2}, ...}
581+
{{}, {{64, 64}}}, // input 0
582+
{{-1, 128, 33, 64, 1}, {{1, 128, 33, 64, 1}}} // input 1
583+
},
584+
{false, false}
585+
},
579586
};
580587

581588
const std::vector<ShapeRelatedParams> IS_Dynamic_nightly = {

0 commit comments

Comments
 (0)