Skip to content

Commit b08ed65

Browse files
dmitry-gorokhovluweizhou2016
authored andcommitted
[FORK][FEATURE] Migrate legacy post ops and zero points on runtime data pointers
ONEDNN 3.2 migration squashed commits: - fix depthwise nwc conv - Fix deconv 3D post OPs segment fault issue. (#130) - fix in avx2 conv+fakequant post ops +nxc last channel wrong result (#128) - Luwei/fix deconv 3d postops bug (#136) -- Fix the deconv fused with depthwise issue in cpuFuncTests -- Switch to use jit_uni_depthwise_injector API. -- Fix potential conflicts in registers and YMM. -- Update with optimization. - fix legacyOps with stock src_zero_point in jit_avx512_core_amx - Fix incorrect offset to rsp - Preserve bf16emu scratch register when conflict with legacy post ops - fix per-OC legacyPostOps for jit_avx512_dw_conv_fwd_kernel_bf16 - Fix segment fault caused by dest scale.
1 parent 9e1e154 commit b08ed65

File tree

91 files changed

+1655
-936
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+1655
-936
lines changed

include/oneapi/dnnl/dnnl.h

+9-19
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "oneapi/dnnl/dnnl_config.h"
2525
#include "oneapi/dnnl/dnnl_types.h"
2626
#include "oneapi/dnnl/dnnl_version.h"
27+
#include <stdbool.h>
2728

2829
#ifdef __cplusplus
2930
extern "C" {
@@ -354,23 +355,14 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_mask(
354355
dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points_mask(
355356
dnnl_primitive_attr_t attr, int arg, int mask);
356357

357-
dnnl_status_t DNNL_API dnnl_primitive_attr_get_output_compensations(
358-
const_dnnl_primitive_attr_t attr, int *count, int *mask, const int32_t **compensations);
359-
360358
dnnl_status_t DNNL_API dnnl_primitive_attr_set_output_compensations(
361-
dnnl_primitive_attr_t attr, int count, int mask, const int32_t *compensations);
362-
363-
dnnl_status_t DNNL_API dnnl_primitive_attr_get_input_zero_points(
364-
const_dnnl_primitive_attr_t attr, int *count, int *mask, const uint8_t **zero_points);
359+
dnnl_primitive_attr_t attr, int count, int mask);
365360

366361
dnnl_status_t DNNL_API dnnl_primitive_attr_set_input_zero_points(
367-
dnnl_primitive_attr_t attr, int count, int mask, const uint8_t *zero_points);
368-
369-
dnnl_status_t DNNL_API dnnl_primitive_attr_get_weights_zero_points(
370-
const_dnnl_primitive_attr_t attr, int *count, int *mask, const float **zero_points);
362+
dnnl_primitive_attr_t attr, int count, int mask);
371363

372364
dnnl_status_t DNNL_API dnnl_primitive_attr_set_weights_zero_points(
373-
dnnl_primitive_attr_t attr, int count, int mask, const float *zero_points);
365+
dnnl_primitive_attr_t attr, int count, int mask);
374366

375367
/// Returns primitive attributes post-ops.
376368
///
@@ -578,8 +570,7 @@ dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw(
578570
///
579571
/// The kind of this post operation is #dnnl_convolution.
580572
dnnl_status_t DNNL_API dnnl_post_ops_append_dw_conv(
581-
dnnl_post_ops_t post_ops, int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, dnnl_data_type_t in_dt,
582-
const float* weights_data, const float* biases_data);
573+
dnnl_post_ops_t post_ops, int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, dnnl_data_type_t in_dt);
583574

584575
/// Appends a binary post-op.
585576
///
@@ -661,14 +652,13 @@ dnnl_status_t DNNL_API dnnl_post_ops_get_params_prelu(
661652
const_dnnl_post_ops_t post_ops, int index, int *mask);
662653

663654
dnnl_status_t DNNL_API dnnl_post_ops_append_depthwise(
664-
dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg,
665-
const float* weights_data, const float* biases_data);
655+
dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg, size_t offset_size, const size_t* offset);
666656

667657
dnnl_status_t DNNL_API dnnl_post_ops_append_quantization(
668658
dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg,
669-
const void* crop_low, const void* crop_high,
670-
const void* input_scale, const void* input_shift,
671-
const void* output_scale, const void* output_shift);
659+
size_t per_channel_size, const bool* per_channel,
660+
size_t all_default_size, const bool* all_default,
661+
size_t offset_size, const size_t* offset);
672662

673663
dnnl_status_t DNNL_API dnnl_post_ops_append_binarization(
674664
dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg, const float* weights_data, const float* output_mask);

include/oneapi/dnnl/dnnl.hpp

