forked from uxlfoundation/oneDNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathref_deconvolution.hpp
124 lines (103 loc) · 4.58 KB
/
ref_deconvolution.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
/*******************************************************************************
* Copyright 2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef GPU_SYCL_REF_DECONVOLUTION_HPP
#define GPU_SYCL_REF_DECONVOLUTION_HPP
#include "gpu/generic/sycl/ref_convolution.hpp"
#include "gpu/generic/sycl/sycl_gpu_primitive.hpp"
#include "gpu/generic/sycl/sycl_io_helper.hpp"
#include "gpu/generic/sycl/sycl_post_ops.hpp"
#include "gpu/generic/sycl/sycl_primitive_conf.hpp"
#include "gpu/generic/sycl/sycl_q10n.hpp"
#include "gpu/generic/sycl/sycl_utils.hpp"
#include "gpu/gpu_deconvolution_pd.hpp"
#include "xpu/sycl/types.hpp"
namespace dnnl {
namespace impl {
namespace gpu {
namespace generic {
namespace sycl {
struct ref_deconvolution_bwd_weights_t
: public gpu::generic::sycl::primitive_t {
using gpu::generic::sycl::primitive_t::primitive_t;
struct pd_t : public deconvolution_bwd_weights_pd_t {
using deconvolution_bwd_weights_pd_t::deconvolution_bwd_weights_pd_t;
DECLARE_COMMON_PD_T("dpcpp:ref:any", ref_deconvolution_bwd_weights_t);
status_t init(impl::engine_t *engine) {
using namespace data_type;
const memory_desc_wrapper data_d(src_md());
const memory_desc_wrapper diff_weights_d(diff_weights_md());
const memory_desc_wrapper diff_dst_d(diff_dst_md());
const bool ok = desc()->prop_kind == prop_kind::backward_weights
&& check_convolution_work_amount(diff_weights_d, OC())
&& md_dims_in_range(src_md()) && set_default_formats()
&& check_convolution_data_types(
data_d, diff_weights_d, diff_dst_d)
&& check_convolution_formats(
data_d, diff_weights_d, diff_dst_d)
&& attr()->has_default_values()
&& desc()->alg_kind == alg_kind::deconvolution_direct;
if (!ok) return status::unimplemented;
return init_conf();
}
sycl_convolution_conf_t conf_;
private:
status_t init_conf();
bool set_default_formats_common_template(memory_desc_t &src_md,
format_tag_t src_tag, memory_desc_t &wei_md,
format_tag_t wei_tag, memory_desc_t &dst_md,
format_tag_t dst_tag, memory_desc_t &bia_md) {
using namespace format_tag;
#define IS_OK(f) \
do { \
if ((f) != status::success) return false; \
} while (0)
if (src_md.format_kind == format_kind::any
&& !utils::one_of(src_tag, any, undef))
IS_OK(memory_desc_init_by_tag(src_md, src_tag));
if (dst_md.format_kind == format_kind::any
&& !utils::one_of(dst_tag, any, undef))
IS_OK(memory_desc_init_by_tag(dst_md, dst_tag));
if (wei_md.format_kind == format_kind::any
&& !utils::one_of(wei_tag, any, undef))
IS_OK(memory_desc_init_by_tag(wei_md, wei_tag));
if (with_bias() && bia_md.format_kind == format_kind::any)
IS_OK(memory_desc_init_by_tag(bia_md, x));
#undef IS_OK
return true;
}
bool set_default_formats() {
using namespace format_tag;
auto dat_tag = utils::pick(ndims() - 3, nwc, nhwc, ndhwc);
auto wei_tag = with_groups()
? utils::pick(ndims() - 3, goiw, goihw, goidhw)
: utils::pick(ndims() - 3, oiw, oihw, oidhw);
return set_default_formats_common_template(src_md_, dat_tag,
diff_weights_md_, wei_tag, diff_dst_md_, dat_tag,
diff_bias_md_);
}
};
status_t init(impl::engine_t *engine) override;
status_t execute(const exec_ctx_t &ctx) const override;
private:
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
kernel_t kernel_;
};
} // namespace sycl
} // namespace generic
} // namespace gpu
} // namespace impl
} // namespace dnnl
#endif