Skip to content

Commit bf3eabd

Browse files
dmitry-gorokhovluweizhou2016
authored andcommitted
[FORK][FEATURE] Introduced Depthwise and Quantization post ops
Primitives that supports new post ops: - Jit Convolutions (FP32,BF16,INT8) - Jit Deconvolution (INT8) - Jit,Ref Pooling (INT8) AMX primitives: explicilty pass dst_type into has_default_values checks Extended int8 AMX convolutions to support depthwise/quantization post ops ONEDNN 3.2 migration squashed commits: - Correct assert for reg id in dw conv kernel ONEDNN 3.5 squash list: [FIX] fix avx2 int8 binary postops reg conflict
1 parent 115e9fa commit bf3eabd

File tree

94 files changed

+2078
-237
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

94 files changed

+2078
-237
lines changed

include/oneapi/dnnl/dnnl.h

+10
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,16 @@ dnnl_status_t DNNL_API dnnl_post_ops_append_prelu(
795795
dnnl_status_t DNNL_API dnnl_post_ops_get_params_prelu(
796796
const_dnnl_post_ops_t post_ops, int index, int *mask);
797797

798+
dnnl_status_t DNNL_API dnnl_post_ops_append_depthwise(
799+
dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg,
800+
const float* weights_data, const float* biases_data);
801+
802+
dnnl_status_t DNNL_API dnnl_post_ops_append_quantization(
803+
dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg,
804+
const void* crop_low, const void* crop_high,
805+
const void* input_scale, const void* input_shift,
806+
const void* output_scale, const void* output_shift);
807+
798808
/// @} dnnl_api_attributes
799809

800810
/// @} dnnl_api_primitives

include/oneapi/dnnl/dnnl.hpp

+22
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,12 @@ enum class algorithm {
496496
softmax_accurate = dnnl_softmax_accurate,
497497
/// LogSoftmax, numerically stable
498498
softmax_log = dnnl_softmax_log,
499+
500+
depthwise_scale_shift = dnnl_depthwise_scale_shift,
501+
depthwise_prelu = dnnl_depthwise_prelu,
502+
503+
quantization_quantize_dequantize = dnnl_quantization_quantize_dequantize,
504+
quantization_quantize = dnnl_quantization_quantize,
499505
};
500506

501507
/// Converts algorithm kind enum value from C++ API to C API type.
@@ -3924,6 +3930,22 @@ struct post_ops : public handle<dnnl_post_ops_t> {
39243930
error::wrap_c_api(dnnl_post_ops_get_params_prelu(get(), index, &mask),
39253931
"could not get parameters of a binary post-op");
39263932
}
3933+
3934+
void append_depthwise(algorithm alg, const float* weights_data,
3935+
const float* biases_data) {
3936+
error::wrap_c_api(dnnl_post_ops_append_depthwise(get(),
3937+
convert_to_c(alg), weights_data, biases_data),
3938+
"could not append depthwise");
3939+
}
3940+
3941+
void append_quantization(algorithm alg,
3942+
const void* crop_low, const void* crop_high,
3943+
const void* input_scale, const void* input_shift,
3944+
const void* output_scale, const void* output_shift) {
3945+
error::wrap_c_api(dnnl_post_ops_append_quantization(get(), convert_to_c(alg), crop_low, crop_high,
3946+
input_scale, input_shift, output_scale, output_shift),
3947+
"could not append quantization");
3948+
}
39273949
};
39283950

39293951
/// @cond DO_NOT_DOCUMENT_THIS

include/oneapi/dnnl/dnnl_types.h