+20-65
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <memory>
3030
#include <string>
3131
#include <vector>
32+
#include <array>
3233
#include <unordered_map>
3334

3435
#include "oneapi/dnnl/dnnl.h"
@@ -148,6 +149,10 @@ struct primitive : public handle<dnnl_primitive_t> {
148149
layer_normalization = dnnl_layer_normalization,
149150
/// A group normalization primitive
150151
group_normalization = dnnl_group_normalization,
152+
153+
depthwise = dnnl_depthwise,
154+
quantization = dnnl_quantization,
155+
binarization = dnnl_binarization,
151156
};
152157

153158
using handle::handle;
@@ -168,7 +173,7 @@ struct primitive : public handle<dnnl_primitive_t> {
168173
const std::vector<uint8_t> &cache_blob);
169174

170175
/// Constructs a primitive from a primitive descriptor.
171-
///
176+
///src/common/deconvolution_pd.hpp
172177
/// @param pd Primitive descriptor.
173178
primitive(const primitive_desc &pd);
174179

@@ -3615,10 +3620,9 @@ struct post_ops : public handle<dnnl_post_ops_t> {
36153620
"could not append a binary post-op");
36163621
}
36173622

3618-
void append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, dnnl_data_type_t in_dt,
3619-
const float* weights_data, const float* biases_data) {
3623+
void append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w, dnnl_data_type_t in_dt) {
36203624
error::wrap_c_api(dnnl_post_ops_append_dw_conv(get(),
3621-
in_h, in_w, ker_h, ker_w, str_h, str_w, in_dt, weights_data, biases_data),
3625+
in_h, in_w, ker_h, ker_w, str_h, str_w, in_dt),
36223626
"could not append dw conv");
36233627
}
36243628

@@ -3707,19 +3711,15 @@ struct post_ops : public handle<dnnl_post_ops_t> {
37073711
"could not get parameters of a binary post-op");
37083712
}
37093713

