Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit c8eb642

Browse files
maxnickluweizhou2016
authored andcommittedJul 29, 2024
[FIX] Hash utility functions were extracted to a separate module for reuse
1 parent b0c50e8 commit c8eb642

5 files changed

+382
-255
lines changed
 

‎src/common/primitive_hashing.cpp

-219
Original file line numberDiff line numberDiff line change
@@ -100,225 +100,6 @@ bool key_t::operator==(const key_t &rhs) const {
100100
return true;
101101
}
102102

103-
// Combine hash of each memory_desc_t data member
104-
size_t get_md_hash(const memory_desc_t &md) {
105-
size_t seed = 0;
106-
seed = get_array_hash(seed, md.dims, md.ndims);
107-
seed = hash_combine(seed, static_cast<size_t>(md.data_type));
108-
seed = get_array_hash(seed, md.padded_dims, md.ndims);
109-
seed = get_array_hash(seed, md.padded_offsets, md.ndims);
110-
seed = hash_combine(seed, md.offset0);
111-
seed = hash_combine(seed, static_cast<size_t>(md.format_kind));
112-
// format desc
113-
switch ((int)md.format_kind) {
114-
case format_kind::undef:
115-
case format_kind::any: break;
116-
case format_kind::blocked:
117-
for (int i = 0; i < md.ndims; i++) {
118-
if (md.dims[i] == 1 && md.padded_dims[i] == 1) continue;
119-
seed = hash_combine(seed, md.format_desc.blocking.strides[i]);
120-
}
121-
seed = hash_combine(seed, md.format_desc.blocking.inner_nblks);
122-
seed = get_array_hash(seed, md.format_desc.blocking.inner_blks,
123-
md.format_desc.blocking.inner_nblks);
124-
seed = get_array_hash(seed, md.format_desc.blocking.inner_idxs,
125-
md.format_desc.blocking.inner_nblks);
126-
break;
127-
case format_kind::wino:
128-
seed = hash_combine(seed,
129-
static_cast<size_t>(md.format_desc.wino_desc.wino_format));
130-
seed = hash_combine(seed, md.format_desc.wino_desc.r);
131-
seed = hash_combine(seed, md.format_desc.wino_desc.alpha);
132-
seed = hash_combine(seed, md.format_desc.wino_desc.ic);
133-
seed = hash_combine(seed, md.format_desc.wino_desc.oc);
134-
seed = hash_combine(seed, md.format_desc.wino_desc.ic_block);
135-
seed = hash_combine(seed, md.format_desc.wino_desc.oc_block);
136-
seed = hash_combine(seed, md.format_desc.wino_desc.ic2_block);
137-
seed = hash_combine(seed, md.format_desc.wino_desc.oc2_block);
138-
seed = hash_combine(seed, md.format_desc.wino_desc.adj_scale);
139-
seed = hash_combine(seed, md.format_desc.wino_desc.size);
140-
break;
141-
case format_kind::rnn_packed:
142-
seed = hash_combine(seed,
143-
static_cast<size_t>(md.format_desc.rnn_packed_desc.format));
144-
seed = hash_combine(seed, md.format_desc.rnn_packed_desc.n_parts);
145-
seed = hash_combine(seed, md.format_desc.rnn_packed_desc.n);
146-
seed = hash_combine(seed, md.format_desc.rnn_packed_desc.ldb);
147-
{
148-
int n_parts = md.format_desc.rnn_packed_desc.n_parts;
149-
seed = get_array_hash(
150-
seed, md.format_desc.rnn_packed_desc.parts, n_parts);
151-
seed = get_array_hash(seed,
152-
md.format_desc.rnn_packed_desc.part_pack_size, n_parts);
153-
seed = get_array_hash(seed,
154-
md.format_desc.rnn_packed_desc.pack_part, n_parts);
155-
}
156-
seed = hash_combine(
157-
seed, md.format_desc.rnn_packed_desc.offset_compensation);
158-
seed = hash_combine(seed, md.format_desc.rnn_packed_desc.size);
159-
break;
160-
#ifdef DNNL_EXPERIMENTAL_SPARSE
161-
case format_kind::sparse:
162-
seed = hash_combine(seed,
163-
static_cast<size_t>(md.format_desc.sparse_desc.encoding));
164-
seed = hash_combine(seed, md.format_desc.sparse_desc.nnz);
165-
seed = get_array_hash(seed,
166-
md.format_desc.sparse_desc.metadata_types,
167-
sparse_desc_t::max_metadata_types);
168-
// User cannot initialize `packed_desc` therefore `packed_desc`
169-
// is always zero initialized.
170-
break;
171-
#endif
172-
default: assert(!"unknown format_kind");
173-
}
174-
175-
if (md.extra.flags != dnnl_memory_extra_flag_none) {
176-
seed = hash_combine(seed, md.extra.flags);
177-
if ((md.extra.flags
178-
& (dnnl_memory_extra_flag_compensation_conv_s8s8
179-
| dnnl_memory_extra_flag_rnn_u8s8_compensation))
180-
&& !types::extra_flag_rnn_s8s8_compensation_is_set(
181-
md.extra.flags)) {
182-
seed = hash_combine(seed, md.extra.compensation_mask);
183-
}
184-
185-
if (md.extra.flags & dnnl_memory_extra_flag_scale_adjust) {
186-
seed = hash_combine(seed, md.extra.scale_adjust);
187-
}
188-
189-
if (md.extra.flags
190-
& dnnl_memory_extra_flag_compensation_conv_asymmetric_src) {
191-
seed = hash_combine(seed, md.extra.asymm_compensation_mask);
192-
}
193-
}
194-
// Combined hash for a memory descriptor
195-
return seed;
196-
}
197-
198-
// Combine hash of each primitive_attr_t data member
199-
size_t get_attr_hash(const primitive_attr_t &attr) {
200-
size_t seed = 0;
201-
// scratchpad_mode
202-
seed = hash_combine(seed, static_cast<size_t>(attr.scratchpad_mode_));
203-
// fpmath_mode
204-
seed = hash_combine(seed, static_cast<size_t>(attr.fpmath_.mode_));
205-
seed = hash_combine(seed, static_cast<size_t>(attr.fpmath_.apply_to_int_));
206-
// deterministic
207-
seed = hash_combine(seed, static_cast<size_t>(attr.deterministic_));
208-
// acc_mode
209-
seed = hash_combine(seed, static_cast<size_t>(attr.acc_mode_));
210-
211-
if (!attr.output_scales_.has_default_values()) {
212-
// output_scales: mask
213-
seed = hash_combine(seed, attr.output_scales_.mask_);
214-
} else if (!attr.scales_.has_default_values()) {
215-
// go through scales for all arguments
216-
for (const auto &p : attr.scales_.scales_) {
217-
// scales: arg
218-
seed = hash_combine(seed, p.first);
219-
// scales: mask
220-
seed = hash_combine(seed, p.second.mask_);
221-
// scales: groups
222-
const int ndims = p.second.ndims_;
223-
seed = hash_combine(seed, ndims);
224-
if (ndims > 0)
225-
seed = get_array_hash(seed, p.second.group_dims_, ndims);
226-
// scales: data type
227-
seed = hash_combine(seed, static_cast<size_t>(p.second.data_type_));
228-
}
229-
}
230-
// zero_points
231-
for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST})
232-
if (!attr.zero_points_.has_default_values(arg)) {
233-
const auto &zps = attr.zero_points_;
234-
// zero_points: arg
235-
seed = hash_combine(seed, arg);
236-
int mask = 0;
237-
zps.get(arg, &mask);
238-
// zero_points: mask
239-
seed = hash_combine(seed, mask);
240-
// zero points: groups
241-
const int ndims = zps.get_groups_ndims(arg);
242-
seed = hash_combine(seed, ndims);
243-
if (ndims > 0)
244-
seed = get_array_hash(seed, zps.get_groups(arg), ndims);
245-
// zero points: data type
246-
seed = hash_combine(
247-
seed, static_cast<size_t>(zps.get_data_type(arg)));
248-
}
249-
// post_ops: entry[:]
250-
for (int i = 0; i < attr.post_ops_.len(); i++) {
251-
const auto &entry = attr.post_ops_.entry_[i];
252-
switch (entry.kind) {
253-
case primitive_kind::eltwise:
254-
seed = hash_combine(
255-
seed, static_cast<size_t>(entry.eltwise.alg));
256-
seed = hash_combine(seed, entry.eltwise.scale);
257-
seed = hash_combine(seed, entry.eltwise.alpha);
258-
seed = hash_combine(seed, entry.eltwise.beta);
259-
break;
260-
case primitive_kind::sum:
261-
seed = hash_combine(seed, entry.sum.scale);
262-
seed = hash_combine(seed, entry.sum.zero_point);
263-
seed = hash_combine(seed, static_cast<size_t>(entry.sum.dt));
264-
break;
265-
case primitive_kind::convolution:
266-
seed = hash_combine(
267-
seed, static_cast<size_t>(entry.depthwise_conv.kernel));
268-
seed = hash_combine(
269-
seed, static_cast<size_t>(entry.depthwise_conv.stride));
270-
seed = hash_combine(seed,
271-
static_cast<size_t>(entry.depthwise_conv.padding));
272-
seed = hash_combine(
273-
seed, static_cast<size_t>(entry.depthwise_conv.wei_dt));
274-
seed = hash_combine(seed,
275-
static_cast<size_t>(entry.depthwise_conv.bias_dt));
276-
seed = hash_combine(
277-
seed, static_cast<size_t>(entry.depthwise_conv.dst_dt));
278-
break;
279-
case primitive_kind::binary:
280-
seed = hash_combine(
281-
seed, static_cast<size_t>(entry.binary.alg));
282-
seed = hash_combine(
283-
seed, get_md_hash(entry.binary.user_src1_desc));
284-
break;
285-
case primitive_kind::prelu:
286-
seed = hash_combine(
287-
seed, static_cast<size_t>(entry.prelu.mask));
288-
break;
289-
case primitive_kind::depthwise:
290-
seed = hash_combine(seed, static_cast<size_t>(entry.depthwise.alg));
291-
seed = hash_combine(seed, reinterpret_cast<size_t>(entry.depthwise.weights_data));
292-
seed = hash_combine(seed, reinterpret_cast<size_t>(entry.depthwise.biases_data));
293-
break;
294-
case primitive_kind::quantization:
295-
seed = hash_combine(seed, static_cast<size_t>(entry.quantization.alg));
296-
seed = get_array_hash(seed, entry.quantization.per_channel, entry.quantization.fields_count);
297-
seed = get_array_hash(seed, entry.quantization.all_default, entry.quantization.fields_count);
298-
seed = get_array_hash(seed, entry.quantization.data, entry.quantization.fields_count);
299-
break;
300-
default: assert(!"unknown post_op");
301-
}
302-
}
303-
// rnn_data_qparams: scale, shift
304-
seed = hash_combine(seed, attr.rnn_data_qparams_.scale_);
305-
seed = hash_combine(seed, attr.rnn_data_qparams_.shift_);
306-
if (!attr.rnn_weights_qparams_.has_default_values()) {
307-
// rnn_weights_qparams: mask
308-
seed = hash_combine(seed, attr.rnn_weights_qparams_.mask_);
309-
// rnn_weights_qparams: count
310-
seed = hash_combine(seed, attr.rnn_weights_qparams_.count_);
311-
// rnn_weights_qparams: scales[:]
312-
seed = get_array_hash(seed, attr.rnn_weights_qparams_.scales_,
313-
attr.rnn_weights_qparams_.count_);
314-
}
315-
if (attr.gpu_attr_) {
316-
seed = hash_combine(seed, attr.gpu_attr_->get_hash());
317-
}
318-
// Combined hash for attributes
319-
return seed;
320-
}
321-
322103
// Functions that compute hash for different op_descs
323104
size_t get_desc_hash(const concat_desc_t &desc) {
324105
size_t seed = 0;

‎src/common/primitive_hashing.hpp

+1-35
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "oneapi/dnnl/dnnl.h"
2727
#include "primitive_attr.hpp"
2828
#include "type_helpers.hpp"
29+
#include "primitive_hashing_utils.hpp"
2930

3031
namespace dnnl {
3132
namespace impl {
@@ -72,8 +73,6 @@ struct key_t {
7273
std::thread::id thread_id_;
7374
};
7475

75-
size_t get_md_hash(const memory_desc_t &md);
76-
size_t get_attr_hash(const primitive_attr_t &attr);
7776
size_t get_desc_hash(const concat_desc_t &desc);
7877
size_t get_desc_hash(const batch_normalization_desc_t &desc);
7978
size_t get_desc_hash(const binary_desc_t &desc);
@@ -96,39 +95,6 @@ size_t get_desc_hash(const softmax_desc_t &desc);
9695
size_t get_desc_hash(const sum_desc_t &desc);
9796
size_t get_desc_hash(const zero_pad_desc_t &desc);
9897

99-
template <typename T>
100-
size_t get_array_hash(size_t seed, const T *v, int size) {
101-
for (int i = 0; i < size; i++) {
102-
seed = hash_combine(seed, v[i]);
103-
}
104-
return seed;
105-
}
106-
107-
template <>
108-
inline size_t get_array_hash<memory_desc_t>(
109-
size_t seed, const memory_desc_t *v, int size) {
110-
for (int i = 0; i < size; i++) {
111-
seed = hash_combine(seed, get_md_hash(v[i]));
112-
}
113-
return seed;
114-
}
115-
116-
inline size_t get_array_hash(
117-
size_t seed, const std::vector<const memory_desc_t *> &mds) {
118-
for (const auto *md : mds)
119-
seed = hash_combine(seed, get_md_hash(*md));
120-
return seed;
121-
}
122-
123-
template <>
124-
inline size_t get_array_hash<data_type_t>(
125-
size_t seed, const data_type_t *v, int size) {
126-
for (int i = 0; i < size; i++) {
127-
seed = hash_combine(seed, static_cast<size_t>(v[i]));
128-
}
129-
return seed;
130-
}
131-
13298
} // namespace primitive_hashing
13399
} // namespace impl
134100
} // namespace dnnl

0 commit comments

Comments
 (0)
Please sign in to comment.