|
1 | 1 | /*******************************************************************************
|
2 | 2 | * Copyright 2021 Intel Corporation
|
3 |
| -* Copyright 2021 FUJITSU LIMITED |
| 3 | +* Copyright 2021-2024 FUJITSU LIMITED |
4 | 4 | *
|
5 | 5 | * Licensed under the Apache License, Version 2.0 (the "License");
|
6 | 6 | * you may not use this file except in compliance with the License.
|
@@ -95,180 +95,206 @@ struct jit_uni_dw_convolution_fwd_t : public primitive_t {
|
95 | 95 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
|
96 | 96 |
|
97 | 97 | std::unique_ptr<jit_uni_dw_conv_fwd_kernel<isa, src_type>> kernel_;
|
98 |
| -}; |
99 |
| - |
100 |
| -using jit_sve_512_dw_convolution_fwd_t |
101 |
| - = jit_uni_dw_convolution_fwd_t<sve_512, data_type::f32>; |
102 |
| - |
103 |
| -template <cpu_isa_t isa, data_type_t diff_dst_type, |
104 |
| - data_type_t diff_src_type = diff_dst_type> |
105 |
| -struct jit_uni_dw_convolution_bwd_data_t : public primitive_t { |
106 |
| - struct pd_t : public cpu_convolution_bwd_data_pd_t { |
107 |
| - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
108 |
| - const convolution_fwd_pd_t *hint_fwd_pd) |
109 |
| - : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} |
110 |
| - |
111 |
| - DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_dw:", jcp_.isa, ""), |
112 |
| - jit_uni_dw_convolution_bwd_data_t); |
113 |
| - |
114 |
| - status_t init(engine_t *engine) { |
115 |
| - bool ok = true && desc()->prop_kind == prop_kind::backward_data |
116 |
| - && set_default_alg_kind(alg_kind::convolution_direct) |
117 |
| - && expect_data_types(diff_src_type, diff_dst_type, |
118 |
| - data_type::undef, diff_dst_type, data_type::f32) |
119 |
| - && attr()->has_default_values() && !has_zero_dim_memory() |
120 |
| - && set_default_formats(); |
121 |
| - |
122 |
| - if (!ok) return status::unimplemented; |
123 |
| - |
124 |
| - status_t status = jit_uni_dw_conv_bwd_data_kernel<isa, |
125 |
| - diff_dst_type>::init_conf(jcp_, *desc(), *diff_src_md(), |
126 |
| - *weights_md(), *diff_dst_md()); |
127 |
| - if (status != status::success) return status; |
128 |
| - |
129 |
| - auto scratchpad = scratchpad_registry().registrar(); |
130 |
| - jit_uni_dw_conv_bwd_data_kernel<isa, |
131 |
| - diff_dst_type>::init_scratchpad(scratchpad, jcp_); |
| 98 | + using jit_sve_512_dw_convolution_fwd_t |
| 99 | + = jit_uni_dw_convolution_fwd_t<sve_512, data_type::f32>; |
| 100 | + using jit_sve_256_dw_convolution_fwd_t |
| 101 | + = jit_uni_dw_convolution_fwd_t<sve_256, data_type::f32>; |
| 102 | + |
| 103 | + template <cpu_isa_t isa, data_type_t diff_dst_type, |
| 104 | + data_type_t diff_src_type = diff_dst_type> |
| 105 | + struct jit_uni_dw_convolution_bwd_data_t : public primitive_t { |
| 106 | + struct pd_t : public cpu_convolution_bwd_data_pd_t { |
| 107 | + pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
| 108 | + const convolution_fwd_pd_t *hint_fwd_pd) |
| 109 | + : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) |
| 110 | + , jcp_() {} |
| 111 | + |
| 112 | + DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_dw:", jcp_.isa, ""), |
| 113 | + jit_uni_dw_convolution_bwd_data_t); |
| 114 | + |
| 115 | + status_t init(engine_t *engine) { |
| 116 | + bool ok = true && desc()->prop_kind == prop_kind::backward_data |
| 117 | + && set_default_alg_kind(alg_kind::convolution_direct) |
| 118 | + && expect_data_types(diff_src_type, diff_dst_type, |
| 119 | + data_type::undef, diff_dst_type, data_type::f32) |
| 120 | + && attr()->has_default_values() |
| 121 | + && !has_zero_dim_memory() && set_default_formats(); |
| 122 | + |
| 123 | + if (!ok) return status::unimplemented; |
| 124 | + |
| 125 | + status_t status = jit_uni_dw_conv_bwd_data_kernel<isa, |
| 126 | + diff_dst_type>::init_conf(jcp_, *desc(), *diff_src_md(), |
| 127 | + *weights_md(), *diff_dst_md()); |
| 128 | + if (status != status::success) return status; |
| 129 | + |
| 130 | + auto scratchpad = scratchpad_registry().registrar(); |
| 131 | + jit_uni_dw_conv_bwd_data_kernel<isa, |
| 132 | + diff_dst_type>::init_scratchpad(scratchpad, jcp_); |
| 133 | + |
| 134 | + return status::success; |
| 135 | + } |
| 136 | + |
| 137 | + jit_conv_conf_t jcp_; |
| 138 | + |
| 139 | + protected: |
| 140 | + bool set_default_formats() { |
| 141 | + |
| 142 | + using namespace format_tag; |
| 143 | + format_tag_t dat_tag, wei_tag; |
| 144 | + switch (isa) { |
| 145 | + case sve_512: |
| 146 | + dat_tag = nChw16c; |
| 147 | + wei_tag = Goihw16g; |
| 148 | + break; |
| 149 | + case sve_256: |
| 150 | + dat_tag = nChw8c; |
| 151 | + wei_tag = Goihw8g; |
| 152 | + break; |
| 153 | + default: return false; |
| 154 | + } |
| 155 | + return set_default_formats_common(dat_tag, wei_tag, dat_tag); |
| 156 | + } |
| 157 | + }; |
| 158 | + |
| 159 | + jit_uni_dw_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} |
| 160 | + |
| 161 | + typedef typename prec_traits<diff_src_type>::type diff_src_data_t; |
| 162 | + typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t; |
| 163 | + typedef typename prec_traits<diff_dst_type>::type wei_data_t; |
| 164 | + |
| 165 | + status_t init(engine_t *engine) override { |
| 166 | + CHECK(safe_ptr_assign(kernel_, |
| 167 | + new jit_uni_dw_conv_bwd_data_kernel<isa, diff_dst_type>( |
| 168 | + pd()->jcp_))); |
| 169 | + return kernel_->create_kernel(); |
| 170 | + } |
132 | 171 |
|
| 172 | + status_t execute(const exec_ctx_t &ctx) const override { |
| 173 | + execute_backward_data(ctx); |
133 | 174 | return status::success;
|
134 | 175 | }
|
135 | 176 |
|
136 |
| - jit_conv_conf_t jcp_; |
137 |
| - |
138 |
| - protected: |
139 |
| - bool set_default_formats() { |
140 |
| - using namespace format_tag; |
| 177 | + private: |
| 178 | + void execute_backward_data(const exec_ctx_t &ctx) const; |
| 179 | + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
141 | 180 |
|
142 |
| - auto dat_tag = nChw16c; |
143 |
| - auto wei_tag = Goihw16g; |
144 |
| - |
145 |
| - return set_default_formats_common(dat_tag, wei_tag, dat_tag); |
146 |
| - } |
| 181 | + std::unique_ptr<jit_uni_dw_conv_bwd_data_kernel<isa, diff_dst_type>> |
| 182 | + kernel_; |
147 | 183 | };
|
148 | 184 |
|
149 |
| - jit_uni_dw_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} |
150 |
| - |
151 |
| - typedef typename prec_traits<diff_src_type>::type diff_src_data_t; |
152 |
| - typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t; |
153 |
| - typedef typename prec_traits<diff_dst_type>::type wei_data_t; |
154 |
| - |
155 |
| - status_t init(engine_t *engine) override { |
156 |
| - CHECK(safe_ptr_assign(kernel_, |
157 |
| - new jit_uni_dw_conv_bwd_data_kernel<isa, diff_dst_type>( |
158 |
| - pd()->jcp_))); |
159 |
| - return kernel_->create_kernel(); |
160 |
| - } |
161 |
| - |
162 |
| - status_t execute(const exec_ctx_t &ctx) const override { |
163 |
| - execute_backward_data(ctx); |
164 |
| - return status::success; |
165 |
| - } |
166 |
| - |
167 |
| -private: |
168 |
| - void execute_backward_data(const exec_ctx_t &ctx) const; |
169 |
| - const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
170 |
| - |
171 |
| - std::unique_ptr<jit_uni_dw_conv_bwd_data_kernel<isa, diff_dst_type>> |
172 |
| - kernel_; |
173 |
| -}; |
174 |
| - |
175 |
| -using jit_sve_512_dw_convolution_bwd_data_t |
176 |
| - = jit_uni_dw_convolution_bwd_data_t<sve_512, data_type::f32>; |
177 |
| - |
178 |
| -template <cpu_isa_t isa, data_type_t src_type, |
179 |
| - data_type_t diff_weights_type = src_type> |
180 |
| -struct jit_uni_dw_convolution_bwd_weights_t : public primitive_t { |
181 |
| - struct pd_t : public cpu_convolution_bwd_weights_pd_t { |
182 |
| - pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
183 |
| - const convolution_fwd_pd_t *hint_fwd_pd) |
184 |
| - : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) |
185 |
| - , jcp_() {} |
186 |
| - using jit_uni_dw_convolution_bwd_weights |
187 |
| - = jit_uni_dw_convolution_bwd_weights_t<isa, src_type, |
188 |
| - diff_weights_type>; |
189 |
| - DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_dw:", jcp_.isa, ""), |
190 |
| - jit_uni_dw_convolution_bwd_weights); |
191 |
| - |
192 |
| - status_t init(engine_t *engine) { |
193 |
| - bool ok = true && desc()->prop_kind == prop_kind::backward_weights |
194 |
| - && set_default_alg_kind(alg_kind::convolution_direct) |
195 |
| - && expect_data_types(src_type, diff_weights_type, |
196 |
| - data_type::undef, src_type, data_type::f32) |
197 |
| - && IMPLICATION(this->with_bias(), |
198 |
| - utils::one_of( |
199 |
| - this->desc()->diff_bias_desc.data_type, |
200 |
| - data_type::f32, data_type::bf16)) |
201 |
| - && attr()->has_default_values() && !has_zero_dim_memory() |
202 |
| - && set_default_formats(); |
203 |
| - if (!ok) return status::unimplemented; |
204 |
| - |
205 |
| - const int max_threads |
206 |
| - = dnnl_in_parallel() ? 1 : dnnl_get_max_threads(); |
207 |
| - |
208 |
| - status_t status = jit_uni_dw_conv_bwd_weights_kernel<isa, |
209 |
| - src_type>::init_conf(jcp_, *desc(), *src_md(), |
210 |
| - *diff_weights_md(), *diff_dst_md(), max_threads); |
211 |
| - if (status != status::success) return status; |
212 |
| - |
213 |
| - auto scratchpad = scratchpad_registry().registrar(); |
214 |
| - jit_uni_dw_conv_bwd_weights_kernel<isa, src_type>::init_scratchpad( |
215 |
| - scratchpad, jcp_); |
216 |
| - |
| 185 | + using jit_sve_512_dw_convolution_bwd_data_t |
| 186 | + = jit_uni_dw_convolution_bwd_data_t<sve_512, data_type::f32>; |
| 187 | + using jit_sve_256_dw_convolution_bwd_data_t |
| 188 | + = jit_uni_dw_convolution_bwd_data_t<sve_256, data_type::f32>; |
| 189 | + |
| 190 | + template <cpu_isa_t isa, data_type_t src_type, |
| 191 | + data_type_t diff_weights_type = src_type> |
| 192 | + struct jit_uni_dw_convolution_bwd_weights_t : public primitive_t { |
| 193 | + struct pd_t : public cpu_convolution_bwd_weights_pd_t { |
| 194 | + pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
| 195 | + const convolution_fwd_pd_t *hint_fwd_pd) |
| 196 | + : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) |
| 197 | + , jcp_() {} |
| 198 | + using jit_uni_dw_convolution_bwd_weights |
| 199 | + = jit_uni_dw_convolution_bwd_weights_t<isa, src_type, |
| 200 | + diff_weights_type>; |
| 201 | + DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_dw:", jcp_.isa, ""), |
| 202 | + jit_uni_dw_convolution_bwd_weights); |
| 203 | + |
| 204 | + status_t init(engine_t *engine) { |
| 205 | + bool ok = true |
| 206 | + && desc()->prop_kind == prop_kind::backward_weights |
| 207 | + && set_default_alg_kind(alg_kind::convolution_direct) |
| 208 | + && expect_data_types(src_type, diff_weights_type, |
| 209 | + data_type::undef, src_type, data_type::f32) |
| 210 | + && IMPLICATION(this->with_bias(), |
| 211 | + utils::one_of( |
| 212 | + this->desc()->diff_bias_desc.data_type, |
| 213 | + data_type::f32, data_type::bf16)) |
| 214 | + && attr()->has_default_values() |
| 215 | + && !has_zero_dim_memory() && set_default_formats(); |
| 216 | + if (!ok) return status::unimplemented; |
| 217 | + |
| 218 | + const int max_threads |
| 219 | + = dnnl_in_parallel() ? 1 : dnnl_get_max_threads(); |
| 220 | + |
| 221 | + status_t status = jit_uni_dw_conv_bwd_weights_kernel<isa, |
| 222 | + src_type>::init_conf(jcp_, *desc(), *src_md(), |
| 223 | + *diff_weights_md(), *diff_dst_md(), max_threads); |
| 224 | + if (status != status::success) return status; |
| 225 | + |
| 226 | + auto scratchpad = scratchpad_registry().registrar(); |
| 227 | + jit_uni_dw_conv_bwd_weights_kernel<isa, |
| 228 | + src_type>::init_scratchpad(scratchpad, jcp_); |
| 229 | + |
| 230 | + return status::success; |
| 231 | + } |
| 232 | + |
| 233 | + jit_conv_conf_t jcp_; |
| 234 | + |
| 235 | + protected: |
| 236 | + bool set_default_formats() { |
| 237 | + using namespace format_tag; |
| 238 | + format_tag_t dat_tag, wei_tag; |
| 239 | + switch (isa) { |
| 240 | + case sve_512: |
| 241 | + dat_tag = nChw16c; |
| 242 | + wei_tag = Goihw16g; |
| 243 | + break; |
| 244 | + case sve_256: |
| 245 | + dat_tag = nChw8c; |
| 246 | + wei_tag = Goihw8g; |
| 247 | + break; |
| 248 | + default: return false; |
| 249 | + } |
| 250 | + |
| 251 | + return set_default_formats_common(dat_tag, wei_tag, dat_tag); |
| 252 | + } |
| 253 | + }; |
| 254 | + |
| 255 | + jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd); |
| 256 | + |
| 257 | + typedef typename prec_traits<data_type::bf16>::type bf16_data_t; |
| 258 | + typedef typename prec_traits<data_type::f32>::type f32_data_t; |
| 259 | + typedef typename prec_traits<src_type>::type src_data_t; |
| 260 | + typedef typename prec_traits<src_type>::type diff_dst_data_t; |
| 261 | + typedef typename prec_traits<diff_weights_type>::type |
| 262 | + diff_weights_data_t; |
| 263 | + |
| 264 | + status_t init(engine_t *engine) override { |
| 265 | + CHECK(safe_ptr_assign(kernel_, |
| 266 | + new jit_uni_dw_conv_bwd_weights_kernel<isa, src_type>( |
| 267 | + pd()->jcp_))); |
| 268 | + CHECK(kernel_->create_kernel()); |
| 269 | + |
| 270 | + if (pd()->jcp_.nthr_mb > 1) { |
| 271 | + CHECK(safe_ptr_assign(acc_ker_, |
| 272 | + new cpu_accumulator_1d_t<data_type::f32, isa>())); |
| 273 | + CHECK(acc_ker_->create_kernel()); |
| 274 | + } |
217 | 275 | return status::success;
|
218 | 276 | }
|
219 | 277 |
|
220 |
| - jit_conv_conf_t jcp_; |
221 |
| - |
222 |
| - protected: |
223 |
| - bool set_default_formats() { |
224 |
| - using namespace format_tag; |
225 |
| - |
226 |
| - auto dat_tag = isa == sve_512 ? nChw16c : nChw8c; |
227 |
| - auto wei_tag = isa == sve_512 ? Goihw16g : Goihw8g; |
228 |
| - |
229 |
| - return set_default_formats_common(dat_tag, wei_tag, dat_tag); |
230 |
| - } |
231 |
| - }; |
232 |
| - |
233 |
| - jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd); |
234 |
| - |
235 |
| - typedef typename prec_traits<data_type::bf16>::type bf16_data_t; |
236 |
| - typedef typename prec_traits<data_type::f32>::type f32_data_t; |
237 |
| - typedef typename prec_traits<src_type>::type src_data_t; |
238 |
| - typedef typename prec_traits<src_type>::type diff_dst_data_t; |
239 |
| - typedef typename prec_traits<diff_weights_type>::type diff_weights_data_t; |
240 |
| - |
241 |
| - status_t init(engine_t *engine) override { |
242 |
| - CHECK(safe_ptr_assign(kernel_, |
243 |
| - new jit_uni_dw_conv_bwd_weights_kernel<isa, src_type>( |
244 |
| - pd()->jcp_))); |
245 |
| - CHECK(kernel_->create_kernel()); |
246 |
| - |
247 |
| - if (pd()->jcp_.nthr_mb > 1) { |
248 |
| - CHECK(safe_ptr_assign( |
249 |
| - acc_ker_, new cpu_accumulator_1d_t<data_type::f32>())); |
250 |
| - CHECK(acc_ker_->create_kernel()); |
| 278 | + status_t execute(const exec_ctx_t &ctx) const override { |
| 279 | + execute_backward_weights(ctx); |
| 280 | + execute_reduction(ctx); |
| 281 | + return status::success; |
251 | 282 | }
|
252 |
| - return status::success; |
253 |
| - } |
254 | 283 |
|
255 |
| - status_t execute(const exec_ctx_t &ctx) const override { |
256 |
| - execute_backward_weights(ctx); |
257 |
| - execute_reduction(ctx); |
258 |
| - return status::success; |
259 |
| - } |
| 284 | + private: |
| 285 | + void execute_backward_weights(const exec_ctx_t &ctx) const; |
| 286 | + void execute_reduction(const exec_ctx_t &ctx) const; |
| 287 | + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
260 | 288 |
|
261 |
| -private: |
262 |
| - void execute_backward_weights(const exec_ctx_t &ctx) const; |
263 |
| - void execute_reduction(const exec_ctx_t &ctx) const; |
264 |
| - const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
265 |
| - |
266 |
| - std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_; |
267 |
| - std::unique_ptr<jit_uni_dw_conv_bwd_weights_kernel<isa, src_type>> kernel_; |
268 |
| -}; |
| 289 | + std::unique_ptr<cpu_accumulator_1d_t<data_type::f32, isa>> acc_ker_; |
| 290 | + std::unique_ptr<jit_uni_dw_conv_bwd_weights_kernel<isa, src_type>> |
| 291 | + kernel_; |
| 292 | + }; |
269 | 293 |
|
270 |
| -using jit_sve_512_dw_convolution_bwd_weights_t |
271 |
| - = jit_uni_dw_convolution_bwd_weights_t<sve_512, data_type::f32>; |
| 294 | + using jit_sve_512_dw_convolution_bwd_weights_t |
| 295 | + = jit_uni_dw_convolution_bwd_weights_t<sve_512, data_type::f32>; |
| 296 | + using jit_sve_256_dw_convolution_bwd_weights_t |
| 297 | + = jit_uni_dw_convolution_bwd_weights_t<sve_256, data_type::f32>; |
272 | 298 | } // namespace aarch64
|
273 | 299 | } // namespace cpu
|
274 | 300 | } // namespace impl
|
|
0 commit comments