Skip to content

Commit b4f33c7

Browse files
luweizhou2016azhai219
authored andcommitted
[FORK][FEATURE] Migrate legacy post ops and zero points on runtime data pointers
ONEDNN 3.2 migration squashed commits: - fix depthwise nwc conv - Fix deconv 3D post OPs segment fault issue. (openvinotoolkit#130) - fix in avx2 conv+fakequant post ops +nxc last channel wrong result (openvinotoolkit#128) - Luwei/fix deconv 3d postops bug (openvinotoolkit#136) -- Fix the deconv fused with depthwise issue in cpuFuncTests -- Switch to use jit_uni_depthwise_injector API. -- Fix potential conflicts in registers and YMM. -- Update with optimization. - fix legacyOps with stock src_zero_point in jit_avx512_core_amx - Fix incorrect offset to rsp - Preserve bf16emu scratch register when conflict with legacy post ops - fix per-OC legacyPostOps for jit_avx512_dw_conv_fwd_kernel_bf16 - Fix segment fault caused by dest scale. ONEDNN 3.5 migration squshed commmits: [FIX][FORK][FEATURE] Introduced Depthwise and Quantization post ops [FORK] [FIX] Fix legacy zero point issue on AVX2 + AVX512. [FEATURE]Migrate legacy post ops and zero points on runtime data pointers [FORK][FIX] jit_uni_dw_conv_row_f32: fixed post ops start idx [FEATURE] Migrate legacy post ops and zero points on runtime data pointers [Fix] Update the attr checking caused by forked onednn. [FORK][FIX][x64] Add proper post op checks to gemm_conv is split particially and squashed. [FORK][FIX] jit_uni_dw_conv_kernel_f32: fixed register conflict [FIX] SPlit the removed unused variables [FIX] fix avx512 bf16 dw stack pointer [ARM] Fixed legacy post-ops changes for ARM target
1 parent 10d508a commit b4f33c7

File tree

96 files changed

+1717
-949
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+1717
-949
lines changed

include/oneapi/dnnl/dnnl.h

+9-16
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "oneapi/dnnl/dnnl_config.h"
2525
#include "oneapi/dnnl/dnnl_types.h"
2626
#include "oneapi/dnnl/dnnl_version.h"
27+
#include <stdbool.h>
2728

2829
#ifdef __cplusplus
2930
extern "C" {
@@ -519,19 +520,13 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_get_output_compensations(
519520
const_dnnl_primitive_attr_t attr, int *count, int *mask, const int32_t **compensations);
520521

521522
dnnl_status_t DNNL_API dnnl_primitive_attr_set_output_compensations(
522-
dnnl_primitive_attr_t attr, int count, int mask, const int32_t *compensations);
523-
524-
dnnl_status_t DNNL_API dnnl_primitive_attr_get_input_zero_points(
525-
const_dnnl_primitive_attr_t attr, int *count, int *mask, const uint8_t **zero_points);
523+
dnnl_primitive_attr_t attr, int count, int mask);
526524

527525
dnnl_status_t DNNL_API dnnl_primitive_attr_set_input_zero_points(
528-
dnnl_primitive_attr_t attr, int count, int mask, const uint8_t *zero_points);
529-
530-
dnnl_status_t DNNL_API dnnl_primitive_attr_get_weights_zero_points(
531-
const_dnnl_primitive_attr_t attr, int *count, int *mask, const float **zero_points);
526+
dnnl_primitive_attr_t attr, int count, int mask);
532527

533528
dnnl_status_t DNNL_API dnnl_primitive_attr_set_weights_zero_points(
534-
dnnl_primitive_attr_t attr, int count, int mask, const float *zero_points);
529+
dnnl_primitive_attr_t attr, int count, int mask);
535530

536531
/// Returns primitive attributes post-ops.
537532
///
@@ -739,8 +734,7 @@ dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw(
739734
///
740735
/// The kind of this post operation is #dnnl_convolution.
741736
dnnl_status_t DNNL_API dnnl_post_ops_append_dw_conv(
742-
dnnl_post_ops_t post_ops, int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, dnnl_data_type_t in_dt,
743-
const float* weights_data, const float* biases_data);
737+
dnnl_post_ops_t post_ops, int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, dnnl_data_type_t in_dt);
744738

745739
/// Appends a binary post-op.
746740
///
@@ -822,14 +816,13 @@ dnnl_status_t DNNL_API dnnl_post_ops_get_params_prelu(
822816
const_dnnl_post_ops_t post_ops, int index, int *mask);
823817

