Skip to content

Commit 879e531

Browse files
maxnickluweizhou2016
authored andcommitted
[FIX] Hash utility functions were extracted to a separate module for reuse
1 parent d7c312e commit 879e531

5 files changed

+359
-232
lines changed

src/common/primitive_hashing.cpp

-196
Original file line numberDiff line numberDiff line change
@@ -98,202 +98,6 @@ bool key_t::operator==(const key_t &rhs) const {
9898
return true;
9999
}
100100

101-
// Combine hash of each memory_desc_t data member
102-
size_t get_md_hash(const memory_desc_t &md) {
103-
size_t seed = 0;
104-
seed = get_array_hash(seed, md.dims, md.ndims);
105-
seed = hash_combine(seed, static_cast<size_t>(md.data_type));
106-
seed = get_array_hash(seed, md.padded_dims, md.ndims);
107-
seed = get_array_hash(seed, md.padded_offsets, md.ndims);
108-
seed = hash_combine(seed, md.offset0);
109-
seed = hash_combine(seed, static_cast<size_t>(md.format_kind));
110-
// format desc
111-
switch ((int)md.format_kind) {
112-
case format_kind::undef:
113-
case format_kind::any: break;
114-
case format_kind::blocked:
115-
for (int i = 0; i < md.ndims; i++) {
116-
if (md.dims[i] == 1 && md.padded_dims[i] == 1) continue;
117-
seed = hash_combine(seed, md.format_desc.blocking.strides[i]);
118-
}
119-
seed = hash_combine(seed, md.format_desc.blocking.inner_nblks);
120-
seed = get_array_hash(seed, md.format_desc.blocking.inner_blks,
121-
md.format_desc.blocking.inner_nblks);
122-
seed = get_array_hash(seed, md.format_desc.blocking.inner_idxs,
123-
md.format_desc.blocking.inner_nblks);
124-
break;
125-
case format_kind::wino:
126-
seed = hash_combine(seed,
127-
static_cast<size_t>(md.format_desc.wino_desc.wino_format));
128-
seed = hash_combine(seed, md.format_desc.wino_desc.r);
129-
seed = hash_combine(seed, md.format_desc.wino_desc.alpha);
130-
seed = hash_combine(seed, md.format_desc.wino_desc.ic);
131-
seed = hash_combine(seed, md.format_desc.wino_desc.oc);
132-
seed = hash_combine(seed, md.format_desc.wino_desc.ic_block);
133-
seed = hash_combine(seed, md.format_desc.wino_desc.oc_block);
134-
seed = hash_combine(seed, md.format_desc.wino_desc.ic2_block);
135-
seed = hash_combine(seed, md.format_desc.wino_desc.oc2_block);
136-
seed = hash_combine(seed, md.format_desc.wino_desc.adj_scale);
137-
seed = hash_combine(seed, md.format_desc.wino_desc.size);
138-
break;
139-
case format_kind::rnn_packed:
140-
seed = hash_combine(seed,
141-
static_cast<size_t>(md.format_desc.rnn_packed_desc.format));
142-
seed = hash_combine(seed, md.format_desc.rnn_packed_desc.n_parts);
143-
seed = hash_combine(seed, md.format_desc.rnn_packed_desc.n);
144-
seed = hash_combine(seed, md.format_desc.rnn_packed_desc.ldb);
145-
{
146-
int n_parts = md.format_desc.rnn_packed_desc.n_parts;
147-
seed = get_array_hash(
148-
seed, md.format_desc.rnn_packed_desc.parts, n_parts);
149-
seed = get_array_hash(seed,
150-
md.format_desc.rnn_packed_desc.part_pack_size, n_parts);
151-
seed = get_array_hash(seed,
152-
md.format_desc.rnn_packed_desc.pack_part, n_parts);
153-
}
154-
seed = hash_combine(
155-
seed, md.format_desc.rnn_packed_desc.offset_compensation);
156-
seed = hash_combine(seed, md.format_desc.rnn_packed_desc.size);
157-
break;
158-
#ifdef DNNL_EXPERIMENTAL_SPARSE
159-
case format_kind::sparse:
160-
seed = hash_combine(seed,
161-
static_cast<size_t>(md.format_desc.sparse_desc.encoding));
162-
seed = hash_combine(seed, md.format_desc.sparse_desc.nnz);
163-
seed = get_array_hash(seed,
164-
md.format_desc.sparse_desc.metadata_types,
165-
sparse_desc_t::max_metadata_types);
166-
break;
167-
#endif
168-
default: assert(!"unknown format_kind");
169-
}
170-
171-
if (md.extra.flags != dnnl_memory_extra_flag_none) {
172-
seed = hash_combine(seed, md.extra.flags);
173-
if ((md.extra.flags
174-
& (dnnl_memory_extra_flag_compensation_conv_s8s8
175-
| dnnl_memory_extra_flag_rnn_u8s8_compensation))
176-
&& !types::extra_flag_rnn_s8s8_compensation_is_set(
177-
md.extra.flags)) {
178-
seed = hash_combine(seed, md.extra.compensation_mask);
179-
}
180-
181-
if (md.extra.flags & dnnl_memory_extra_flag_scale_adjust) {
182-
seed = hash_combine(seed, md.extra.scale_adjust);
183-
}
184-
185-
if (md.extra.flags
186-
& dnnl_memory_extra_flag_compensation_conv_asymmetric_src) {
187-
seed = hash_combine(seed, md.extra.asymm_compensation_mask);
188-
}
189-
}
190-
// Combined hash for a memory descriptor
191-
return seed;
192-
}
193-
194-
// Combine hash of each primitive_attr_t data member
195-
size_t get_attr_hash(const primitive_attr_t &attr) {
196-
size_t seed = 0;
197-
// scratchpad_mode
198-
seed = hash_combine(seed, static_cast<size_t>(attr.scratchpad_mode_));
199-
// fpmath_mode
200-
seed = hash_combine(seed, static_cast<size_t>(attr.fpmath_mode_));
201-
202-
if (!attr.output_scales_.has_default_values()) {
203-
// output_scales: mask
204-
seed = hash_combine(seed, attr.output_scales_.mask_);
205-
} else if (!attr.scales_.has_default_values()) {
206-
// go through scales for all arguments
207-
for (const auto &p : attr.scales_.scales_) {
208-
// scales: arg
209-
seed = hash_combine(seed, p.first);
210-
// scales: mask
211-
seed = hash_combine(seed, p.second.mask_);
212-
}
213-
}
214-
// zero_points
215-
for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST})
216-
if (!attr.zero_points_.has_default_values(arg)) {
217-
// zero_points: arg
218-
seed = hash_combine(seed, arg);
219-
int mask = 0;
220-
attr.zero_points_.get(arg, &mask);
221-
// zero_points: mask
222-
seed = hash_combine(seed, mask);
223-
}
224-
// post_ops: entry[:]
225-
for (int i = 0; i < attr.post_ops_.len(); i++) {
226-
const auto &entry = attr.post_ops_.entry_[i];
227-
switch (entry.kind) {
228-
case primitive_kind::eltwise:
229-
seed = hash_combine(
230-
seed, static_cast<size_t>(entry.eltwise.alg));
231-
seed = hash_combine(seed, entry.eltwise.scale);
232-
seed = hash_combine(seed, entry.eltwise.alpha);
233-
seed = hash_combine(seed, entry.eltwise.beta);
234-
break;
235-
case primitive_kind::sum:
236-
seed = hash_combine(seed, entry.sum.scale);
237-
seed = hash_combine(seed, entry.sum.zero_point);
238-
seed = hash_combine(seed, static_cast<size_t>(entry.sum.dt));
239-
break;
240-
case primitive_kind::convolution:
241-
seed = hash_combine(
242-
seed, static_cast<size_t>(entry.depthwise_conv.kernel));
243-
seed = hash_combine(
244-
seed, static_cast<size_t>(entry.depthwise_conv.stride));
245-
seed = hash_combine(seed,
246-
static_cast<size_t>(entry.depthwise_conv.padding));
247-
seed = hash_combine(
248-
seed, static_cast<size_t>(entry.depthwise_conv.wei_dt));
249-
seed = hash_combine(seed,
250-
static_cast<size_t>(entry.depthwise_conv.bias_dt));
251-
seed = hash_combine(
252-
seed, static_cast<size_t>(entry.depthwise_conv.dst_dt));
253-
break;
254-
case primitive_kind::binary:
255-
seed = hash_combine(
256-
seed, static_cast<size_t>(entry.binary.alg));
257-
seed = hash_combine(
258-
seed, get_md_hash(entry.binary.user_src1_desc));
259-
break;
260-
case primitive_kind::prelu:
261-
seed = hash_combine(
262-
seed, static_cast<size_t>(entry.prelu.mask));
263-
break;
264-
case primitive_kind::depthwise:
265-
seed = hash_combine(seed, static_cast<size_t>(entry.depthwise.alg));
266-
seed = hash_combine(seed, reinterpret_cast<size_t>(entry.depthwise.weights_data));
267-
seed = hash_combine(seed, reinterpret_cast<size_t>(entry.depthwise.biases_data));
268-
break;
269-
case primitive_kind::quantization:
270-
seed = hash_combine(seed, static_cast<size_t>(entry.quantization.alg));
271-
seed = get_array_hash(seed, entry.quantization.per_channel, entry.quantization.fields_count);
272-
seed = get_array_hash(seed, entry.quantization.all_default, entry.quantization.fields_count);
273-
seed = get_array_hash(seed, entry.quantization.data, entry.quantization.fields_count);
274-
break;
275-
default: assert(!"unknown post_op");
276-
}
277-
}
278-
// rnn_data_qparams: scale, shift
279-
seed = hash_combine(seed, attr.rnn_data_qparams_.scale_);
280-
seed = hash_combine(seed, attr.rnn_data_qparams_.shift_);
281-
if (!attr.rnn_weights_qparams_.has_default_values()) {
282-
// rnn_weights_qparams: mask
283-
seed = hash_combine(seed, attr.rnn_weights_qparams_.mask_);
284-
// rnn_weights_qparams: count
285-
seed = hash_combine(seed, attr.rnn_weights_qparams_.count_);
286-
// rnn_weights_qparams: scales[:]
287-
seed = get_array_hash(seed, attr.rnn_weights_qparams_.scales_,
288-
attr.rnn_weights_qparams_.count_);
289-
}
290-
if (attr.gpu_attr_) {
291-
seed = hash_combine(seed, attr.gpu_attr_->get_hash());
292-
}
293-
// Combined hash for attributes
294-
return seed;
295-
}
296-
297101
// Functions that compute hash for different op_descs
298102
size_t get_desc_hash(const concat_desc_t &desc) {
299103
size_t seed = 0;

src/common/primitive_hashing.hpp

+1-35
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "oneapi/dnnl/dnnl.h"
2727
#include "primitive_attr.hpp"
2828
#include "type_helpers.hpp"
29+
#include "primitive_hashing_utils.hpp"
2930

3031
namespace dnnl {
3132
namespace impl {
@@ -71,8 +72,6 @@ struct key_t {
7172
std::thread::id thread_id_;
7273
};
7374

74-
size_t get_md_hash(const memory_desc_t &md);
75-
size_t get_attr_hash(const primitive_attr_t &attr);
7675
size_t get_desc_hash(const concat_desc_t &desc);
7776
size_t get_desc_hash(const batch_normalization_desc_t &desc);
7877
size_t get_desc_hash(const binary_desc_t &desc);
@@ -95,39 +94,6 @@ size_t get_desc_hash(const softmax_desc_t &desc);
9594
size_t get_desc_hash(const sum_desc_t &desc);
9695
size_t get_desc_hash(const zero_pad_desc_t &desc);
9796

98-
template <typename T>
99-
size_t get_array_hash(size_t seed, const T *v, int size) {
100-
for (int i = 0; i < size; i++) {
101-
seed = hash_combine(seed, v[i]);
102-
}
103-
return seed;
104-
}
105-
106-
template <>
107-
inline size_t get_array_hash<memory_desc_t>(
108-
size_t seed, const memory_desc_t *v, int size) {
109-
for (int i = 0; i < size; i++) {
110-
seed = hash_combine(seed, get_md_hash(v[i]));
111-
}
112-
return seed;
113-
}
114-
115-
inline size_t get_array_hash(
116-
size_t seed, const std::vector<const memory_desc_t *> &mds) {
117-
for (const auto *md : mds)
118-
seed = hash_combine(seed, get_md_hash(*md));
119-
return seed;
120-
}
121-
122-
template <>
123-
inline size_t get_array_hash<data_type_t>(
124-
size_t seed, const data_type_t *v, int size) {
125-
for (int i = 0; i < size; i++) {
126-
seed = hash_combine(seed, static_cast<size_t>(v[i]));
127-
}
128-
return seed;
129-
}
130-
13197
} // namespace primitive_hashing
13298
} // namespace impl
13399
} // namespace dnnl

0 commit comments

Comments
 (0)