Skip to content

Commit 42be8d5

Browse files
authored
gpu: nvidia: Refactor to native parameters for matmul (#2111)
1 parent 7f5e066 commit 42be8d5

10 files changed

+801
-774
lines changed

src/gpu/nvidia/cudnn_matmul.cpp

+12-56
Original file line numberDiff line numberDiff line change
@@ -41,31 +41,17 @@ status_t cudnn_matmul_t::execute(const exec_ctx_t &ctx) const {
4141
const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md());
4242
const auto bias_d = ctx.memory_mdw(DNNL_ARG_BIAS, pd()->weights_md(1));
4343

44-
status_t status;
45-
size_t bias_scratchpad_size
46-
= 0; // To avoid extra allocation in an executor.
47-
48-
bool has_runtime_args = matmul_impl_->has_runtime_params();
49-
if (has_runtime_args) {
50-
// Initialise all runtime parameters
51-
status = matmul_impl_->init_parameters(src_d, weights_d, dst_d, bias_d);
52-
if (status != status::success) return status;
53-
54-
bias_scratchpad_size = matmul_impl_->bias_scratch_size();
55-
}
56-
5744
nvidia::stream_t *cuda_stream
5845
= utils::downcast<nvidia::stream_t *>(ctx.stream());
5946

60-
status = executor_->execute(
61-
ctx, ctx.stream()->engine(), matmul_impl_, bias_scratchpad_size);
47+
status_t status = executor_->execute(ctx, ctx.stream()->engine(),
48+
matmul_impl_, pd()->params_, src_d, weights_d, dst_d, bias_d);
6249

63-
if (has_runtime_args) {
50+
if (pd()->params_->has_runtime_params_) {
6451
auto &evts = cuda_stream->sycl_ctx().get_sycl_deps().events;
6552
for (auto e : evts) {
6653
e.wait();
6754
}
68-
matmul_impl_->cleanup();
6955
}
7056
return status;
7157
}
@@ -76,32 +62,6 @@ status_t cudnn_matmul_lt_t::execute(const exec_ctx_t &ctx) const {
7662
const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md());
7763
const auto weights_d = ctx.memory_mdw(DNNL_ARG_WEIGHTS, pd()->weights_md());
7864
const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md());
79-
const auto bias_d = ctx.memory_mdw(DNNL_ARG_BIAS, pd()->weights_md(1));
80-
81-
// To avoid extra allocation in an executor.
82-
size_t algo_scratchpad_size = 0;
83-
size_t bias_scratchpad_size = 0;
84-
size_t block_a_scratchpad_size = 0;
85-
size_t block_b_scratchpad_size = 0;
86-
size_t block_c_scratchpad_size = 0;
87-
size_t src_scale_scratchpad_size = 0;
88-
size_t wei_scale_scratchpad_size = 0;
89-
90-
bool has_runtime_args = matmul_impl_->has_runtime_params();
91-
if (has_runtime_args) {
92-
// Initialise all runtime parameters
93-
auto engine = ctx.stream()->engine();
94-
CHECK(matmul_impl_->init_parameters(
95-
src_d, weights_d, dst_d, bias_d, engine));
96-
97-
algo_scratchpad_size = matmul_impl_->algo_scratch_size();
98-
bias_scratchpad_size = matmul_impl_->bias_scratch_size();
99-
block_a_scratchpad_size = matmul_impl_->block_a_scratch_size();
100-
block_b_scratchpad_size = matmul_impl_->block_b_scratch_size();
101-
block_c_scratchpad_size = matmul_impl_->block_c_scratch_size();
102-
src_scale_scratchpad_size = matmul_impl_->src_scale_size();
103-
wei_scale_scratchpad_size = matmul_impl_->wei_scale_size();
104-
}
10565