824818
dnnl_status_t DNNL_API dnnl_post_ops_append_depthwise(
825-
dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg,
826-
const float* weights_data, const float* biases_data);
819+
dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg, size_t offset_size, const size_t* offset);
827820

828821
dnnl_status_t DNNL_API dnnl_post_ops_append_quantization(
829822
dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg,
830-
const void* crop_low, const void* crop_high,
831-
const void* input_scale, const void* input_shift,
832-
const void* output_scale, const void* output_shift);
823+
size_t per_channel_size, const bool* per_channel,
824+
size_t all_default_size, const bool* all_default,
825+
size_t offset_size, const size_t* offset);
833826

834827
dnnl_status_t DNNL_API dnnl_post_ops_append_binarization(
835828
dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg, const float* weights_data, const float* output_mask);

include/oneapi/dnnl/dnnl.hpp

+20-65
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <memory>
3030
#include <string>
3131
#include <vector>
32+
#include <array>
3233
#include <unordered_map>
3334

3435
#include "oneapi/dnnl/dnnl.h"
@@ -148,6 +149,10 @@ struct primitive : public handle<dnnl_primitive_t> {
148149
layer_normalization = dnnl_layer_normalization,
149150
/// A group normalization primitive
150151
group_normalization = dnnl_group_normalization,
152+
153+
depthwise = dnnl_depthwise,
154+
quantization = dnnl_quantization,
155+
binarization = dnnl_binarization,
151156
};
152157

153158
using handle::handle;
@@ -168,7 +173,7 @@ struct primitive : public handle<dnnl_primitive_t> {
168173
const std::vector<uint8_t> &cache_blob);
169174

170175
/// Constructs a primitive from a primitive descriptor.
171-
///
176+
///src/common/deconvolution_pd.hpp
172177
/// @param pd Primitive descriptor.
173178
primitive(const primitive_desc &pd);
174179

@@ -3864,10 +3869,9 @@ struct post_ops : public handle<dnnl_post_ops_t> {
38643869
"could not append a binary post-op");
38653870
}
38663871

3867-
void append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, dnnl_data_type_t in_dt,
3868-
const float* weights_data, const float* biases_data) {
3872+
void append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, dnnl_data_type_t in_dt) {
38693873
error::wrap_c_api(dnnl_post_ops_append_dw_conv(get(),
3870-
in_h, in_w, ker_h, ker_w, str_h, str_w, in_dt, weights_data, biases_data),
3874+
in_h, in_w, ker_h, ker_w, str_h, str_w, in_dt),
38713875
"could not append dw conv");
38723876
}
38733877

@@ -3956,19 +3960,15 @@ struct post_ops : public handle<dnnl_post_ops_t> {
39563960
"could not get parameters of a binary post-op");
39573961
}
39583962