+10
Original file line numberDiff line numberDiff line change
@@ -1992,6 +1992,10 @@ typedef enum {
19921992
dnnl_deconvolution,
19931993
/// An element-wise primitive.
19941994
dnnl_eltwise,
1995+
/// An depthwise-wise primitive.
1996+
dnnl_depthwise,
1997+
/// A quantization primitive.
1998+
dnnl_quantization,
19951999
/// An LRN primitive.
19962000
dnnl_lrn,
19972001
/// A batch normalization primitive.
@@ -2176,6 +2180,12 @@ typedef enum {
21762180
dnnl_softmax_accurate = 0x30000,
21772181
/// Logsoftmax
21782182
dnnl_softmax_log,
2183+
2184+
dnnl_depthwise_scale_shift = 0x3fff0,
2185+
dnnl_depthwise_prelu = 0x3fff1,
2186+
2187+
dnnl_quantization_quantize_dequantize = 0x4fff0,
2188+
dnnl_quantization_quantize = 0x4fff1,
21792189
} dnnl_alg_kind_t;
21802190

21812191
/// Flags for normalization primitives.

src/common/c_types_map.hpp

+6
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ const alg_kind_t reduction_norm_lp_power_p_sum
141141
= dnnl_reduction_norm_lp_power_p_sum;
142142
const alg_kind_t softmax_accurate = dnnl_softmax_accurate;
143143
const alg_kind_t softmax_log = dnnl_softmax_log;
144+
const alg_kind_t depthwise_scale_shift = dnnl_depthwise_scale_shift;
145+
const alg_kind_t depthwise_prelu = dnnl_depthwise_prelu;
146+
const alg_kind_t quantization_quantize_dequantize = dnnl_quantization_quantize_dequantize;
147+
const alg_kind_t quantization_quantize = dnnl_quantization_quantize;
144148
} // namespace alg_kind
145149

146150
using data_type_t = dnnl_data_type_t;
@@ -1949,6 +1953,8 @@ const primitive_kind_t reduction = dnnl_reduction;
19491953
const primitive_kind_t softmax = dnnl_softmax;
19501954
const primitive_kind_t layer_normalization = dnnl_layer_normalization;
19511955
const primitive_kind_t group_normalization = dnnl_group_normalization;
1956+
const primitive_kind_t depthwise = dnnl_depthwise;
1957+
const primitive_kind_t quantization = dnnl_quantization;
19521958

19531959
// Internal only primitive kinds.
19541960
const primitive_kind_t internal_only_start = (primitive_kind_t)(1 << 12);

src/common/dnnl_debug_autogenerated.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -1753,6 +1753,8 @@ const char *dnnl_prim_kind2str(dnnl_primitive_kind_t v) {
17531753
if (v == dnnl_softmax) return "softmax";
17541754
if (v == dnnl_layer_normalization) return "layer_normalization";
17551755
if (v == dnnl_group_normalization) return "group_normalization";
1756+
if (v == dnnl_depthwise) return "depthwise";
1757+
if (v == dnnl_quantization) return "quantization";
17561758
if (v == dnnl_primitive_kind_max) return "primitive_kind_max";
17571759
if (v == dnnl::impl::primitive_kind::sdpa) return "sdpa";
17581760
assert(!"unknown prim_kind");
@@ -1830,6 +1832,10 @@ const char *dnnl_alg_kind2str(dnnl_alg_kind_t v) {
18301832
if (v == dnnl_reduction_norm_lp_power_p_sum) return "reduction_norm_lp_power_p_sum";
18311833
if (v == dnnl_softmax_accurate) return "softmax_accurate";
18321834
if (v == dnnl_softmax_log) return "softmax_log";
1835+
if (v == dnnl_depthwise_scale_shift) return "depthwise_scale_shift";
1836+
if (v == dnnl_depthwise_prelu) return "depthwise_prelu";
1837+
if (v == dnnl_quantization_quantize_dequantize) return "quantization_quantize_dequantize";
1838+
if (v == dnnl_quantization_quantize) return "quantization_quantize";
18331839
assert(!"unknown alg_kind");
18341840
return "unknown alg_kind";
18351841
}

src/common/ittnotify.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ void primitive_task_start(primitive_kind_t kind) {
8080
CASE(layer_normalization),
8181
CASE(group_normalization),
8282
CASE(sdpa),
83+
CASE(depthwise),
84+
CASE(quantization),
8385
};
8486
#undef CASE
8587
int kind_idx = (int)kind;

src/common/math_utils.hpp

+37
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,43 @@ inline float stochastic_round_fwd(
567567
return r;
568568
}
569569

570+
inline float get_bias(const char *bias, size_t offset, data_type_t data_type) {
571+
if (!bias) return 0.0f;
572+
573+
#define CASE(dt) \
574+
case dt: return (float)((const prec_traits<dt>::type *)bias)[offset]
575+
576+
switch (data_type) {
577+
CASE(data_type::s8);
578+
CASE(data_type::u8);
579+
CASE(data_type::bf16);
580+
CASE(data_type::s32);
581+
CASE(data_type::f32);
582+
default: assert(!"unimplemented");
583+
}
584+
return 0; // never happens (should probably be a NaN)
585+
#undef CASE
586+
}
587+
588+
inline float get_sum(char *sum, size_t offset, data_type_t data_type)
589+
{
590+
if (!sum)
591+
return 0.0f;
592+
593+
#define CASE(dt) \
594+
case dt: return (float)((const prec_traits<dt>::type *)sum)[offset]
595+
596+
switch (data_type) {
597+
CASE(data_type::s8);
598+
CASE(data_type::u8);
599+
CASE(data_type::s32);
600+
CASE(data_type::f32);
601+
default: assert(!"unimplemented");
602+
}
603+
return 0; // never happens (should probably be a NaN)
604+
#undef CASE
605+
}
606+
570607
} // namespace math
571608
} // namespace impl
572609
} // namespace dnnl

