Skip to content

Commit 5f18f4d

Browse files
t4c1mgouicem
authored andcommitted
generic: sycl: add missing type checks on scales
1 parent f2eb2bd commit 5f18f4d

9 files changed

+114
-48
lines changed

src/gpu/generic/sycl/ref_binary.hpp

+11-4
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ struct ref_binary_t : public gpu::generic::sycl::primitive_t {
5454
&& check_formats(src0_d, src1_d, dst_d)
5555
&& attr()->has_default_values(
5656
sm::scales_runtime | sm::post_ops)
57-
&& IMPLICATION(!attr()->scales_.has_default_values(),
58-
check_scales_mask())
57+
&& IMPLICATION(
58+
!attr()->scales_.has_default_values(), scales_ok())
5959
&& sycl_post_ops_t::post_ops_ok(attr())
6060
&& md_dims_in_range(src_md(0))
6161
&& md_dims_in_range(src_md(1))
@@ -70,10 +70,17 @@ struct ref_binary_t : public gpu::generic::sycl::primitive_t {
7070
private:
7171
status_t init_conf();
7272

73-
bool check_scales_mask() const {
73+
bool scales_ok() const {
7474
const std::vector<int> supported_args
7575
= {DNNL_ARG_SRC_0, DNNL_ARG_SRC_1};
76-
return attr_scales_ok(supported_args);
76+
77+
const auto &scales = attr()->scales_;
78+
bool dt_ok = true;
79+
for (auto arg : supported_args) {
80+
auto &s = scales.get(arg);
81+
dt_ok = dt_ok && is_supported_type(s.data_type_);
82+
}
83+
return dt_ok && attr_scales_ok(supported_args);
7784
}
7885

7986
static bool check_data_types(const memory_desc_wrapper &src0,

src/gpu/generic/sycl/ref_convolution.hpp

+23-14
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,16 @@ namespace gpu {
3232
namespace generic {
3333
namespace sycl {
3434

35-
static bool check_convolution_data_types(const memory_desc_wrapper &src0,
35+
inline bool check_convolution_data_types(const memory_desc_wrapper &src0,
3636
const memory_desc_wrapper &src1, const memory_desc_wrapper &dst) {
37-
using namespace data_type;
38-
39-
const auto src0_dt = src0.data_type();
40-
const auto src1_dt = src1.data_type();
41-
const auto dst_dt = dst.data_type();
42-
43-
for (auto t : {src0_dt, src1_dt, dst_dt}) {
44-
if (!utils::one_of(t, f32, bf16, f16, s32, s8, u8)) return false;
37+
for (const auto &mdw : {src0, src1, dst}) {
38+
if (!is_supported_type(mdw.data_type())) return false;
4539
}
4640

4741
return true;
4842
}
4943

50-
static bool check_convolution_formats(const memory_desc_wrapper &src0,
44+
inline bool check_convolution_formats(const memory_desc_wrapper &src0,
5145
const memory_desc_wrapper &src1, const memory_desc_wrapper &dst) {
5246
using namespace format_tag;
5347

@@ -57,7 +51,7 @@ static bool check_convolution_formats(const memory_desc_wrapper &src0,
5751
return true;
5852
}
5953

60-
static bool check_convolution_work_amount(
54+
inline bool check_convolution_work_amount(
6155
const memory_desc_wrapper &weights, dim_t OC) {
6256
auto elems = weights.nelems();
6357
auto work_per_output = elems / OC;
@@ -66,6 +60,18 @@ static bool check_convolution_work_amount(
6660
return work_per_output < 200000;
6761
}
6862

63+
inline bool check_convolution_scales_types(const primitive_attr_t *attr) {
64+
const std::vector<int> supported_args
65+
= {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
66+
67+
const auto &scales = attr->scales_;
68+
for (auto arg : supported_args) {
69+
auto dt = scales.get(arg).data_type_;
70+
if (!is_supported_type(dt)) { return false; }
71+
}
72+
return true;
73+
}
74+
6975
struct ref_convolution_fwd_t : public gpu::generic::sycl::primitive_t {
7076
using gpu::generic::sycl::primitive_t::primitive_t;
7177

@@ -92,7 +98,8 @@ struct ref_convolution_fwd_t : public gpu::generic::sycl::primitive_t {
9298
| sm::zero_points_runtime | sm::post_ops
9399
| sm::sum_dt)
94100
&& IMPLICATION(!attr()->scales_.has_default_values(),
95-
attr_scales_ok())
101+
attr_scales_ok()
102+
&& check_convolution_scales_types(attr()))
96103
&& sycl_post_ops_t::post_ops_ok(attr(), false);
97104
if (!ok) return status::unimplemented;
98105

@@ -148,7 +155,8 @@ struct ref_convolution_bwd_data_t : public gpu::generic::sycl::primitive_t {
148155
&& attr()->has_default_values(sm::scales_runtime
149156
| sm::zero_points_runtime | sm::sum_dt)
150157
&& IMPLICATION(!attr()->scales_.has_default_values(),
151-
attr_scales_ok());
158+
attr_scales_ok()
159+
&& check_convolution_scales_types(attr()));
152160
if (!ok) return status::unimplemented;
153161

154162
return init_conf();
@@ -203,7 +211,8 @@ struct ref_convolution_bwd_weights_t : public gpu::generic::sycl::primitive_t {
203211
&& attr()->has_default_values(sm::scales_runtime
204212
| sm::zero_points_runtime | sm::sum_dt)
205213
&& IMPLICATION(!attr()->scales_.has_default_values(),
206-
attr_scales_ok());
214+
attr_scales_ok()
215+
&& check_convolution_scales_types(attr()));
207216
if (!ok) return status::unimplemented;
208217

209218
return init_conf();

src/gpu/generic/sycl/ref_layer_normalizations.hpp

+21-9
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,31 @@ struct ref_layer_normalization_fwd_t : public gpu::generic::sycl::primitive_t {
5454

5555
const bool ok = is_fwd()
5656
&& (src_md(0)->format_desc.blocking.inner_nblks == 0)
57-
&& utils::one_of(
58-
src_md(0)->data_type, f32, bf16, f16, s8, u8)
59-
&& utils::one_of(
60-
dst_md(0)->data_type, f32, bf16, f16, s8, u8)
61-
&& stat_md()->data_type == f32
57+
&& is_supported_type(src_md(0)->data_type)
58+
&& is_supported_type(dst_md(0)->data_type)
59+
&& is_supported_type(stat_md()->data_type)
6260
&& check_scale_shift_data_type({f32, bf16, f16})
6361
&& attr()->has_default_values(sm::scales_runtime)
62+
&& IMPLICATION(
63+
!attr()->scales_.has_default_values(), scales_ok())
6464
&& attr_scales_ok() && set_default_formats_common()
6565
&& md_dims_in_range(src_md());
6666
if (!ok) return status::unimplemented;
6767
return init_conf();
6868
}
6969

70+
bool scales_ok() const {
71+
const std::vector<int> supported_args
72+
= {DNNL_ARG_SRC, DNNL_ARG_DST};
73+
74+
const auto &scales = attr()->scales_;
75+
for (auto arg : supported_args) {
76+
auto dt = scales.get(arg).data_type_;
77+
if (!is_supported_type(dt)) { return false; }
78+
}
79+
return true;
80+
}
81+
7082
status_t init_conf();
7183
sycl_layer_normalization_conf_t conf_;
7284
};
@@ -105,10 +117,10 @@ struct ref_layer_normalization_bwd_t : public gpu::generic::sycl::primitive_t {
105117
const bool ok = !is_fwd()
106118
&& (src_md(0)->format_desc.blocking.inner_nblks == 0)
107119
&& (diff_dst_md(0)->format_desc.blocking.inner_nblks == 0)
108-
&& utils::one_of(src_md(0)->data_type, f32, bf16)
109-
&& utils::one_of(diff_dst_md(0)->data_type, f32, bf16)
110-
&& utils::one_of(diff_src_md(0)->data_type, f32, bf16)
111-
&& stat_md()->data_type == f32
120+
&& is_supported_type(src_md(0)->data_type)
121+
&& is_supported_type(diff_dst_md(0)->data_type)
122+
&& is_supported_type(diff_src_md(0)->data_type)
123+
&& is_supported_type(stat_md()->data_type)
112124
&& check_scale_shift_data_type({f32, bf16, f16})
113125
&& attr()->has_default_values()
114126
&& set_default_formats_common()

src/gpu/generic/sycl/ref_matmul.hpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,14 @@ struct ref_matmul_t : public gpu::generic::sycl::primitive_t {
108108
}
109109

110110
bool scales_ok() const {
111-
using namespace data_type;
112111
const std::vector<int> supported_args
113112
= {DNNL_ARG_SRC_0, DNNL_ARG_WEIGHTS_0, DNNL_ARG_DST};
114113

115114
const auto &scales = attr()->scales_;
116115
bool dt_ok = true;
117116
for (auto arg : supported_args) {
118117
auto &s = scales.get(arg);
119-
dt_ok = dt_ok
120-
&& utils::one_of(s.data_type_, s8, s32, f32, f16, bf16);
118+
dt_ok = dt_ok && is_supported_type(s.data_type_);
121119
}
122120
return dt_ok && attr_scales_ok(supported_args);
123121
}

src/gpu/generic/sycl/ref_prelu.hpp

+21
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ struct ref_prelu_fwd_t : public gpu::generic::sycl::primitive_t {
5454
const bool ok = is_fwd() && set_default_formats()
5555
&& (src_md(0)->format_desc.blocking.inner_nblks == 0)
5656
&& (weights_md(0)->format_desc.blocking.inner_nblks == 0)
57+
&& check_data_types(data_d, weights_d, dst_d)
5758
&& md_dims_in_range(src_md())
5859
&& md_dims_in_range(weights_md());
5960

@@ -63,6 +64,15 @@ struct ref_prelu_fwd_t : public gpu::generic::sycl::primitive_t {
6364

6465
status_t init_conf();
6566
sycl_prelu_conf_t conf_;
67+
68+
static bool check_data_types(const memory_desc_wrapper &src,
69+
const memory_desc_wrapper &wei,
70+
const memory_desc_wrapper &dst) {
71+
for (const auto &mdw : {src, wei, dst}) {
72+
if (!is_supported_type(mdw.data_type())) return false;
73+
}
74+
return true;
75+
}
6676
};
6777

6878
status_t init(impl::engine_t *engine) override;
@@ -97,6 +107,7 @@ struct ref_prelu_bwd_t : public gpu::generic::sycl::primitive_t {
97107
&& (weights_md(0)->format_desc.blocking.inner_nblks == 0)
98108
&& diff_src_md(0)->data_type == src_md(0)->data_type
99109
&& diff_weights_md(0)->data_type == weights_md(0)->data_type
110+
&& check_data_types(data_d, weights_d, diff_dst_d)
100111
&& md_dims_in_range(diff_src_md())
101112
&& md_dims_in_range(weights_md());
102113

@@ -113,6 +124,16 @@ struct ref_prelu_bwd_t : public gpu::generic::sycl::primitive_t {
113124
status_t init_reduction(impl::engine_t *engine);
114125
void init_scratchpad();
115126

127+
static bool check_data_types(const memory_desc_wrapper &src,
128+
const memory_desc_wrapper &wei,
129+
const memory_desc_wrapper &dst) {
130+
for (const auto &mdw : {src, wei, dst}) {
131+
if (!is_supported_type(mdw.data_type())) return false;
132+
}
133+
134+
return true;
135+
}
136+
116137
sycl_prelu_conf_t conf_;
117138
bool reduce_diff_weights_ = false;
118139
memory_desc_t scratch_md_;

src/gpu/generic/sycl/ref_reorder.hpp

+16-7
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ struct ref_reorder_t : public gpu::generic::sycl::primitive_t {
5454
&& check_formats(src_d, dst_d)
5555
&& attr()->has_default_values(
5656
sm::scales_runtime | sm::post_ops)
57+
&& IMPLICATION(
58+
!attr()->scales_.has_default_values(), scales_ok())
5759
&& sycl_post_ops_t::post_ops_ok(attr())
5860
&& md_dims_in_range(dst_md());
5961
if (!ok) return status::unimplemented;
@@ -70,13 +72,8 @@ struct ref_reorder_t : public gpu::generic::sycl::primitive_t {
7072

7173
static bool check_data_types(const memory_desc_wrapper &src,
7274
const memory_desc_wrapper &dst) {
73-
using namespace data_type;
74-
75-
const auto src_dt = src.data_type();
76-
const auto dst_dt = dst.data_type();
77-
78-
for (auto t : {src_dt, dst_dt}) {
79-
if (!utils::one_of(t, f32, bf16, f16, s8, u8)) return false;
75+
for (const auto &mdw : {src, dst}) {
76+
if (!is_supported_type(mdw.data_type())) return false;
8077
}
8178

8279
return true;
@@ -91,6 +88,18 @@ struct ref_reorder_t : public gpu::generic::sycl::primitive_t {
9188
}
9289
return true;
9390
}
91+
92+
bool scales_ok() const {
93+
const std::vector<int> supported_args
94+
= {DNNL_ARG_SRC, DNNL_ARG_DST};
95+
96+
const auto &scales = attr()->scales_;
97+
for (auto arg : supported_args) {
98+
auto dt = scales.get(arg).data_type_;
99+
if (!is_supported_type(dt)) { return false; }
100+
}
101+
return true;
102+
}
94103
};
95104

96105
status_t init(impl::engine_t *engine) override;

src/gpu/generic/sycl/ref_resampling.hpp

+5-7
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,14 @@ struct ref_resampling_fwd_t : public gpu::generic::sycl::primitive_t {
4141
DECLARE_COMMON_PD_T("dpcpp:ref:any", ref_resampling_fwd_t);
4242

4343
status_t init(impl::engine_t *engine) {
44-
using namespace data_type;
4544
using namespace prop_kind;
4645
using namespace alg_kind;
4746
using sm = primitive_attr_t::skip_mask_t;
4847
const memory_desc_wrapper src_d(src_md(0));
4948
const memory_desc_wrapper dst_d(dst_md(0));
5049

51-
const bool ok = is_fwd()
52-
&& utils::one_of(
53-
src_md(0)->data_type, f32, bf16, f16, s32, s8, u8)
54-
&& utils::one_of(
55-
dst_md(0)->data_type, f32, bf16, f16, s32, s8, u8)
50+
const bool ok = is_fwd() && is_supported_type(src_md(0)->data_type)
51+
&& is_supported_type(dst_md(0)->data_type)
5652
&& attr()->has_default_values(sm::post_ops)
5753
&& set_default_params() == status::success
5854
&& attr_.set_default_formats(dst_md(0)) == status::success
@@ -92,7 +88,9 @@ struct ref_resampling_bwd_t : public gpu::generic::sycl::primitive_t {
9288
const memory_desc_wrapper diff_dst_d(diff_dst_md(0));
9389
const memory_desc_wrapper diff_src_d(diff_src_md(0));
9490

95-
bool ok = !is_fwd() && set_default_params() == status::success
91+
bool ok = !is_fwd() && is_supported_type(src_md(0)->data_type)
92+
&& is_supported_type(dst_md(0)->data_type)
93+
&& set_default_params() == status::success
9694
&& (src_md(0)->format_desc.blocking.inner_nblks == 0)
9795
&& (diff_dst_md(0)->format_desc.blocking.inner_nblks == 0)
9896
&& attr()->has_default_values()

src/gpu/generic/sycl/ref_sum.hpp

+11-4
Original file line numberDiff line numberDiff line change
@@ -40,27 +40,34 @@ struct ref_sum_t : public gpu::generic::sycl::primitive_t {
4040
DECLARE_SUM_PD_T("dpcpp:ref:any", ref_sum_t);
4141

4242
status_t init(impl::engine_t *engine) {
43-
using namespace data_type;
4443
using namespace format_tag;
4544

4645
const memory_desc_wrapper dst_d(dst_md());
47-
if (!utils::one_of(dst_d.data_type(), f32, bf16, f16, s8, u8))
46+
if (!is_supported_type(dst_d.data_type()))
4847
return status::unimplemented;
4948
// Block formats are not yet supported
5049
// Dimensions can not be > 6
5150
if (!dst_d.is_plain() || dst_d.ndims() > xpu::sycl::md_t::max_dims)
5251
return status::unimplemented;
5352

5453
const int n = n_inputs();
54+
const auto &scales = attr()->scales_;
5555
for (auto i = 0; i < n; ++i) {
5656
const memory_desc_wrapper src_d(src_md(i));
57-
if (!utils::one_of(src_d.data_type(), f32, bf16, f16, s8, u8))
57+
if (!is_supported_type(src_d.data_type())) {
5858
return status::unimplemented;
59+
}
5960
// Block formats are not yet supported
6061
// Dimensions can not be > 6
6162
if (!src_d.is_plain()
62-
|| src_d.ndims() > xpu::sycl::md_t::max_dims)
63+
|| src_d.ndims() > xpu::sycl::md_t::max_dims) {
6364
return status::unimplemented;
65+
}
66+
if (!attr()->scales_.has_default_values()
67+
&& !is_supported_type(
68+
scales.get(DNNL_ARG_SRC + i).data_type_)) {
69+
return status::unimplemented;
70+
}
6471
}
6572

6673
const bool ok = set_default_params() == status::success

src/gpu/generic/sycl/sycl_io_helper.hpp

+5
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ namespace gpu {
2828
namespace generic {
2929
namespace sycl {
3030

31+
inline bool is_supported_type(data_type_t dt) {
32+
using namespace data_type;
33+
return utils::one_of(dt, f32, f16, bf16, s32, s8, u8);
34+
}
35+
3136
inline int load_int_value(data_type_t dt, const void *ptr, dim_t idx) {
3237
#define CASE(dt) \
3338
case dt: \

0 commit comments

Comments
 (0)