3959-
void append_depthwise(algorithm alg, const float* weights_data,
3960-
const float* biases_data) {
3961-
error::wrap_c_api(dnnl_post_ops_append_depthwise(get(),
3962-
convert_to_c(alg), weights_data, biases_data),
3963+
void append_depthwise(algorithm alg, const std::array<size_t, 2>& offset) {
3964+
error::wrap_c_api(dnnl_post_ops_append_depthwise(get(), convert_to_c(alg), offset.size(), offset.data()),
39633965
"could not append depthwise");
39643966
}
39653967

3966-
void append_quantization(algorithm alg,
3967-
const void* crop_low, const void* crop_high,
3968-
const void* input_scale, const void* input_shift,
3969-
const void* output_scale, const void* output_shift) {
3970-
error::wrap_c_api(dnnl_post_ops_append_quantization(get(), convert_to_c(alg), crop_low, crop_high,
3971-
input_scale, input_shift, output_scale, output_shift),
3968+
void append_quantization(algorithm alg, const std::array<bool, 6>& per_channel, const std::array<bool, 6>& all_default,
3969+
const std::array<size_t, 6>& offset) {
3970+
error::wrap_c_api(dnnl_post_ops_append_quantization(get(), convert_to_c(alg), per_channel.size(), per_channel.data(),
3971+
all_default.size(), all_default.data(), offset.size(), offset.data()),
39723972
"could not append quantization");
39733973
}
39743974

@@ -4226,66 +4226,21 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
42264226
"could not set zero points primitive attribute");
42274227
}
42284228

4229-
void get_output_compensations(int &mask, std::vector<int32_t> &compensations) const
4229+
void set_output_compensations(dnnl_dim_t count, int mask)
42304230
{
4231-
int count, c_mask;
4232-
const int32_t *c_compensations;
4233-
error::wrap_c_api(dnnl_primitive_attr_get_output_compensations(get(),
4234-
&count, &c_mask, &c_compensations),
4235-
"could not get int output compensations");
4236-
compensations.resize(count);
4237-
4238-
mask = c_mask;
4239-
for (int c = 0; c < count; ++c)
4240-
compensations[c] = c_compensations[c];
4241-
}
4242-
4243-
void set_output_compensations(int mask, const std::vector<int32_t> &compensations)
4244-
{
4245-
error::wrap_c_api(dnnl_primitive_attr_set_output_compensations(get(),
4246-
(int)compensations.size(), mask, &compensations[0]),
4231+
error::wrap_c_api(dnnl_primitive_attr_set_output_compensations(get(), count, mask),
42474232
"could not set int output compensations");
42484233
}
42494234

4250-
void get_input_zero_points(int &mask, std::vector<uint8_t> &zero_points) const
4235+
void set_input_zero_points(dnnl_dim_t count, int mask)
42514236
{
4252-
int count, c_mask;
4253-
const uint8_t *c_zero_points;
4254-
error::wrap_c_api(dnnl_primitive_attr_get_input_zero_points(get(),
4255-
&count, &c_mask, &c_zero_points),
4256-
"could not get int input zero_points");
4257-
zero_points.resize(count);
4258-
4259-
mask = c_mask;
4260-
for (int c = 0; c < count; ++c)
4261-
zero_points[c] = c_zero_points[c];
4262-
}
4263-
4264-
void set_input_zero_points(int mask, const std::vector<uint8_t> &zero_points)
4265-
{
4266-
error::wrap_c_api(dnnl_primitive_attr_set_input_zero_points(get(),
4267-
(int)zero_points.size(), mask, &zero_points[0]),
4237+
error::wrap_c_api(dnnl_primitive_attr_set_input_zero_points(get(), count, mask),
42684238
"could not set int input zero_points");
42694239
}
42704240

4271-
void get_weights_zero_points(int &mask, std::vector<int8_t> &zero_points) const
4272-
{
4273-
int count, c_mask;
4274-
const float *c_zero_points;
4275-
error::wrap_c_api(dnnl_primitive_attr_get_weights_zero_points(get(),
4276-
&count, &c_mask, &c_zero_points),
4277-
"could not get int weights zero_points");
4278-
zero_points.resize(count);
4279-
4280-
mask = c_mask;
4281-
for (int c = 0; c < count; ++c)
4282-
zero_points[c] = c_zero_points[c];
4283-
}
4284-
4285-
void set_weights_zero_points(int mask, const std::vector<float> &zero_points)
4241+
void set_weights_zero_points(dnnl_dim_t count, int mask)
42864242
{
4287-
error::wrap_c_api(dnnl_primitive_attr_set_weights_zero_points(get(),
4288-
(int)zero_points.size(), mask, &zero_points[0]),
4243+
error::wrap_c_api(dnnl_primitive_attr_set_weights_zero_points(get(), count, mask),
42894244
"could not set int weights zero_points");
42904245
}
42914246

src/common/convolution.cpp

+9-6
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,10 @@ status_t conv_attr_check(const convolution_desc_t &desc, const engine_t *engine,
171171
if (is_int8)
172172
fwd_attr_mask |= smask_t::scales_runtime
173173
| smask_t::zero_points_runtime
174-
| smask_t::zero_points_runtime_data_type;
174+
| smask_t::zero_points_runtime_data_type
175+
| smask_t::input_zero_points
176+
| smask_t::output_compensations
177+
| smask_t::weights_zero_points;
175178

176179
VCHECK_CONV_UNIMPL(attr->has_default_values(fwd_attr_mask, dst_dt),
177180
VERBOSE_UNSUPPORTED_ATTR);
@@ -208,17 +211,17 @@ status_t conv_attr_check(const convolution_desc_t &desc, const engine_t *engine,
208211
const auto &po = attr->post_ops_;
209212
using namespace primitive_kind;
210213
VCHECK_CONV_UNIMPL(po.has_default_values({binary, eltwise, prelu,
211-
sum, convolution}),
214+
sum, convolution, depthwise, quantization}),
212215
VERBOSE_UNSUPPORTED_POSTOP);
213216

214217
// Check sum
215218
VCHECK_CONV_UNIMPL(po.check_sum_consistency(dst_dt, is_int8, true),
216219
VERBOSE_UNSUPPORTED_POSTOP);
217220
}
218-
} else {
219-
auto bwd_attr_mask = smask_t::fpmath_mode;
220-
VCHECK_CONV_UNIMPL(attr->has_default_values(bwd_attr_mask),
221-
VERBOSE_UNSUPPORTED_ATTR);
221+
// } else {
222+
// auto bwd_attr_mask = smask_t::fpmath_mode;
223+
// VCHECK_CONV_UNIMPL(attr->has_default_values(bwd_attr_mask),
224+
// VERBOSE_UNSUPPORTED_ATTR);
222225
}
223226

224227
return status::success;

src/common/convolution_pd.hpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,8 @@ struct convolution_fwd_pd_t : public convolution_pd_t {
298298
}
299299

300300
int n_inputs() const override {
301-
// todo: [antonvor] uncomment when new behavior of dw convolution fusing from oneDNN 1.6 will be supported
302-
return 2 + with_bias() /* + attr_post_op_dw_inputs() */ + n_binary_po_inputs()
303-
+ n_prelu_po_inputs();
301+
return 2 + with_bias() + attr_post_op_dw_inputs() + n_binary_po_inputs()
302+
+ n_prelu_po_inputs() + n_depthwise_po_inputs() + n_quantization_po_inputs();
304303
}
305304

306305
int n_outputs() const override { return 1; }
@@ -330,8 +329,7 @@ struct convolution_fwd_pd_t : public convolution_pd_t {
330329
const auto &po = attr_.post_ops_;
331330
int conv = po.find(primitive_kind::convolution);
332331
if (conv == -1) return 0;
333-
return po.entry_[conv].depthwise_conv.bias_dt == data_type::undef ? 1
334-
: 2;
332+
return 2;
335333
}
336334
};
337335

@@ -379,7 +377,9 @@ struct convolution_bwd_data_pd_t : public convolution_pd_t {
379377
return &glob_zero_md;
380378
}
381379

382-
int n_inputs() const override { return 2 + with_bias(); }
380+
int n_inputs() const override {
381+
return 2 + with_bias() + n_depthwise_po_inputs() + n_quantization_po_inputs();
382+
}
383383
int n_outputs() const override { return 1; }
384384

385385
virtual bool support_bias() const { return false; }

src/common/deconvolution_pd.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ struct deconvolution_fwd_pd_t : public deconvolution_pd_t {
245245
}
246246

247247
int n_inputs() const override {
248-
return 2 + with_bias() + n_prelu_po_inputs() + n_binary_po_inputs();
248+
return 2 + with_bias() + n_prelu_po_inputs() + n_binary_po_inputs() + n_depthwise_po_inputs() + n_quantization_po_inputs();
249249
}
250250
int n_outputs() const override { return 1; }
251251

src/common/inner_product.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ status_t ip_attr_check(const inner_product_desc_t &desc, const engine_t *engine,
125125
is_int8 = is_int8
126126
|| utils::one_of(dst_dt, data_type::s8, data_type::u8,
127127
data_type::s32);
128-
if (is_int8) fwd_attr_mask |= smask_t::scales_runtime;
128+
if (is_int8) fwd_attr_mask |= smask_t::scales_runtime | smask_t::zero_points_runtime;
129129

130130
VCHECK_IP_UNIMPL(attr->has_default_values(fwd_attr_mask, dst_dt),
131131
VERBOSE_UNSUPPORTED_ATTR);

src/common/pooling.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ status_t pooling_attr_check(const pooling_desc_t &desc, const engine_t *engine,
151151
if (!attr->post_ops_.has_default_values()) {
152152
const auto &po = attr->post_ops_;
153153
using namespace primitive_kind;
154-
VCHECK_POOLING_IMPL(po.has_default_values({binary, eltwise}),
154+
VCHECK_POOLING_IMPL(po.has_default_values({binary, eltwise, quantization}),
155155
VERBOSE_UNSUPPORTED_POSTOP);
156156
}
157157
} else {

src/common/pooling_pd.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ struct pooling_fwd_pd_t : public pooling_pd_t {
215215
: &glob_zero_md;
216216
}
217217

218-
int n_inputs() const override { return 1 + n_binary_po_inputs(); }
218+
int n_inputs() const override { return 1 + n_binary_po_inputs() + n_depthwise_po_inputs() + n_quantization_po_inputs(); }
219219
int n_outputs() const override {
220220
return 1 + (!types::is_zero_md(workspace_md()));
221221
}

0 commit comments

Comments
 (0)