Skip to content

Commit 133b139

Browse files
[GPU] Extend gemm to fuse broadcast and reshape layers (openvinotoolkit#23513)
### Details: - Fuse `broadcast` and `reshape` layers into `gemm` layer for LLM's 2nd latency optimization - before : [`broadcast`] --> [`reshape`] --> `gemm` - after : `gemm` - `gemm` is extended to have `input0_target_shape`, `input1_target_shape`, `input0_output_pattern` and `input1_output_pattern` from `broadcast` and `reshape` layers ### Tickets: - 128343 --------- Signed-off-by: Andrew Park <andrew.park@intel.com>
1 parent dbef32e commit 133b139

19 files changed

+925
-109
lines changed

src/plugins/intel_gpu/include/intel_gpu/op/gemm.hpp

+26-3
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,30 @@ class Gemm : public ov::op::v0::MatMul {
2626
const std::vector<int64_t>& order_c,
2727
const ov::element::Type output_type = ov::element::undefined);
2828

29+
Gemm(const ov::Output<Node>& A,
30+
const ov::Output<Node>& B,
31+
const std::vector<int32_t>& target_shape_a,
32+
const std::vector<int32_t>& target_shape_b,
33+
const std::vector<int64_t>& output_pattern_a,
34+
const std::vector<int64_t>& output_pattern_b,
35+
const std::vector<int64_t>& order_a,
36+
const std::vector<int64_t>& order_b,
37+
const std::vector<int64_t>& order_c,
38+
const ov::element::Type output_type = ov::element::undefined);
39+
2940
bool visit_attributes(ov::AttributeVisitor &visitor) override;
3041

3142
void validate_and_infer_types() override;
3243

3344
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;
3445

35-
std::vector<int64_t> get_input0_order() const { return m_order_a; }
36-
std::vector<int64_t> get_input1_order() const { return m_order_b; }
37-
std::vector<int64_t> get_output_order() const { return m_order_c; }
46+
std::vector<int32_t> get_input0_broadcast_target_shape() const { return m_target_shape_a; }
47+
std::vector<int32_t> get_input1_broadcast_target_shape() const { return m_target_shape_b; }
48+
std::vector<int64_t> get_input0_reshape_pattern() const { return m_output_pattern_a; }
49+
std::vector<int64_t> get_input1_reshape_pattern() const { return m_output_pattern_b; }
50+
std::vector<int64_t> get_input0_transpose_order() const { return m_order_a; }
51+
std::vector<int64_t> get_input1_transpose_order() const { return m_order_b; }
52+
std::vector<int64_t> get_output_transpose_order() const { return m_order_c; }
3853
ov::element::Type get_output_type() const { return m_output_type; }
3954

4055
static std::vector<int64_t> default_order(size_t rank) {
@@ -44,6 +59,10 @@ class Gemm : public ov::op::v0::MatMul {
4459
}
4560

4661
protected:
62+
std::vector<int32_t> m_target_shape_a;
63+
std::vector<int32_t> m_target_shape_b;
64+
std::vector<int64_t> m_output_pattern_a;
65+
std::vector<int64_t> m_output_pattern_b;
4766
std::vector<int64_t> m_order_a;
4867
std::vector<int64_t> m_order_b;
4968
std::vector<int64_t> m_order_c;
@@ -52,6 +71,10 @@ class Gemm : public ov::op::v0::MatMul {
5271

5372
std::vector<ov::PartialShape> shape_infer(const Gemm* op,
5473
std::vector<ov::PartialShape> input_shapes,
74+
const std::vector<int32_t>& target_shape_a,
75+
const std::vector<int32_t>& target_shape_b,
76+
const std::vector<int64_t>& output_pattern_a,
77+
const std::vector<int64_t>& output_pattern_b,
5578
const std::vector<int64_t>& order_a,
5679
const std::vector<int64_t>& order_b,
5780
const std::vector<int64_t>& order_c);

src/plugins/intel_gpu/include/intel_gpu/primitives/gemm.hpp

+71-38
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ struct gemm : public primitive_base<gemm> {
5454
: primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}),
5555
transpose_input0(transpose_input0 ? 1 : 0),
5656
transpose_input1(transpose_input1 ? 1 : 0),
57+
input0_broadcast_target_shape({}),
58+
input1_broadcast_target_shape({}),
59+
input0_reshape_pattern({}),
60+
input1_reshape_pattern({}),
5761
alpha(alpha),
5862
beta(beta),
5963
input_rank(input_rank),
@@ -70,9 +74,9 @@ struct gemm : public primitive_base<gemm> {
7074
return order;
7175
};
7276

73-
input0_order = get_transposed_order(input_rank, transpose_input0);
74-
input1_order = get_transposed_order(weight_rank, transpose_input1);
75-
output_order = {};
77+
input0_transpose_order = get_transposed_order(input_rank, transpose_input0);
78+
input1_transpose_order = get_transposed_order(weight_rank, transpose_input1);
79+
output_transpose_order = {};
7680
}
7781

7882
/// @brief Constructs gemm layer.
@@ -86,69 +90,89 @@ struct gemm : public primitive_base<gemm> {
8690
gemm(const primitive_id& id,
8791
const std::vector<input_info>& inputs,
8892
const data_types data_type,
89-
const std::vector<int64_t>& input0_order = {0, 1, 2, 3},
90-
const std::vector<int64_t>& input1_order = {0, 1, 2, 3},
91-
const std::vector<int64_t>& output_order = {},
93+
const std::vector<int32_t>& input0_broadcast_target_shape = {},
94+
const std::vector<int32_t>& input1_broadcast_target_shape = {},
95+
const std::vector<int64_t>& input0_reshape_pattern = {},
96+
const std::vector<int64_t>& input1_reshape_pattern = {},
97+
const std::vector<int64_t>& input0_transpose_order = {0, 1, 2, 3},
98+
const std::vector<int64_t>& input1_transpose_order = {0, 1, 2, 3},
99+
const std::vector<int64_t>& output_transpose_order = {},
92100
const float alpha = 1.0f,
93101
const float beta = 0.0f,
94102
const padding& output_padding = padding())
95103
: primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}),
96-
input0_order(input0_order),
97-
input1_order(input1_order),
98-
output_order(output_order),
104+
input0_broadcast_target_shape(input0_broadcast_target_shape),
105+
input1_broadcast_target_shape(input1_broadcast_target_shape),
106+
input0_reshape_pattern(input0_reshape_pattern),
107+
input1_reshape_pattern(input1_reshape_pattern),
108+
input0_transpose_order(input0_transpose_order),
109+
input1_transpose_order(input1_transpose_order),
110+
output_transpose_order(output_transpose_order),
99111
alpha(alpha),
100112
beta(beta),
101-
input_rank(input0_order.size()),
102-
weight_rank(input1_order.size()) {
113+
input_rank(input0_transpose_order.size()),
114+
weight_rank(input1_transpose_order.size()) {
103115
if (inputs.size() != 2 && inputs.size() != 3) {
104116
throw std::invalid_argument("Invalid inputs count - gemm expects either two or three inputs");
105117
}
106118

107-
transpose_input0 = get_transpose_mode(input0_order);
108-
transpose_input1 = get_transpose_mode(input1_order);
119+
transpose_input0 = get_transpose_mode(input0_transpose_order);
120+
transpose_input1 = get_transpose_mode(input1_transpose_order);
109121
}
110122

111123
gemm(const primitive_id& id,
112124
const std::vector<input_info>& inputs,
113125
const input_info& beam_table,
114126
const data_types data_type,
115-
const std::vector<int64_t>& input0_order,
116-
const std::vector<int64_t>& input1_order,
117-
const std::vector<int64_t>& output_order,
127+
const std::vector<int64_t>& input0_transpose_order,
128+
const std::vector<int64_t>& input1_transpose_order,
129+
const std::vector<int64_t>& output_transpose_order,
118130
bool indirect_a,
119131
bool indirect_b,
120132
const float alpha = 1.0f,
121133
const float beta = 0.0f,
122134
const padding& output_padding = padding())
123135
: primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}),
124-
input0_order(input0_order),
125-
input1_order(input1_order),
126-
output_order(output_order),
136+
input0_broadcast_target_shape({}),
137+
input1_broadcast_target_shape({}),
138+
input0_reshape_pattern({}),
139+
input1_reshape_pattern({}),
140+
input0_transpose_order(input0_transpose_order),
141+
input1_transpose_order(input1_transpose_order),
142+
output_transpose_order(output_transpose_order),
127143
alpha(alpha),
128144
beta(beta),
129-
input_rank(input0_order.size()),
130-
weight_rank(input1_order.size()),
145+
input_rank(input0_transpose_order.size()),
146+
weight_rank(input1_transpose_order.size()),
131147
beam_table(beam_table),
132148
indirect_a(indirect_a),
133149
indirect_b(indirect_b) {
134150
if (inputs.size() != 2 && inputs.size() != 3) {
135151
throw std::invalid_argument("Invalid inputs count - gemm expects either two or three inputs");
136152
}
137153

138-
transpose_input0 = get_transpose_mode(input0_order);
139-
transpose_input1 = get_transpose_mode(input1_order);
154+
transpose_input0 = get_transpose_mode(input0_transpose_order);
155+
transpose_input1 = get_transpose_mode(input1_transpose_order);
140156
}
141157

