|
| 1 | +/******************************************************************************* |
| 2 | +* Copyright 2024 Intel Corporation |
| 3 | +* |
| 4 | +* Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +* you may not use this file except in compliance with the License. |
| 6 | +* You may obtain a copy of the License at |
| 7 | +* |
| 8 | +* http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +* |
| 10 | +* Unless required by applicable law or agreed to in writing, software |
| 11 | +* distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +* See the License for the specific language governing permissions and |
| 14 | +* limitations under the License. |
| 15 | +*******************************************************************************/ |
| 16 | + |
| 17 | +#ifndef GPU_GENERIC_CONVOLUTION_DECONVOLUTION_HPP |
| 18 | +#define GPU_GENERIC_CONVOLUTION_DECONVOLUTION_HPP |
| 19 | + |
| 20 | +#include "common/c_types_map.hpp" |
| 21 | +#include "common/primitive.hpp" |
| 22 | +#include "common/primitive_desc_iterator.hpp" |
| 23 | +#include "common/type_helpers.hpp" |
| 24 | +#include "common/utils.hpp" |
| 25 | +#include "gpu/gpu_deconvolution_pd.hpp" |
| 26 | +#include "gpu/gpu_primitive.hpp" |
| 27 | + |
| 28 | +namespace dnnl { |
| 29 | +namespace impl { |
| 30 | +namespace gpu { |
| 31 | +namespace generic { |
| 32 | + |
| 33 | +static status_t weights_axes_permutation( |
| 34 | + memory_desc_t *o_md, const memory_desc_t *i_md, bool with_groups) { |
| 35 | + int perm[DNNL_MAX_NDIMS] {}; // deconv to conv weight permutation |
| 36 | + for (int d = 0; d < DNNL_MAX_NDIMS; ++d) |
| 37 | + perm[d] = d; |
| 38 | + nstl::swap(perm[0 + with_groups], perm[1 + with_groups]); |
| 39 | + |
| 40 | + return memory_desc_permute_axes(*o_md, *i_md, perm); |
| 41 | +} |
| 42 | + |
| 43 | +static status_t conv_descr_create( |
| 44 | + const deconvolution_desc_t *dd, convolution_desc_t *cd) { |
| 45 | + using namespace prop_kind; |
| 46 | + alg_kind_t alg_kind = alg_kind::convolution_direct; |
| 47 | + |
| 48 | + const memory_desc_t *src_md, *dst_md, *d_weights_d; |
| 49 | + prop_kind_t prop_kind; |
| 50 | + |
| 51 | + switch (dd->prop_kind) { |
| 52 | + case forward: |
| 53 | + case forward_inference: |
| 54 | + prop_kind = backward_data; |
| 55 | + src_md = &dd->dst_desc; |
| 56 | + dst_md = &dd->src_desc; |
| 57 | + d_weights_d = &dd->weights_desc; |
| 58 | + break; |
| 59 | + case backward_data: |
| 60 | + prop_kind = forward_training; |
| 61 | + src_md = &dd->diff_dst_desc; |
| 62 | + dst_md = &dd->diff_src_desc; |
| 63 | + d_weights_d = &dd->weights_desc; |
| 64 | + break; |
| 65 | + case backward_weights: |
| 66 | + prop_kind = dd->prop_kind; |
| 67 | + src_md = &dd->diff_dst_desc; |
| 68 | + dst_md = &dd->src_desc; |
| 69 | + d_weights_d = &dd->diff_weights_desc; |
| 70 | + break; |
| 71 | + default: assert(!"unknown prop kind"); return status::invalid_arguments; |
| 72 | + } |
| 73 | + |
| 74 | + // Create weights desc for convolution |
| 75 | + memory_desc_t c_weights_d; |
| 76 | + const bool with_groups = d_weights_d->ndims == src_md->ndims + 1; |
| 77 | + CHECK(weights_axes_permutation(&c_weights_d, d_weights_d, with_groups)); |
| 78 | + |
| 79 | + return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d, |
| 80 | + prop_kind != backward_weights ? &dd->bias_desc : nullptr, dst_md, |
| 81 | + dd->strides, dd->dilates, dd->padding[0], dd->padding[1]); |
| 82 | +} |
| 83 | + |
| 84 | +struct convolution_deconvolution_fwd_t : public gpu::primitive_t { |
| 85 | + using gpu::primitive_t::primitive_t; |
| 86 | + struct pd_t : public gpu_deconvolution_fwd_pd_t { |
| 87 | + pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr, |
| 88 | + const deconvolution_fwd_pd_t *hint_fwd_pd) |
| 89 | + : gpu_deconvolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {} |
| 90 | + |
| 91 | + pd_t(const pd_t &other) = default; |
| 92 | + |
| 93 | + ~pd_t() = default; |
| 94 | + |
| 95 | + DECLARE_COMMON_PD_T(name_.c_str(), convolution_deconvolution_fwd_t); |
| 96 | + status_t init_convolution(impl::engine_t *engine) { |
| 97 | + convolution_desc_t cd; |
| 98 | + CHECK(conv_descr_create(desc(), &cd)); |
| 99 | + primitive_attr_t conv_attr(*attr()); |
| 100 | + if (!conv_attr.is_initialized()) return status::out_of_memory; |
| 101 | + primitive_desc_iterator_t it( |
| 102 | + engine, (op_desc_t *)&cd, &conv_attr, nullptr); |
| 103 | + if (!it.is_initialized()) return status::out_of_memory; |
| 104 | + conv_pd_ = *(++it); |
| 105 | + |
| 106 | + return (conv_pd_) ? status::success : status::unimplemented; |
| 107 | + } |
| 108 | + |
| 109 | + status_t init(impl::engine_t *engine) { |
| 110 | + using namespace format_tag; |
| 111 | + using sm = primitive_attr_t::skip_mask_t; |
| 112 | + |
| 113 | + const auto attr_skip_mask = sm::post_ops | sm::zero_points_runtime |
| 114 | + | sm::scales_runtime; |
| 115 | + |
| 116 | + VDISPATCH_DECONVOLUTION(is_fwd(), VERBOSE_BAD_PROPKIND); |
| 117 | + VDISPATCH_DECONVOLUTION( |
| 118 | + desc()->alg_kind == alg_kind::deconvolution_direct, |
| 119 | + VERBOSE_BAD_ALGORITHM); |
| 120 | + VDISPATCH_DECONVOLUTION(attr()->has_default_values(attr_skip_mask), |
| 121 | + VERBOSE_UNSUPPORTED_ATTR); |
| 122 | + VDISPATCH_DECONVOLUTION( |
| 123 | + (utils::everyone_is(data_type::f32, |
| 124 | + desc()->src_desc.data_type, |
| 125 | + desc()->weights_desc.data_type, |
| 126 | + desc()->dst_desc.data_type) |
| 127 | + || (utils::everyone_is(data_type::f64, |
| 128 | + desc()->src_desc.data_type, |
| 129 | + desc()->weights_desc.data_type, |
| 130 | + desc()->dst_desc.data_type)) |
| 131 | + || ((utils::everyone_is(data_type::f16, |
| 132 | + desc()->src_desc.data_type, |
| 133 | + desc()->weights_desc.data_type) |
| 134 | + || utils::everyone_is(data_type::f32, |
| 135 | + desc()->src_desc.data_type, |
| 136 | + desc()->weights_desc.data_type) |
| 137 | + || utils::everyone_is(data_type::bf16, |
| 138 | + desc()->src_desc.data_type, |
| 139 | + desc()->weights_desc.data_type)) |
| 140 | + && utils::one_of(desc()->dst_desc.data_type, |
| 141 | + data_type::f16, data_type::u8, |
| 142 | + data_type::s8)) |
| 143 | + || (utils::everyone_is(data_type::bf16, |
| 144 | + desc()->src_desc.data_type, |
| 145 | + desc()->weights_desc.data_type) |
| 146 | + && utils::one_of(desc()->dst_desc.data_type, |
| 147 | + data_type::f32, data_type::bf16)) |
| 148 | + || (utils::everyone_is(data_type::f16, |
| 149 | + desc()->src_desc.data_type, |
| 150 | + desc()->weights_desc.data_type) |
| 151 | + && utils::one_of(desc()->dst_desc.data_type, |
| 152 | + data_type::f32, data_type::f16)) |
| 153 | + || (desc()->weights_desc.data_type == data_type::s8 |
| 154 | + && utils::one_of(desc()->src_desc.data_type, |
| 155 | + data_type::u8, data_type::s8) |
| 156 | + && desc()->dst_desc.data_type |
| 157 | + != data_type::f64)), |
| 158 | + VERBOSE_UNSUPPORTED_DT); |
| 159 | + |
| 160 | + VDISPATCH_DECONVOLUTION_SC( |
| 161 | + init_convolution(engine), "init_convolution()"); |
| 162 | + if (weights_md_.format_kind == format_kind::any) { |
| 163 | + VDISPATCH_DECONVOLUTION_SC( |
| 164 | + weights_axes_permutation(&weights_md_, |
| 165 | + conv_pd_->weights_md(), with_groups()), |
| 166 | + "weights_axes_permutation()"); |
| 167 | + } |
| 168 | + if (src_md_.format_kind == format_kind::any) |
| 169 | + src_md_ = *conv_pd_->diff_dst_md(); |
| 170 | + if (dst_md_.format_kind == format_kind::any) |
| 171 | + dst_md_ = *conv_pd_->diff_src_md(); |
| 172 | + if (bias_md_.format_kind == format_kind::any) { |
| 173 | + VDISPATCH_DECONVOLUTION_SC(memory_desc_init_by_tag(bias_md_, x), |
| 174 | + VERBOSE_UNSUPPORTED_TAG); |
| 175 | + } |
| 176 | + init_name(); |
| 177 | + init_scratchpad(); |
| 178 | + VDISPATCH_DECONVOLUTION_SC(attr_.set_default_formats(dst_md(0)), |
| 179 | + VERBOSE_UNSUPPORTED_ATTR); |
| 180 | + |
| 181 | + return status::success; |
| 182 | + } |
| 183 | + |
| 184 | + std::shared_ptr<primitive_desc_t> conv_pd_; |
| 185 | + |
| 186 | + private: |
| 187 | + std::string name_ = "conv:any"; |
| 188 | + |
| 189 | + void init_name() { |
| 190 | + name_.append("+"); |
| 191 | + name_.append(conv_pd_->name()); |
| 192 | + } |
| 193 | + |
| 194 | + void init_scratchpad() { |
| 195 | + auto scratchpad = scratchpad_registry().registrar(); |
| 196 | + scratchpad.book(memory_tracking::names::key_nested, |
| 197 | + conv_pd_->scratchpad_registry()); |
| 198 | + } |
| 199 | + }; |
| 200 | + |
| 201 | + status_t init(impl::engine_t *engine) override { |
| 202 | + return create_nested_primitive(conv_p_, pd()->conv_pd_, engine); |
| 203 | + } |
| 204 | + |
| 205 | + status_t execute(const exec_ctx_t &ctx) const override { |
| 206 | + using namespace memory_tracking::names; |
| 207 | + const auto &args = ctx.args(); |
| 208 | + exec_args_t conv_args; |
| 209 | + conv_args[DNNL_ARG_DIFF_DST] = args.at(DNNL_ARG_SRC); |
| 210 | + conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS); |
| 211 | + conv_args[DNNL_ARG_DIFF_SRC] = args.at(DNNL_ARG_DST); |
| 212 | + if (pd()->with_bias()) |
| 213 | + conv_args[DNNL_ARG_BIAS] = args.at(DNNL_ARG_BIAS); |
| 214 | + |
| 215 | + for (int idx = 0; idx < pd()->attr()->post_ops_.len(); ++idx) { |
| 216 | + if (pd()->attr()->post_ops_.entry_[idx].is_binary()) { |
| 217 | + conv_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1] |
| 218 | + = args.at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) |
| 219 | + | DNNL_ARG_SRC_1); |
| 220 | + } else if (pd()->attr()->post_ops_.entry_[idx].is_prelu()) { |
| 221 | + conv_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) |
| 222 | + | DNNL_ARG_WEIGHTS] |
| 223 | + = args.at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) |
| 224 | + | DNNL_ARG_WEIGHTS); |
| 225 | + } |
| 226 | + } |
| 227 | + const auto z_src = DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC; |
| 228 | + const auto z_dst = DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST; |
| 229 | + if (args.find(z_src) != args.end()) conv_args[z_src] = args.at(z_src); |
| 230 | + if (args.find(z_dst) != args.end()) conv_args[z_dst] = args.at(z_dst); |
| 231 | + |
| 232 | + for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}) { |
| 233 | + int key = DNNL_ARG_ATTR_SCALES | arg; |
| 234 | + if (args.find(key) != args.end()) conv_args[key] = args.at(key); |
| 235 | + } |
| 236 | + |
| 237 | + exec_ctx_t conv_ctx(ctx, std::move(conv_args)); |
| 238 | + |
| 239 | + nested_scratchpad_t ns(ctx, key_nested, conv_p_); |
| 240 | + conv_ctx.set_scratchpad_grantor(ns.grantor()); |
| 241 | + // Executing the convolution kernel |
| 242 | + return conv_p_->execute(conv_ctx); |
| 243 | + } |
| 244 | + |
| 245 | +private: |
| 246 | + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
| 247 | + std::shared_ptr<impl::primitive_t> conv_p_; |
| 248 | +}; |
| 249 | + |
| 250 | +struct convolution_deconvolution_bwd_data_t : public gpu::primitive_t { |
| 251 | + using gpu::primitive_t::primitive_t; |
| 252 | + struct pd_t : public gpu_deconvolution_bwd_data_pd_t { |
| 253 | + pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr, |
| 254 | + const deconvolution_fwd_pd_t *hint_fwd_pd) |
| 255 | + : gpu_deconvolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) |
| 256 | + , conv_pd_(nullptr) {} |
| 257 | + |
| 258 | + pd_t(const pd_t &other) = default; |
| 259 | + |
| 260 | + ~pd_t() = default; |
| 261 | + |
| 262 | + DECLARE_COMMON_PD_T( |
| 263 | + name_.c_str(), convolution_deconvolution_bwd_data_t); |
| 264 | + |
| 265 | + status_t init_convolution(impl::engine_t *engine) { |
| 266 | + convolution_desc_t cd; |
| 267 | + CHECK(conv_descr_create(desc(), &cd)); |
| 268 | + primitive_attr_t conv_attr(*attr()); |
| 269 | + if (!conv_attr.is_initialized()) return status::out_of_memory; |
| 270 | + primitive_desc_iterator_t it( |
| 271 | + engine, (op_desc_t *)&cd, &conv_attr, nullptr); |
| 272 | + if (!it.is_initialized()) return status::out_of_memory; |
| 273 | + conv_pd_ = *(++it); |
| 274 | + return (conv_pd_) ? status::success : status::unimplemented; |
| 275 | + } |
| 276 | + |
| 277 | + status_t init(impl::engine_t *engine) { |
| 278 | + VDISPATCH_DECONVOLUTION( |
| 279 | + desc()->prop_kind == prop_kind::backward_data, |
| 280 | + VERBOSE_BAD_PROPKIND); |
| 281 | + |
| 282 | + VDISPATCH_DECONVOLUTION( |
| 283 | + (utils::everyone_is(data_type::f32, |
| 284 | + desc()->diff_src_desc.data_type, |
| 285 | + desc()->weights_desc.data_type, |
| 286 | + desc()->diff_dst_desc.data_type) |
| 287 | + || (utils::everyone_is(data_type::f64, |
| 288 | + desc()->diff_src_desc.data_type, |
| 289 | + desc()->weights_desc.data_type, |
| 290 | + desc()->diff_dst_desc.data_type)) |
| 291 | + || utils::everyone_is(data_type::f16, |
| 292 | + desc()->weights_desc.data_type, |
| 293 | + desc()->diff_dst_desc.data_type) |
| 294 | + || utils::everyone_is(data_type::bf16, |
| 295 | + desc()->weights_desc.data_type, |
| 296 | + desc()->diff_dst_desc.data_type)), |
| 297 | + VERBOSE_UNSUPPORTED_DT); |
| 298 | + |
| 299 | + VDISPATCH_DECONVOLUTION( |
| 300 | + utils::one_of(desc()->diff_src_desc.data_type, |
| 301 | + data_type::bf16, data_type::f16, data_type::f32, |
| 302 | + data_type::f64), |
| 303 | + VERBOSE_UNSUPPORTED_DT); |
| 304 | + VDISPATCH_DECONVOLUTION( |
| 305 | + desc()->alg_kind == alg_kind::deconvolution_direct, |
| 306 | + VERBOSE_BAD_ALGORITHM); |
| 307 | + VDISPATCH_DECONVOLUTION( |
| 308 | + attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR); |
| 309 | + |
| 310 | + VDISPATCH_DECONVOLUTION_SC( |
| 311 | + init_convolution(engine), "init_convolution()"); |
| 312 | + if (weights_md_.format_kind == format_kind::any) |
| 313 | + VDISPATCH_DECONVOLUTION_SC( |
| 314 | + weights_axes_permutation(&weights_md_, |
| 315 | + conv_pd_->weights_md(), with_groups()), |
| 316 | + "weights_axes_permutation()"); |
| 317 | + if (diff_src_md_.format_kind == format_kind::any) |
| 318 | + diff_src_md_ = *conv_pd_->dst_md(); |
| 319 | + if (diff_dst_md_.format_kind == format_kind::any) |
| 320 | + diff_dst_md_ = *conv_pd_->src_md(); |
| 321 | + |
| 322 | + init_name(); |
| 323 | + init_scratchpad(); |
| 324 | + |
| 325 | + return status::success; |
| 326 | + } |
| 327 | + |
| 328 | + std::shared_ptr<primitive_desc_t> conv_pd_; |
| 329 | + |
| 330 | + private: |
| 331 | + std::string name_ = "conv:any"; |
| 332 | + |
| 333 | + void init_name() { |
| 334 | + name_.append("+"); |
| 335 | + name_.append(conv_pd_->name()); |
| 336 | + } |
| 337 | + |
| 338 | + void init_scratchpad() { |
| 339 | + auto scratchpad = scratchpad_registry().registrar(); |
| 340 | + scratchpad.book(memory_tracking::names::key_nested, |
| 341 | + conv_pd_->scratchpad_registry()); |
| 342 | + } |
| 343 | + }; |
| 344 | + |
| 345 | + status_t init(impl::engine_t *engine) override { |
| 346 | + return create_nested_primitive(conv_p_, pd()->conv_pd_, engine); |
| 347 | + } |
| 348 | + |
| 349 | + status_t execute(const exec_ctx_t &ctx) const override { |
| 350 | + using namespace memory_tracking::names; |
| 351 | + const auto &args = ctx.args(); |
| 352 | + exec_args_t conv_args; |
| 353 | + conv_args[DNNL_ARG_SRC] = args.at(DNNL_ARG_DIFF_DST); |
| 354 | + conv_args[DNNL_ARG_WEIGHTS] = args.at(DNNL_ARG_WEIGHTS); |
| 355 | + conv_args[DNNL_ARG_DST] = args.at(DNNL_ARG_DIFF_SRC); |
| 356 | + if (!types::is_zero_md(pd()->scratchpad_md())) |
| 357 | + conv_args[DNNL_ARG_SCRATCHPAD] = args.at(DNNL_ARG_SCRATCHPAD); |
| 358 | + exec_ctx_t conv_ctx(ctx, std::move(conv_args)); |
| 359 | + |
| 360 | + nested_scratchpad_t ns(ctx, key_nested, conv_p_); |
| 361 | + conv_ctx.set_scratchpad_grantor(ns.grantor()); |
| 362 | + // Executing the convolution kernel |
| 363 | + return conv_p_->execute(conv_ctx); |
| 364 | + } |
| 365 | + |
| 366 | +private: |
| 367 | + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
| 368 | + std::shared_ptr<impl::primitive_t> conv_p_; |
| 369 | +}; |
| 370 | + |
| 371 | +} // namespace generic |
| 372 | +} // namespace gpu |
| 373 | +} // namespace impl |
| 374 | +} // namespace dnnl |
| 375 | + |
| 376 | +#endif |
0 commit comments