@@ -41,31 +41,17 @@ status_t cudnn_matmul_t::execute(const exec_ctx_t &ctx) const {
41
41
const auto dst_d = ctx.memory_mdw (DNNL_ARG_DST, pd ()->dst_md ());
42
42
const auto bias_d = ctx.memory_mdw (DNNL_ARG_BIAS, pd ()->weights_md (1 ));
43
43
44
- status_t status;
45
- size_t bias_scratchpad_size
46
- = 0 ; // To avoid extra allocation in an executor.
47
-
48
- bool has_runtime_args = matmul_impl_->has_runtime_params ();
49
- if (has_runtime_args) {
50
- // Initialise all runtime parameters
51
- status = matmul_impl_->init_parameters (src_d, weights_d, dst_d, bias_d);
52
- if (status != status::success) return status;
53
-
54
- bias_scratchpad_size = matmul_impl_->bias_scratch_size ();
55
- }
56
-
57
44
nvidia::stream_t *cuda_stream
58
45
= utils::downcast<nvidia::stream_t *>(ctx.stream ());
59
46
60
- status = executor_->execute (
61
- ctx, ctx. stream ()->engine (), matmul_impl_, bias_scratchpad_size );
47
+ status_t status = executor_->execute (ctx, ctx. stream ()-> engine (),
48
+ matmul_impl_, pd ()->params_ , src_d, weights_d, dst_d, bias_d );
62
49
63
- if (has_runtime_args ) {
50
+ if (pd ()-> params_ -> has_runtime_params_ ) {
64
51
auto &evts = cuda_stream->sycl_ctx ().get_sycl_deps ().events ;
65
52
for (auto e : evts) {
66
53
e.wait ();
67
54
}
68
- matmul_impl_->cleanup ();
69
55
}
70
56
return status;
71
57
}
@@ -76,32 +62,6 @@ status_t cudnn_matmul_lt_t::execute(const exec_ctx_t &ctx) const {
76
62
const auto src_d = ctx.memory_mdw (DNNL_ARG_SRC, pd ()->src_md ());
77
63
const auto weights_d = ctx.memory_mdw (DNNL_ARG_WEIGHTS, pd ()->weights_md ());
78
64
const auto dst_d = ctx.memory_mdw (DNNL_ARG_DST, pd ()->dst_md ());
79
- const auto bias_d = ctx.memory_mdw (DNNL_ARG_BIAS, pd ()->weights_md (1 ));
80
-
81
- // To avoid extra allocation in an executor.
82
- size_t algo_scratchpad_size = 0 ;
83
- size_t bias_scratchpad_size = 0 ;
84
- size_t block_a_scratchpad_size = 0 ;
85
- size_t block_b_scratchpad_size = 0 ;
86
- size_t block_c_scratchpad_size = 0 ;
87
- size_t src_scale_scratchpad_size = 0 ;
88
- size_t wei_scale_scratchpad_size = 0 ;
89
-
90
- bool has_runtime_args = matmul_impl_->has_runtime_params ();
91
- if (has_runtime_args) {
92
- // Initialise all runtime parameters
93
- auto engine = ctx.stream ()->engine ();
94
- CHECK (matmul_impl_->init_parameters (
95
- src_d, weights_d, dst_d, bias_d, engine));
96
-
97
- algo_scratchpad_size = matmul_impl_->algo_scratch_size ();
98
- bias_scratchpad_size = matmul_impl_->bias_scratch_size ();
99
- block_a_scratchpad_size = matmul_impl_->block_a_scratch_size ();
100
- block_b_scratchpad_size = matmul_impl_->block_b_scratch_size ();
101
- block_c_scratchpad_size = matmul_impl_->block_c_scratch_size ();
102
- src_scale_scratchpad_size = matmul_impl_->src_scale_size ();
103
- wei_scale_scratchpad_size = matmul_impl_->wei_scale_size ();
104
- }
105
65
106
66
nvidia::stream_t *cuda_stream
107
67
= utils::downcast<nvidia::stream_t *>(ctx.stream ());
@@ -117,8 +77,8 @@ status_t cudnn_matmul_lt_t::execute(const exec_ctx_t &ctx) const {
117
77
!= ctx.args ().end ();
118
78
119
79
if (has_src_scales
120
- && (matmul_impl_-> multi_src_scale ()
121
- || matmul_impl_-> scale_type () == CUDA_R_32I)) {
80
+ && (pd ()-> params_ -> multi_src_scale_
81
+ || pd ()-> params_ -> acc_type_ == CUDA_R_32I)) {
122
82
// src scale sycl binary
123
83
exec_args_t src_scale_binary_args;
124
84
src_scale_binary_args[DNNL_ARG_SRC_0]
@@ -141,8 +101,8 @@ status_t cudnn_matmul_lt_t::execute(const exec_ctx_t &ctx) const {
141
101
CHECK (src_scale_binary_->execute (binary_ctx));
142
102
}
143
103
if (has_wei_scales
144
- && (matmul_impl_-> multi_wei_scale ()
145
- || matmul_impl_-> scale_type () == CUDA_R_32I)) {
104
+ && (pd ()-> params_ -> multi_wei_scale_
105
+ || pd ()-> params_ -> acc_type_ == CUDA_R_32I)) {
146
106
// wei scale sycl binary
147
107
exec_args_t wei_scale_binary_args;
148
108
wei_scale_binary_args[DNNL_ARG_SRC_0]
@@ -167,11 +127,9 @@ status_t cudnn_matmul_lt_t::execute(const exec_ctx_t &ctx) const {
167
127
}
168
128
169
129
CHECK (executor_->execute (ctx, ctx.stream ()->engine (), matmul_impl_,
170
- algo_scratchpad_size, bias_scratchpad_size, block_a_scratchpad_size,
171
- block_b_scratchpad_size, block_c_scratchpad_size,
172
- src_scale_scratchpad_size, wei_scale_scratchpad_size));
130
+ pd ()->params_ , src_d, weights_d, dst_d));
173
131
174
- if (matmul_impl_-> with_bias () ) {
132
+ if (pd ()-> params_ -> with_bias_ ) {
175
133
// bias sycl binary
176
134
exec_args_t binary_args;
177
135
std::unique_ptr<memory_t , memory_deleter_t > scratch_mem;
@@ -198,8 +156,8 @@ status_t cudnn_matmul_lt_t::execute(const exec_ctx_t &ctx) const {
198
156
}
199
157
200
158
if (has_dst_scales
201
- && (matmul_impl_-> multi_dst_scale ()
202
- || matmul_impl_-> scale_type () == CUDA_R_32I)) {
159
+ && (pd ()-> params_ -> multi_dst_scale_
160
+ || pd ()-> params_ -> acc_type_ == CUDA_R_32I)) {
203
161
// dst scale sycl binary
204
162
exec_args_t dst_scale_binary_args;
205
163
dst_scale_binary_args[DNNL_ARG_SRC_0]
@@ -213,13 +171,11 @@ status_t cudnn_matmul_lt_t::execute(const exec_ctx_t &ctx) const {
213
171
CHECK (dst_scale_binary_->execute (binary_ctx));
214
172
}
215
173
216
- if (has_runtime_args ) {
174
+ if (pd ()-> params_ -> has_runtime_params_ ) {
217
175
auto &evts = cuda_stream->sycl_ctx ().get_sycl_deps ().events ;
218
176
for (auto e : evts) {
219
177
e.wait ();
220
178
}
221
-
222
- matmul_impl_->rt_cleanup ();
223
179
}
224
180
225
181
return status::success;
0 commit comments