142158
/// @brief Flag for transposing first input matrix
143159
uint32_t transpose_input0 = 0;
144160
/// @brief Flag for transposing second input matrix
145161
uint32_t transpose_input1 = 0;
162+
/// @brief broadcasted target shape of input 0
163+
std::vector<int32_t> input0_broadcast_target_shape;
164+
/// @brief broadcasted target shape of input 1
165+
std::vector<int32_t> input1_broadcast_target_shape;
166+
/// @brief reshaped output pattern of input 0
167+
std::vector<int64_t> input0_reshape_pattern;
168+
/// @brief reshaped output pattern of input 1
169+
std::vector<int64_t> input1_reshape_pattern;
146170
/// @brief order of input 0
147-
std::vector<int64_t> input0_order;
171+
std::vector<int64_t> input0_transpose_order;
148172
/// @brief order of input 1
149-
std::vector<int64_t> input1_order;
173+
std::vector<int64_t> input1_transpose_order;
150174
/// @brief order of output
151-
std::vector<int64_t> output_order;
175+
std::vector<int64_t> output_transpose_order;
152176
/// @brief Variable containing ALPHA parameter
153177
float alpha = 1.0f;
154178
/// @brief Variable containing BETA parameter
@@ -169,12 +193,13 @@ struct gemm : public primitive_base<gemm> {
169193
seed = hash_combine(seed, transpose_input1);
170194
seed = hash_combine(seed, indirect_a);
171195
seed = hash_combine(seed, indirect_b);
172-
for (auto order : input0_order)
173-
seed = hash_combine(seed, order);
174-
for (auto order : input1_order)
175-
seed = hash_combine(seed, order);
176-
for (auto order : output_order)
177-
seed = hash_combine(seed, order);
196+
seed = hash_range(seed, input0_broadcast_target_shape.begin(), input0_broadcast_target_shape.end());
197+
seed = hash_range(seed, input1_broadcast_target_shape.begin(), input1_broadcast_target_shape.end());
198+
seed = hash_range(seed, input0_reshape_pattern.begin(), input0_reshape_pattern.end());
199+
seed = hash_range(seed, input1_reshape_pattern.begin(), input1_reshape_pattern.end());
200+
seed = hash_range(seed, input0_transpose_order.begin(), input0_transpose_order.end());
201+
seed = hash_range(seed, input1_transpose_order.begin(), input1_transpose_order.end());
202+
seed = hash_range(seed, output_transpose_order.begin(), output_transpose_order.end());
178203
seed = hash_combine(seed, alpha);
179204
seed = hash_combine(seed, beta);
180205
return seed;
@@ -200,9 +225,13 @@ struct gemm : public primitive_base<gemm> {
200225
primitive_base<gemm>::save(ob);
201226
ob << transpose_input0;
202227
ob << transpose_input1;
203-
ob << input0_order;
204-
ob << input1_order;
205-
ob << output_order;
228+
ob << input0_broadcast_target_shape;
229+
ob << input1_broadcast_target_shape;
230+
ob << input0_reshape_pattern;
231+
ob << input1_reshape_pattern;
232+
ob << input0_transpose_order;
233+
ob << input1_transpose_order;
234+
ob << output_transpose_order;
206235
ob << alpha;
207236
ob << beta;
208237
ob << input_rank;
@@ -217,9 +246,13 @@ struct gemm : public primitive_base<gemm> {
217246
primitive_base<gemm>::load(ib);
218247
ib >> transpose_input0;
219248
ib >> transpose_input1;
220-
ib >> input0_order;
221-
ib >> input1_order;
222-
ib >> output_order;
249+
ib >> input0_broadcast_target_shape;
250+
ib >> input1_broadcast_target_shape;
251+
ib >> input0_reshape_pattern;
252+
ib >> input1_reshape_pattern;
253+
ib >> input0_transpose_order;
254+
ib >> input1_transpose_order;
255+
ib >> output_transpose_order;
223256
ib >> alpha;
224257
ib >> beta;
225258
ib >> input_rank;

0 commit comments

Comments
 (0)