@@ -32,22 +32,16 @@ namespace gpu {
32
32
namespace generic {
33
33
namespace sycl {
34
34
35
- static bool check_convolution_data_types (const memory_desc_wrapper &src0,
35
+ inline bool check_convolution_data_types (const memory_desc_wrapper &src0,
36
36
const memory_desc_wrapper &src1, const memory_desc_wrapper &dst) {
37
- using namespace data_type ;
38
-
39
- const auto src0_dt = src0.data_type ();
40
- const auto src1_dt = src1.data_type ();
41
- const auto dst_dt = dst.data_type ();
42
-
43
- for (auto t : {src0_dt, src1_dt, dst_dt}) {
44
- if (!utils::one_of (t, f32, bf16, f16, s32, s8, u8)) return false ;
37
+ for (const auto &mdw : {src0, src1, dst}) {
38
+ if (!is_supported_type (mdw.data_type ())) return false ;
45
39
}
46
40
47
41
return true ;
48
42
}
49
43
50
- static bool check_convolution_formats (const memory_desc_wrapper &src0,
44
+ inline bool check_convolution_formats (const memory_desc_wrapper &src0,
51
45
const memory_desc_wrapper &src1, const memory_desc_wrapper &dst) {
52
46
using namespace format_tag ;
53
47
@@ -57,7 +51,7 @@ static bool check_convolution_formats(const memory_desc_wrapper &src0,
57
51
return true ;
58
52
}
59
53
60
- static bool check_convolution_work_amount (
54
+ inline bool check_convolution_work_amount (
61
55
const memory_desc_wrapper &weights, dim_t OC) {
62
56
auto elems = weights.nelems ();
63
57
auto work_per_output = elems / OC;
@@ -66,6 +60,18 @@ static bool check_convolution_work_amount(
66
60
return work_per_output < 200000 ;
67
61
}
68
62
63
+ inline bool check_convolution_scales_types (const primitive_attr_t *attr) {
64
+ const std::vector<int > supported_args
65
+ = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
66
+
67
+ const auto &scales = attr->scales_ ;
68
+ for (auto arg : supported_args) {
69
+ auto dt = scales.get (arg).data_type_ ;
70
+ if (!is_supported_type (dt)) { return false ; }
71
+ }
72
+ return true ;
73
+ }
74
+
69
75
struct ref_convolution_fwd_t : public gpu ::generic::sycl::primitive_t {
70
76
using gpu::generic::sycl::primitive_t ::primitive_t ;
71
77
@@ -92,7 +98,8 @@ struct ref_convolution_fwd_t : public gpu::generic::sycl::primitive_t {
92
98
| sm::zero_points_runtime | sm::post_ops
93
99
| sm::sum_dt)
94
100
&& IMPLICATION (!attr ()->scales_ .has_default_values (),
95
- attr_scales_ok ())
101
+ attr_scales_ok ()
102
+ && check_convolution_scales_types (attr ()))
96
103
&& sycl_post_ops_t::post_ops_ok (attr (), false );
97
104
if (!ok) return status::unimplemented;
98
105
@@ -148,7 +155,8 @@ struct ref_convolution_bwd_data_t : public gpu::generic::sycl::primitive_t {
148
155
&& attr ()->has_default_values (sm::scales_runtime
149
156
| sm::zero_points_runtime | sm::sum_dt)
150
157
&& IMPLICATION (!attr ()->scales_ .has_default_values (),
151
- attr_scales_ok ());
158
+ attr_scales_ok ()
159
+ && check_convolution_scales_types (attr ()));
152
160
if (!ok) return status::unimplemented;
153
161
154
162
return init_conf ();
@@ -203,7 +211,8 @@ struct ref_convolution_bwd_weights_t : public gpu::generic::sycl::primitive_t {
203
211
&& attr ()->has_default_values (sm::scales_runtime
204
212
| sm::zero_points_runtime | sm::sum_dt)
205
213
&& IMPLICATION (!attr ()->scales_ .has_default_values (),
206
- attr_scales_ok ());
214
+ attr_scales_ok ()
215
+ && check_convolution_scales_types (attr ()));
207
216
if (!ok) return status::unimplemented;
208
217
209
218
return init_conf ();
0 commit comments