3710-
void append_depthwise(algorithm alg, const float* weights_data,
3711-
const float* biases_data) {
3712-
error::wrap_c_api(dnnl_post_ops_append_depthwise(get(),
3713-
convert_to_c(alg), weights_data, biases_data),
3714+
void append_depthwise(algorithm alg, const std::array<size_t, 2>& offset) {
3715+
error::wrap_c_api(dnnl_post_ops_append_depthwise(get(), convert_to_c(alg), offset.size(), offset.data()),
37143716
"could not append depthwise");
37153717
}
37163718

3717-
void append_quantization(algorithm alg,
3718-
const void* crop_low, const void* crop_high,
3719-
const void* input_scale, const void* input_shift,
3720-
const void* output_scale, const void* output_shift) {
3721-
error::wrap_c_api(dnnl_post_ops_append_quantization(get(), convert_to_c(alg), crop_low, crop_high,
3722-
input_scale, input_shift, output_scale, output_shift),
3719+
void append_quantization(algorithm alg, const std::array<bool, 6>& per_channel, const std::array<bool, 6>& all_default,
3720+
const std::array<size_t, 6>& offset) {
3721+
error::wrap_c_api(dnnl_post_ops_append_quantization(get(), convert_to_c(alg), per_channel.size(), per_channel.data(),
3722+
all_default.size(), all_default.data(), offset.size(), offset.data()),
37233723
"could not append quantization");
37243724
}
37253725

@@ -3832,66 +3832,21 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
38323832
"could not set zero points primitive attribute");
38333833
}
38343834

3835-
void get_output_compensations(int &mask, std::vector<int32_t> &compensations) const
3835+
void set_output_compensations(dnnl_dim_t count, int mask)
38363836
{
3837-
int count, c_mask;
3838-
const int32_t *c_compensations;
3839-
error::wrap_c_api(dnnl_primitive_attr_get_output_compensations(get(),
3840-
&count, &c_mask, &c_compensations),
3841-
"could not get int output compensations");
3842-
compensations.resize(count);
3843-
3844-
mask = c_mask;
3845-
for (int c = 0; c < count; ++c)
3846-
compensations[c] = c_compensations[c];
3847-
}
3848-
3849-
void set_output_compensations(int mask, const std::vector<int32_t> &compensations)
3850-
{
3851-
error::wrap_c_api(dnnl_primitive_attr_set_output_compensations(get(),
3852-
(int)compensations.size(), mask, &compensations[0]),
3837+
error::wrap_c_api(dnnl_primitive_attr_set_output_compensations(get(), count, mask),
38533838
"could not set int output compensations");
38543839
}
38553840

3856-
void get_input_zero_points(int &mask, std::vector<uint8_t> &zero_points) const
3841+
void set_input_zero_points(dnnl_dim_t count, int mask)
38573842
{
3858-
int count, c_mask;
3859-
const uint8_t *c_zero_points;
3860-
error::wrap_c_api(dnnl_primitive_attr_get_input_zero_points(get(),
3861-
&count, &c_mask, &c_zero_points),
3862-
"could not get int input zero_points");
3863-
zero_points.resize(count);
3864-
3865-
mask = c_mask;
3866-
for (int c = 0; c < count; ++c)
3867-
zero_points[c] = c_zero_points[c];
3868-
}
3869-
3870-
void set_input_zero_points(int mask, const std::vector<uint8_t> &zero_points)
3871-
{
3872-
error::wrap_c_api(dnnl_primitive_attr_set_input_zero_points(get(),
3873-
(int)zero_points.size(), mask, &zero_points[0]),
3843+
error::wrap_c_api(dnnl_primitive_attr_set_input_zero_points(get(), count, mask),
38743844
"could not set int input zero_points");
38753845
}
38763846

3877-
void get_weights_zero_points(int &mask, std::vector<int8_t> &zero_points) const
3878-
{
3879-
int count, c_mask;
3880-
const float *c_zero_points;
3881-
error::wrap_c_api(dnnl_primitive_attr_get_weights_zero_points(get(),
3882-
&count, &c_mask, &c_zero_points),
3883-
"could not get int weights zero_points");
3884-
zero_points.resize(count);
3885-
3886-
mask = c_mask;
3887-
for (int c = 0; c < count; ++c)
3888-
zero_points[c] = c_zero_points[c];
3889-
}
3890-
3891-
void set_weights_zero_points(int mask, const std::vector<float> &zero_points)
3847+
void set_weights_zero_points(dnnl_dim_t count, int mask)
38923848
{
3893-
error::wrap_c_api(dnnl_primitive_attr_set_weights_zero_points(get(),
3894-
(int)zero_points.size(), mask, &zero_points[0]),
3849+
error::wrap_c_api(dnnl_primitive_attr_set_weights_zero_points(get(), count, mask),
38953850
"could not set int weights zero_points");
38963851
}
38973852

src/common/convolution_pd.hpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,8 @@ struct convolution_fwd_pd_t : public convolution_pd_t {
285285
}
286286

287287
int n_inputs() const override {
288-
// todo: [antonvor] uncomment when new behavior of dw convolution fusing from oneDNN 1.6 will be supported
289-
return 2 + with_bias() /* + attr_post_op_dw_inputs() */ + n_binary_po_inputs()
290-
+ n_prelu_po_inputs();
288+
return 2 + with_bias() + attr_post_op_dw_inputs() + n_binary_po_inputs()
289+
+ n_prelu_po_inputs() + n_depthwise_po_inputs() + n_quantization_po_inputs();
291290
}
292291

293292
int n_outputs() const override { return 1; }
@@ -317,8 +316,7 @@ struct convolution_fwd_pd_t : public convolution_pd_t {
317316
const auto &po = attr_.post_ops_;
318317
int conv = po.find(primitive_kind::convolution);
319318
if (conv == -1) return 0;
320-
return po.entry_[conv].depthwise_conv.bias_dt == data_type::undef ? 1
321-
: 2;
319+
return 2;
322320
}
323321
};
324322

@@ -366,7 +364,9 @@ struct convolution_bwd_data_pd_t : public convolution_pd_t {
366364
return &glob_zero_md;
367365
}
368366

369-
int n_inputs() const override { return 2 + with_bias(); }
367+
int n_inputs() const override {
368+
return 2 + with_bias() + n_depthwise_po_inputs() + n_quantization_po_inputs();
369+
}
370370
int n_outputs() const override { return 1; }
371371

372372
virtual bool support_bias() const { return false; }

src/common/deconvolution_pd.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ struct deconvolution_fwd_pd_t : public deconvolution_pd_t {
232232
}
233233

234234
int n_inputs() const override {
235-
return 2 + with_bias() + n_prelu_po_inputs() + n_binary_po_inputs();
235+
return 2 + with_bias() + n_prelu_po_inputs() + n_binary_po_inputs() + n_depthwise_po_inputs() + n_quantization_po_inputs();
236236
}
237237
int n_outputs() const override { return 1; }
238238

src/common/pooling_pd.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ struct pooling_fwd_pd_t : public pooling_pd_t {
202202
: &glob_zero_md;
203203
}
204204

205-
int n_inputs() const override { return 1 + n_binary_po_inputs(); }
205+
int n_inputs() const override { return 1 + n_binary_po_inputs() + n_depthwise_po_inputs() + n_quantization_po_inputs(); }
206206
int n_outputs() const override {
207207
return 1 + (!types::is_zero_md(workspace_md()));
208208
}

0 commit comments

Comments
 (0)