10666
nvidia::stream_t *cuda_stream
10767
= utils::downcast<nvidia::stream_t *>(ctx.stream());
@@ -117,8 +77,8 @@ status_t cudnn_matmul_lt_t::execute(const exec_ctx_t &ctx) const {
11777
!= ctx.args().end();
11878

11979
if (has_src_scales
120-
&& (matmul_impl_->multi_src_scale()
121-
|| matmul_impl_->scale_type() == CUDA_R_32I)) {
80+
&& (pd()->params_->multi_src_scale_
81+
|| pd()->params_->acc_type_ == CUDA_R_32I)) {
12282
// src scale sycl binary
12383
exec_args_t src_scale_binary_args;
12484
src_scale_binary_args[DNNL_ARG_SRC_0]
@@ -141,8 +101,8 @@ status_t cudnn_matmul_lt_t::execute(const exec_ctx_t &ctx) const {
141101
CHECK(src_scale_binary_->execute(binary_ctx));
142102
}
143103
if (has_wei_scales
144-
&& (matmul_impl_->multi_wei_scale()
145-
|| matmul_impl_->scale_type() == CUDA_R_32I)) {
104+
&& (pd()->params_->multi_wei_scale_
105+
|| pd()->params_->acc_type_ == CUDA_R_32I)) {
146106
// wei scale sycl binary
147107
exec_args_t wei_scale_binary_args;
148108
wei_scale_binary_args[DNNL_ARG_SRC_0]
@@ -167,11 +127,9 @@ status_t cudnn_matmul_lt_t::execute(const exec_ctx_t &ctx) const {
167127
}
168128

169129
CHECK(executor_->execute(ctx, ctx.stream()->engine(), matmul_impl_,
170-
algo_scratchpad_size, bias_scratchpad_size, block_a_scratchpad_size,
171-
block_b_scratchpad_size, block_c_scratchpad_size,
172-
src_scale_scratchpad_size, wei_scale_scratchpad_size));
130+
pd()->params_, src_d, weights_d, dst_d));
173131

174-
if (matmul_impl_->with_bias()) {
132+
if (pd()->params_->with_bias_) {
175133
// bias sycl binary
176134
exec_args_t binary_args;
177135
std::unique_ptr<memory_t, memory_deleter_t> scratch_mem;
@@ -198,8 +156,8 @@ status_t cudnn_matmul_lt_t::execute(const exec_ctx_t &ctx) const {
198156
}
199157

200158
if (has_dst_scales
201-
&& (matmul_impl_->multi_dst_scale()
202-
|| matmul_impl_->scale_type() == CUDA_R_32I)) {
159+
&& (pd()->params_->multi_dst_scale_
160+
|| pd()->params_->acc_type_ == CUDA_R_32I)) {
203161
// dst scale sycl binary
204162
exec_args_t dst_scale_binary_args;
205163
dst_scale_binary_args[DNNL_ARG_SRC_0]
@@ -213,13 +171,11 @@ status_t cudnn_matmul_lt_t::execute(const exec_ctx_t &ctx) const {
213171
CHECK(dst_scale_binary_->execute(binary_ctx));
214172
}
215173

216-
if (has_runtime_args) {
174+
if (pd()->params_->has_runtime_params_) {
217175
auto &evts = cuda_stream->sycl_ctx().get_sycl_deps().events;
218176
for (auto e : evts) {
219177
e.wait();
220178
}
221-
222-
matmul_impl_->rt_cleanup();
223179
}
224180

225181
return status::success;

src/gpu/nvidia/cudnn_matmul.hpp

+19-12
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,21 @@
2020

2121
#include "gpu/gpu_matmul_pd.hpp"
2222

23-
#include "gpu/nvidia/cudnn_matmul_base.hpp"
23+
#include "common/primitive.hpp"
24+
#include "common/primitive_desc_iterator.hpp"
25+
#include "gpu/gpu_primitive.hpp"
2426
#include "gpu/nvidia/cudnn_matmul_executor.hpp"
2527
#include "gpu/nvidia/cudnn_matmul_impl.hpp"
28+
#include "gpu/nvidia/cudnn_matmul_lt_impl.hpp"
2629
#include "gpu/nvidia/sycl_cuda_utils.hpp"
2730

2831
namespace dnnl {
2932
namespace impl {
3033
namespace gpu {
3134
namespace nvidia {
3235

33-
struct cudnn_matmul_t : cudnn_matmul_base_t {
34-
using cudnn_matmul_base_t::cudnn_matmul_base_t;
36+
struct cudnn_matmul_t : public gpu::primitive_t {
37+
using primitive_t::primitive_t;
3538

3639
struct pd_t : public gpu_matmul_pd_t {
3740
using gpu_matmul_pd_t::gpu_matmul_pd_t;
@@ -79,12 +82,15 @@ struct cudnn_matmul_t : cudnn_matmul_base_t {
7982

8083
if (src_md()->ndims > 3) return status::unimplemented;
8184

82-
return status::success;
83-
}
85+
params_ = std::make_shared<cublas_params>();
86+
CHECK(params_->init(src_md(), weights_md(), dst_md(), weights_md(1),
87+
attr(), batched(), with_bias()));
8488

85-
size_t scratchpad_size(const memory_desc_t *dst_md) const {
86-
const auto dst_nelems = memory_desc_wrapper(dst_md).nelems(true);
87-
return dst_nelems * sizeof(float);
89+
if (!params_->has_runtime_params_) {
90+
auto scratchpad = scratchpad_registry().registrar();
91+
params_->init_scratchpad(dst_md(), scratchpad);
92+
}
93+
return status::success;
8894
}
8995

9096
bool scales_ok() const {
@@ -116,21 +122,22 @@ struct cudnn_matmul_t : cudnn_matmul_base_t {
116122
}
117123
return true;
118124
}
125+
126+
std::shared_ptr<cublas_params> params_;
119127
};
120128

121129
status_t init(impl::engine_t *engine) override {
122130
matmul_impl_.reset(new cudnn_matmul_impl_t());
123-
auto status = matmul_impl_->init((matmul_pd_t *)pd());
124-
if (status != status::success) return status;
125131

126-
bool has_runtime_args = matmul_impl_->has_runtime_params();
132+
bool has_runtime_args = pd()->params_->has_runtime_params_;
127133

128134
if (has_runtime_args) {
129135
executor_.reset(new cudnn_matmul_runtime_args_exec_t);
130136
} else {
131137
executor_.reset(new cudnn_matmul_exec_t);
138+
matmul_impl_->set_non_runtime_params(pd()->params_);
132139
}
133-
return status;
140+
return status::success;
134141
}
135142

136143
status_t execute(const exec_ctx_t &ctx) const override;

src/gpu/nvidia/cudnn_matmul_base.hpp

-50
This file was deleted.

0 commit comments

Comments
 (0)