@@ -100,225 +100,6 @@ bool key_t::operator==(const key_t &rhs) const {
100
100
return true ;
101
101
}
102
102
103
- // Combine hash of each memory_desc_t data member
104
- size_t get_md_hash (const memory_desc_t &md) {
105
- size_t seed = 0 ;
106
- seed = get_array_hash (seed, md.dims , md.ndims );
107
- seed = hash_combine (seed, static_cast <size_t >(md.data_type ));
108
- seed = get_array_hash (seed, md.padded_dims , md.ndims );
109
- seed = get_array_hash (seed, md.padded_offsets , md.ndims );
110
- seed = hash_combine (seed, md.offset0 );
111
- seed = hash_combine (seed, static_cast <size_t >(md.format_kind ));
112
- // format desc
113
- switch ((int )md.format_kind ) {
114
- case format_kind::undef:
115
- case format_kind::any: break ;
116
- case format_kind::blocked:
117
- for (int i = 0 ; i < md.ndims ; i++) {
118
- if (md.dims [i] == 1 && md.padded_dims [i] == 1 ) continue ;
119
- seed = hash_combine (seed, md.format_desc .blocking .strides [i]);
120
- }
121
- seed = hash_combine (seed, md.format_desc .blocking .inner_nblks );
122
- seed = get_array_hash (seed, md.format_desc .blocking .inner_blks ,
123
- md.format_desc .blocking .inner_nblks );
124
- seed = get_array_hash (seed, md.format_desc .blocking .inner_idxs ,
125
- md.format_desc .blocking .inner_nblks );
126
- break ;
127
- case format_kind::wino:
128
- seed = hash_combine (seed,
129
- static_cast <size_t >(md.format_desc .wino_desc .wino_format ));
130
- seed = hash_combine (seed, md.format_desc .wino_desc .r );
131
- seed = hash_combine (seed, md.format_desc .wino_desc .alpha );
132
- seed = hash_combine (seed, md.format_desc .wino_desc .ic );
133
- seed = hash_combine (seed, md.format_desc .wino_desc .oc );
134
- seed = hash_combine (seed, md.format_desc .wino_desc .ic_block );
135
- seed = hash_combine (seed, md.format_desc .wino_desc .oc_block );
136
- seed = hash_combine (seed, md.format_desc .wino_desc .ic2_block );
137
- seed = hash_combine (seed, md.format_desc .wino_desc .oc2_block );
138
- seed = hash_combine (seed, md.format_desc .wino_desc .adj_scale );
139
- seed = hash_combine (seed, md.format_desc .wino_desc .size );
140
- break ;
141
- case format_kind::rnn_packed:
142
- seed = hash_combine (seed,
143
- static_cast <size_t >(md.format_desc .rnn_packed_desc .format ));
144
- seed = hash_combine (seed, md.format_desc .rnn_packed_desc .n_parts );
145
- seed = hash_combine (seed, md.format_desc .rnn_packed_desc .n );
146
- seed = hash_combine (seed, md.format_desc .rnn_packed_desc .ldb );
147
- {
148
- int n_parts = md.format_desc .rnn_packed_desc .n_parts ;
149
- seed = get_array_hash (
150
- seed, md.format_desc .rnn_packed_desc .parts , n_parts);
151
- seed = get_array_hash (seed,
152
- md.format_desc .rnn_packed_desc .part_pack_size , n_parts);
153
- seed = get_array_hash (seed,
154
- md.format_desc .rnn_packed_desc .pack_part , n_parts);
155
- }
156
- seed = hash_combine (
157
- seed, md.format_desc .rnn_packed_desc .offset_compensation );
158
- seed = hash_combine (seed, md.format_desc .rnn_packed_desc .size );
159
- break ;
160
- #ifdef DNNL_EXPERIMENTAL_SPARSE
161
- case format_kind::sparse:
162
- seed = hash_combine (seed,
163
- static_cast <size_t >(md.format_desc .sparse_desc .encoding ));
164
- seed = hash_combine (seed, md.format_desc .sparse_desc .nnz );
165
- seed = get_array_hash (seed,
166
- md.format_desc .sparse_desc .metadata_types ,
167
- sparse_desc_t ::max_metadata_types);
168
- // User cannot initialize `packed_desc` therefore `packed_desc`
169
- // is always zero initialized.
170
- break ;
171
- #endif
172
- default : assert (!" unknown format_kind" );
173
- }
174
-
175
- if (md.extra .flags != dnnl_memory_extra_flag_none) {
176
- seed = hash_combine (seed, md.extra .flags );
177
- if ((md.extra .flags
178
- & (dnnl_memory_extra_flag_compensation_conv_s8s8
179
- | dnnl_memory_extra_flag_rnn_u8s8_compensation))
180
- && !types::extra_flag_rnn_s8s8_compensation_is_set (
181
- md.extra .flags )) {
182
- seed = hash_combine (seed, md.extra .compensation_mask );
183
- }
184
-
185
- if (md.extra .flags & dnnl_memory_extra_flag_scale_adjust) {
186
- seed = hash_combine (seed, md.extra .scale_adjust );
187
- }
188
-
189
- if (md.extra .flags
190
- & dnnl_memory_extra_flag_compensation_conv_asymmetric_src) {
191
- seed = hash_combine (seed, md.extra .asymm_compensation_mask );
192
- }
193
- }
194
- // Combined hash for a memory descriptor
195
- return seed;
196
- }
197
-
198
- // Combine hash of each primitive_attr_t data member
199
- size_t get_attr_hash (const primitive_attr_t &attr) {
200
- size_t seed = 0 ;
201
- // scratchpad_mode
202
- seed = hash_combine (seed, static_cast <size_t >(attr.scratchpad_mode_ ));
203
- // fpmath_mode
204
- seed = hash_combine (seed, static_cast <size_t >(attr.fpmath_ .mode_ ));
205
- seed = hash_combine (seed, static_cast <size_t >(attr.fpmath_ .apply_to_int_ ));
206
- // deterministic
207
- seed = hash_combine (seed, static_cast <size_t >(attr.deterministic_ ));
208
- // acc_mode
209
- seed = hash_combine (seed, static_cast <size_t >(attr.acc_mode_ ));
210
-
211
- if (!attr.output_scales_ .has_default_values ()) {
212
- // output_scales: mask
213
- seed = hash_combine (seed, attr.output_scales_ .mask_ );
214
- } else if (!attr.scales_ .has_default_values ()) {
215
- // go through scales for all arguments
216
- for (const auto &p : attr.scales_ .scales_ ) {
217
- // scales: arg
218
- seed = hash_combine (seed, p.first );
219
- // scales: mask
220
- seed = hash_combine (seed, p.second .mask_ );
221
- // scales: groups
222
- const int ndims = p.second .ndims_ ;
223
- seed = hash_combine (seed, ndims);
224
- if (ndims > 0 )
225
- seed = get_array_hash (seed, p.second .group_dims_ , ndims);
226
- // scales: data type
227
- seed = hash_combine (seed, static_cast <size_t >(p.second .data_type_ ));
228
- }
229
- }
230
- // zero_points
231
- for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST})
232
- if (!attr.zero_points_ .has_default_values (arg)) {
233
- const auto &zps = attr.zero_points_ ;
234
- // zero_points: arg
235
- seed = hash_combine (seed, arg);
236
- int mask = 0 ;
237
- zps.get (arg, &mask);
238
- // zero_points: mask
239
- seed = hash_combine (seed, mask);
240
- // zero points: groups
241
- const int ndims = zps.get_groups_ndims (arg);
242
- seed = hash_combine (seed, ndims);
243
- if (ndims > 0 )
244
- seed = get_array_hash (seed, zps.get_groups (arg), ndims);
245
- // zero points: data type
246
- seed = hash_combine (
247
- seed, static_cast <size_t >(zps.get_data_type (arg)));
248
- }
249
- // post_ops: entry[:]
250
- for (int i = 0 ; i < attr.post_ops_ .len (); i++) {
251
- const auto &entry = attr.post_ops_ .entry_ [i];
252
- switch (entry.kind ) {
253
- case primitive_kind::eltwise:
254
- seed = hash_combine (
255
- seed, static_cast <size_t >(entry.eltwise .alg ));
256
- seed = hash_combine (seed, entry.eltwise .scale );
257
- seed = hash_combine (seed, entry.eltwise .alpha );
258
- seed = hash_combine (seed, entry.eltwise .beta );
259
- break ;
260
- case primitive_kind::sum:
261
- seed = hash_combine (seed, entry.sum .scale );
262
- seed = hash_combine (seed, entry.sum .zero_point );
263
- seed = hash_combine (seed, static_cast <size_t >(entry.sum .dt ));
264
- break ;
265
- case primitive_kind::convolution:
266
- seed = hash_combine (
267
- seed, static_cast <size_t >(entry.depthwise_conv .kernel ));
268
- seed = hash_combine (
269
- seed, static_cast <size_t >(entry.depthwise_conv .stride ));
270
- seed = hash_combine (seed,
271
- static_cast <size_t >(entry.depthwise_conv .padding ));
272
- seed = hash_combine (
273
- seed, static_cast <size_t >(entry.depthwise_conv .wei_dt ));
274
- seed = hash_combine (seed,
275
- static_cast <size_t >(entry.depthwise_conv .bias_dt ));
276
- seed = hash_combine (
277
- seed, static_cast <size_t >(entry.depthwise_conv .dst_dt ));
278
- break ;
279
- case primitive_kind::binary:
280
- seed = hash_combine (
281
- seed, static_cast <size_t >(entry.binary .alg ));
282
- seed = hash_combine (
283
- seed, get_md_hash (entry.binary .user_src1_desc ));
284
- break ;
285
- case primitive_kind::prelu:
286
- seed = hash_combine (
287
- seed, static_cast <size_t >(entry.prelu .mask ));
288
- break ;
289
- case primitive_kind::depthwise:
290
- seed = hash_combine (seed, static_cast <size_t >(entry.depthwise .alg ));
291
- seed = hash_combine (seed, reinterpret_cast <size_t >(entry.depthwise .weights_data ));
292
- seed = hash_combine (seed, reinterpret_cast <size_t >(entry.depthwise .biases_data ));
293
- break ;
294
- case primitive_kind::quantization:
295
- seed = hash_combine (seed, static_cast <size_t >(entry.quantization .alg ));
296
- seed = get_array_hash (seed, entry.quantization .per_channel , entry.quantization .fields_count );
297
- seed = get_array_hash (seed, entry.quantization .all_default , entry.quantization .fields_count );
298
- seed = get_array_hash (seed, entry.quantization .data , entry.quantization .fields_count );
299
- break ;
300
- default : assert (!" unknown post_op" );
301
- }
302
- }
303
- // rnn_data_qparams: scale, shift
304
- seed = hash_combine (seed, attr.rnn_data_qparams_ .scale_ );
305
- seed = hash_combine (seed, attr.rnn_data_qparams_ .shift_ );
306
- if (!attr.rnn_weights_qparams_ .has_default_values ()) {
307
- // rnn_weights_qparams: mask
308
- seed = hash_combine (seed, attr.rnn_weights_qparams_ .mask_ );
309
- // rnn_weights_qparams: count
310
- seed = hash_combine (seed, attr.rnn_weights_qparams_ .count_ );
311
- // rnn_weights_qparams: scales[:]
312
- seed = get_array_hash (seed, attr.rnn_weights_qparams_ .scales_ ,
313
- attr.rnn_weights_qparams_ .count_ );
314
- }
315
- if (attr.gpu_attr_ ) {
316
- seed = hash_combine (seed, attr.gpu_attr_ ->get_hash ());
317
- }
318
- // Combined hash for attributes
319
- return seed;
320
- }
321
-
322
103
// Functions that compute hash for different op_descs
323
104
size_t get_desc_hash (const concat_desc_t &desc) {
324
105
size_t seed = 0 ;
0 commit comments