src/common/nstl.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,10 @@ class vector : public c_compatible {
339339
}
340340
void clear() { _impl.clear(); }
341341
void push_back(const T &t) { _impl.push_back(t); }
342+
template<typename... Args>
343+
void emplace_back(Args&&... args) {
344+
_impl.emplace_back(std::forward<Args>(args)...);
345+
}
342346
void resize(size_type count) { _impl.resize(count); }
343347
void reserve(size_type count) { _impl.reserve(count); }
344348
};

src/common/primitive_attr.cpp

+97
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,29 @@ status_t scales_t::set(dim_t count, int mask, const float *scales) {
7373
return status::success;
7474
}
7575

76+
77+
template <typename T>
78+
status_t shifts_t<T>::set(int count, int mask, const T *shifts) {
79+
cleanup();
80+
81+
count_ = count;
82+
mask_ = mask;
83+
84+
if (count_ == 1) {
85+
shifts_ = shifts_buf_;
86+
utils::array_set(shifts_, shifts[0], shifts_buf_size);
87+
} else {
88+
shifts_ = (T *)impl::malloc(count_ * sizeof(*shifts_), 64);
89+
if (shifts_ == nullptr)
90+
return status::out_of_memory;
91+
92+
for (int c = 0; c < count_; ++c)
93+
shifts_[c] = shifts[c];
94+
}
95+
96+
return status::success;
97+
}
98+
7699
status_t zero_points_t::get(int arg, int *mask, data_type_t *dt) const {
77100
if (mask) *mask = get_mask(arg);
78101
if (dt) *dt = get_data_type(arg);
@@ -182,6 +205,14 @@ bool primitive_attr_t::has_default_values(dnnl_primitive_attr::skip_mask_t mask,
182205
#undef CHECK_ARG
183206
}
184207

208+
bool primitive_attr_t::has_asymmetric_quantization() const {
209+
return true
210+
&& output_scales_.has_default_values()
211+
&& rnn_data_qparams_.has_default_values()
212+
&& rnn_weights_qparams_.has_default_values()
213+
&& (!input_zero_points_.has_default_values() || !weights_zero_points_.has_default_values());
214+
}
215+
185216
bool primitive_attr_t::defined(dnnl_primitive_attr::skip_mask_t mask) const {
186217
using smask_t = skip_mask_t;
187218
bool ok = true;
@@ -313,6 +344,47 @@ status_t post_ops_t::append_prelu(int mask) {
313344
return success;
314345
}
315346

347+
status_t post_ops_t::append_depthwise(alg_kind_t alg, const float* weights_data, const float* biases_data) {
348+
using namespace dnnl::impl::alg_kind;
349+
if (len() == post_ops_limit) return out_of_memory;
350+
bool known_alg = one_of(alg, depthwise_scale_shift, depthwise_prelu);
351+
if (!known_alg)
352+
return invalid_arguments;
353+
354+
entry_.emplace_back();
355+
auto &e = entry_.back();
356+
e.kind = primitive_kind::depthwise;
357+
e.depthwise.alg = alg;
358+
e.depthwise.weights_data = weights_data;
359+
e.depthwise.biases_data = biases_data;
360+
361+
return success;
362+
}
363+
364+
status_t post_ops_t::append_quantization(alg_kind_t alg,
365+
const void* crop_low, const void* crop_high,
366+
const void* input_scale, const void* input_shift,
367+
const void* output_scale, const void* output_shift) {
368+
using namespace dnnl::impl::alg_kind;
369+
if (len() == post_ops_limit) return out_of_memory;
370+
bool known_alg = one_of(alg, quantization_quantize_dequantize, quantization_quantize);
371+
if (!known_alg)
372+
return invalid_arguments;
373+
374+
entry_.emplace_back();
375+
auto &e = entry_.back();
376+
e.kind = primitive_kind::quantization;
377+
e.quantization.alg = alg;
378+
e.quantization.crop_low_data = reinterpret_cast<const shifts_t<float>*>(crop_low);
379+
e.quantization.crop_high_data = reinterpret_cast<const shifts_t<float>*>(crop_high);
380+
e.quantization.input_scale_data = reinterpret_cast<const scales_t*>(input_scale);
381+
e.quantization.input_shift_data = reinterpret_cast<const shifts_t<float>*>(input_shift);
382+
e.quantization.output_scale_data = reinterpret_cast<const scales_t*>(output_scale);
383+
e.quantization.output_shift_data = reinterpret_cast<const shifts_t<float>*>(output_shift);
384+
385+
return success;
386+
}
387+
316388
bool post_ops_t::defined() const {
317389
for (int idx = 0; idx < len(); ++idx) {
318390
auto kind = entry_[idx].kind;
@@ -327,6 +399,10 @@ bool post_ops_t::defined() const {
327399
primitive_kind::prelu,
328400
primitive_kind::convolution)) {
329401
// binary is always defined
402+
} else if (kind == primitive_kind::depthwise) {
403+
// depthwise is always defined
404+
} else if (kind == primitive_kind::quantization) {
405+
// quantization is always defined
330406
} else {
331407
assert(!"unreachable");
332408
}
@@ -787,6 +863,23 @@ status_t dnnl_post_ops_get_params_prelu(
787863
return success;
788864
}
789865

866+
status_t dnnl_post_ops_append_depthwise(dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg,
867+
const float* weights_data, const float* biases_data) {
868+
if (post_ops == nullptr) return invalid_arguments;
869+
870+
return post_ops->append_depthwise(alg, weights_data, biases_data);
871+
}
872+
873+
status_t dnnl_post_ops_append_quantization(post_ops_t *post_ops, alg_kind_t kind,
874+
const void* crop_low, const void* crop_high,
875+
const void* input_scale, const void* input_shift,
876+
const void* output_scale, const void* output_shift) {
877+
if (post_ops == nullptr)
878+
return invalid_arguments;
879+
880+
return post_ops->append_quantization(kind, crop_low, crop_high, input_scale, input_shift, output_scale, output_shift);
881+
}
882+
790883
status_t dnnl_primitive_attr_set_rnn_data_qparams(
791884
primitive_attr_t *attr, const float scale, const float shift) {
792885
if (attr == nullptr) return invalid_arguments;
@@ -854,3 +947,7 @@ status_t DNNL_API dnnl_primitive_attr_set_rnn_tparams(
854947

855948
return attr->rnn_tparams_.set(mode, ngates, scales, cscale);
856949
}
950+
951+
template struct dnnl::impl::shifts_t<uint8_t>;
952+
template struct dnnl::impl::shifts_t<int32_t>;
953+
template struct dnnl::impl::shifts_t<float>;

0 commit comments

Comments
 (0)