@@ -54,6 +54,10 @@ struct gemm : public primitive_base<gemm> {
54
54
: primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}),
55
55
transpose_input0 (transpose_input0 ? 1 : 0 ),
56
56
transpose_input1 (transpose_input1 ? 1 : 0 ),
57
+ input0_broadcast_target_shape ({}),
58
+ input1_broadcast_target_shape ({}),
59
+ input0_reshape_pattern ({}),
60
+ input1_reshape_pattern ({}),
57
61
alpha (alpha),
58
62
beta (beta),
59
63
input_rank (input_rank),
@@ -70,9 +74,9 @@ struct gemm : public primitive_base<gemm> {
70
74
return order;
71
75
};
72
76
73
- input0_order = get_transposed_order (input_rank, transpose_input0);
74
- input1_order = get_transposed_order (weight_rank, transpose_input1);
75
- output_order = {};
77
+ input0_transpose_order = get_transposed_order (input_rank, transpose_input0);
78
+ input1_transpose_order = get_transposed_order (weight_rank, transpose_input1);
79
+ output_transpose_order = {};
76
80
}
77
81
78
82
// / @brief Constructs gemm layer.
@@ -86,69 +90,89 @@ struct gemm : public primitive_base<gemm> {
86
90
gemm (const primitive_id& id,
87
91
const std::vector<input_info>& inputs,
88
92
const data_types data_type,
89
- const std::vector<int64_t >& input0_order = {0 , 1 , 2 , 3 },
90
- const std::vector<int64_t >& input1_order = {0 , 1 , 2 , 3 },
91
- const std::vector<int64_t >& output_order = {},
93
+ const std::vector<int32_t >& input0_broadcast_target_shape = {},
94
+ const std::vector<int32_t >& input1_broadcast_target_shape = {},
95
+ const std::vector<int64_t >& input0_reshape_pattern = {},
96
+ const std::vector<int64_t >& input1_reshape_pattern = {},
97
+ const std::vector<int64_t >& input0_transpose_order = {0 , 1 , 2 , 3 },
98
+ const std::vector<int64_t >& input1_transpose_order = {0 , 1 , 2 , 3 },
99
+ const std::vector<int64_t >& output_transpose_order = {},
92
100
const float alpha = 1 .0f ,
93
101
const float beta = 0 .0f ,
94
102
const padding& output_padding = padding())
95
103
: primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}),
96
- input0_order (input0_order),
97
- input1_order (input1_order),
98
- output_order (output_order),
104
+ input0_broadcast_target_shape (input0_broadcast_target_shape),
105
+ input1_broadcast_target_shape (input1_broadcast_target_shape),
106
+ input0_reshape_pattern (input0_reshape_pattern),
107
+ input1_reshape_pattern (input1_reshape_pattern),
108
+ input0_transpose_order (input0_transpose_order),
109
+ input1_transpose_order (input1_transpose_order),
110
+ output_transpose_order (output_transpose_order),
99
111
alpha (alpha),
100
112
beta (beta),
101
- input_rank (input0_order .size()),
102
- weight_rank (input1_order .size()) {
113
+ input_rank (input0_transpose_order .size()),
114
+ weight_rank (input1_transpose_order .size()) {
103
115
if (inputs.size () != 2 && inputs.size () != 3 ) {
104
116
throw std::invalid_argument (" Invalid inputs count - gemm expects either two or three inputs" );
105
117
}
106
118
107
- transpose_input0 = get_transpose_mode (input0_order );
108
- transpose_input1 = get_transpose_mode (input1_order );
119
+ transpose_input0 = get_transpose_mode (input0_transpose_order );
120
+ transpose_input1 = get_transpose_mode (input1_transpose_order );
109
121
}
110
122
111
123
gemm (const primitive_id& id,
112
124
const std::vector<input_info>& inputs,
113
125
const input_info& beam_table,
114
126
const data_types data_type,
115
- const std::vector<int64_t >& input0_order ,
116
- const std::vector<int64_t >& input1_order ,
117
- const std::vector<int64_t >& output_order ,
127
+ const std::vector<int64_t >& input0_transpose_order ,
128
+ const std::vector<int64_t >& input1_transpose_order ,
129
+ const std::vector<int64_t >& output_transpose_order ,
118
130
bool indirect_a,
119
131
bool indirect_b,
120
132
const float alpha = 1 .0f ,
121
133
const float beta = 0 .0f ,
122
134
const padding& output_padding = padding())
123
135
: primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}),
124
- input0_order (input0_order),
125
- input1_order (input1_order),
126
- output_order (output_order),
136
+ input0_broadcast_target_shape ({}),
137
+ input1_broadcast_target_shape ({}),
138
+ input0_reshape_pattern ({}),
139
+ input1_reshape_pattern ({}),
140
+ input0_transpose_order (input0_transpose_order),
141
+ input1_transpose_order (input1_transpose_order),
142
+ output_transpose_order (output_transpose_order),
127
143
alpha (alpha),
128
144
beta (beta),
129
- input_rank (input0_order .size()),
130
- weight_rank (input1_order .size()),
145
+ input_rank (input0_transpose_order .size()),
146
+ weight_rank (input1_transpose_order .size()),
131
147
beam_table (beam_table),
132
148
indirect_a (indirect_a),
133
149
indirect_b (indirect_b) {
134
150
if (inputs.size () != 2 && inputs.size () != 3 ) {
135
151
throw std::invalid_argument (" Invalid inputs count - gemm expects either two or three inputs" );
136
152
}
137
153
138
- transpose_input0 = get_transpose_mode (input0_order );
139
- transpose_input1 = get_transpose_mode (input1_order );
154
+ transpose_input0 = get_transpose_mode (input0_transpose_order );
155
+ transpose_input1 = get_transpose_mode (input1_transpose_order );
140
156
}
141
157
142
158
// / @brief Flag for transposing first input matrix
143
159
uint32_t transpose_input0 = 0 ;
144
160
// / @brief Flag for transposing second input matrix
145
161
uint32_t transpose_input1 = 0 ;
162
+ // / @brief broadcasted target shape of input 0
163
+ std::vector<int32_t > input0_broadcast_target_shape;
164
+ // / @brief broadcasted target shape of input 1
165
+ std::vector<int32_t > input1_broadcast_target_shape;
166
+ // / @brief reshaped output pattern of input 0
167
+ std::vector<int64_t > input0_reshape_pattern;
168
+ // / @brief reshaped output pattern of input 1
169
+ std::vector<int64_t > input1_reshape_pattern;
146
170
// / @brief order of input 0
147
- std::vector<int64_t > input0_order ;
171
+ std::vector<int64_t > input0_transpose_order ;
148
172
// / @brief order of input 1
149
- std::vector<int64_t > input1_order ;
173
+ std::vector<int64_t > input1_transpose_order ;
150
174
// / @brief order of output
151
- std::vector<int64_t > output_order ;
175
+ std::vector<int64_t > output_transpose_order ;
152
176
// / @brief Variable containing ALPHA parameter
153
177
float alpha = 1 .0f ;
154
178
// / @brief Variable containing BETA parameter
@@ -169,12 +193,13 @@ struct gemm : public primitive_base<gemm> {
169
193
seed = hash_combine (seed, transpose_input1);
170
194
seed = hash_combine (seed, indirect_a);
171
195
seed = hash_combine (seed, indirect_b);
172
- for (auto order : input0_order)
173
- seed = hash_combine (seed, order);
174
- for (auto order : input1_order)
175
- seed = hash_combine (seed, order);
176
- for (auto order : output_order)
177
- seed = hash_combine (seed, order);
196
+ seed = hash_range (seed, input0_broadcast_target_shape.begin (), input0_broadcast_target_shape.end ());
197
+ seed = hash_range (seed, input1_broadcast_target_shape.begin (), input1_broadcast_target_shape.end ());
198
+ seed = hash_range (seed, input0_reshape_pattern.begin (), input0_reshape_pattern.end ());
199
+ seed = hash_range (seed, input1_reshape_pattern.begin (), input1_reshape_pattern.end ());
200
+ seed = hash_range (seed, input0_transpose_order.begin (), input0_transpose_order.end ());
201
+ seed = hash_range (seed, input1_transpose_order.begin (), input1_transpose_order.end ());
202
+ seed = hash_range (seed, output_transpose_order.begin (), output_transpose_order.end ());
178
203
seed = hash_combine (seed, alpha);
179
204
seed = hash_combine (seed, beta);
180
205
return seed;
@@ -200,9 +225,13 @@ struct gemm : public primitive_base<gemm> {
200
225
primitive_base<gemm>::save (ob);
201
226
ob << transpose_input0;
202
227
ob << transpose_input1;
203
- ob << input0_order;
204
- ob << input1_order;
205
- ob << output_order;
228
+ ob << input0_broadcast_target_shape;
229
+ ob << input1_broadcast_target_shape;
230
+ ob << input0_reshape_pattern;
231
+ ob << input1_reshape_pattern;
232
+ ob << input0_transpose_order;
233
+ ob << input1_transpose_order;
234
+ ob << output_transpose_order;
206
235
ob << alpha;
207
236
ob << beta;
208
237
ob << input_rank;
@@ -217,9 +246,13 @@ struct gemm : public primitive_base<gemm> {
217
246
primitive_base<gemm>::load (ib);
218
247
ib >> transpose_input0;
219
248
ib >> transpose_input1;
220
- ib >> input0_order;
221
- ib >> input1_order;
222
- ib >> output_order;
249
+ ib >> input0_broadcast_target_shape;
250
+ ib >> input1_broadcast_target_shape;
251
+ ib >> input0_reshape_pattern;
252
+ ib >> input1_reshape_pattern;
253
+ ib >> input0_transpose_order;
254
+ ib >> input1_transpose_order;
255
+ ib >> output_transpose_order;
223
256
ib >> alpha;
224
257
ib >> beta;
225
258
ib >> input_rank;
0 commit comments