Skip to content

Commit 762e317

Browse files
ShanoTonivpirogov
authored andcommitted
generic: sycl: Adding support for RNN FWD r2l, sum & concat
1 parent 1498c83 commit 762e317

File tree

6 files changed

+89
-43
lines changed

6 files changed

+89
-43
lines changed

src/gpu/generic/sycl/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -193,4 +193,4 @@ The implementation supports forward propagation and vanilla RNN cell kind.
193193

194194
* Supported formats: `ldigo`, `ldgoi`
195195
* Supported data types: `f32`, `bf16`, `f16`, `s8`, `u8`
196-
* Supported direction: `left2right`
196+
* Supported direction: `left2right`, `right2left`, `concat`, `sum`

src/gpu/generic/sycl/rnn/cell_common.cpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,17 @@ using namespace rnn_utils;
2929

3030
status_t _ref_rnn_common_t::cell_execution(const cell_ctx_t &cell_struct) {
3131

32-
auto cell_layer = cell_struct.workspace.states_range(cell_struct.lay - 1,
33-
cell_struct.lay - 1, cell_struct.dir, cell_struct.dir,
34-
cell_struct.iter - 1, cell_struct.iter);
32+
auto cell_layer = cell_struct.workspace.states_range(cell_struct.lay,
33+
cell_struct.lay, cell_struct.dir, cell_struct.dir, cell_struct.iter,
34+
cell_struct.iter);
3535

36-
auto cell_iter = cell_struct.workspace.states_range(cell_struct.lay,
37-
cell_struct.lay, cell_struct.dir, cell_struct.dir,
38-
cell_struct.iter - 2, cell_struct.iter - 1);
36+
auto iter_off = cell_struct.iter == 0
37+
? (-1 * (cell_struct.rnn.n_dir - 1) * (cell_struct.rnn.n_iter + 1))
38+
- 1
39+
: cell_struct.iter - 1;
40+
auto cell_iter = cell_struct.workspace.states_range(cell_struct.lay + 1,
41+
cell_struct.lay + 1, cell_struct.dir, cell_struct.dir, iter_off,
42+
iter_off);
3943

4044
auto scratch_gates = cell_struct.scratch.gates(0);
4145

src/gpu/generic/sycl/rnn/ref_rnn.cpp

+18-11
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,6 @@ status_t _ref_rnn_common_t::pd_t::init(impl::engine_t *engine) {
127127
VDISPATCH_RNN_SC(set_weights_desc(this->weights_iter_md_, rnn_conf),
128128
"unsupported weights iter memory descriptor");
129129

130-
// Currently only run L2R
131-
VDISPATCH_RNN(this->direction() == dnnl_unidirectional_left2right,
132-
VERBOSE_BAD_ALGORITHM);
133130
// Check dimensions consistency
134131
VDISPATCH_RNN((this->SIC() == this->DHC() || (this->T() == 1)),
135132
VERBOSE_INCONSISTENT_DIM, "SIC", (int)this->SIC(), "DHC",
@@ -154,10 +151,19 @@ status_t _ref_rnn_common_t::pd_t::init(impl::engine_t *engine) {
154151
CHECK(memory_desc_init_by_tag(state_md, 5, state_dims,
155152
rnn_conf.src_data_type, format_tag::abcde));
156153

157-
copy_init_layer_conf_ = sycl_rnn_copy_conf_t {
158-
xpu::sycl::md_t(this->src_md(0)), xpu::sycl::md_t(&state_md),
159-
rnn_conf.slc, rnn_conf.n_dir, rnn_conf.n_layer, rnn_conf.n_iter,
160-
rnn_conf.mb, rnn_conf.states_ws_ld, true, true};
154+
// using is_l2r/r2l to account for bidirectional as well
155+
// if both l2r and r2l are true, case is bidirectional concat
156+
// is_sum changes behaviour from concat to sum for bidirectional case
157+
158+
bool is_l2r = !(this->desc()->direction == dnnl_unidirectional_right2left);
159+
bool is_r2l = !(this->desc()->direction == dnnl_unidirectional_left2right);
160+
bool is_sum = this->desc()->direction == dnnl_bidirectional_sum;
161+
162+
copy_init_layer_conf_
163+
= sycl_rnn_copy_conf_t {xpu::sycl::md_t(this->src_md(0)),
164+
xpu::sycl::md_t(&state_md), rnn_conf.slc, rnn_conf.n_dir,
165+
rnn_conf.n_layer, rnn_conf.n_iter, rnn_conf.mb,
166+
rnn_conf.states_ws_ld, true, true, is_l2r, is_r2l, false};
161167

162168
xpu::sycl::md_t src_iter_md = this->src_md(1)->data_type == data_type::undef
163169
? xpu::sycl::md_t()
@@ -166,20 +172,21 @@ status_t _ref_rnn_common_t::pd_t::init(impl::engine_t *engine) {
166172
copy_init_iter_conf_ = sycl_rnn_copy_conf_t {src_iter_md,
167173
xpu::sycl::md_t(&state_md), rnn_conf.sic, rnn_conf.n_dir,
168174
rnn_conf.n_layer, rnn_conf.n_iter, rnn_conf.mb,
169-
rnn_conf.states_ws_ld, false, true};
175+
rnn_conf.states_ws_ld, false, true, is_l2r, is_r2l, false};
170176

171177
copy_res_layer_conf_ = sycl_rnn_copy_conf_t {xpu::sycl::md_t(&state_md),
172178
xpu::sycl::md_t(this->dst_md(0)), rnn_conf.dhc, rnn_conf.n_dir,
173179
rnn_conf.n_layer, rnn_conf.n_iter, rnn_conf.mb,
174-
rnn_conf.states_ws_ld, true, false};
180+
rnn_conf.states_ws_ld, true, false, is_l2r, is_r2l, is_sum};
175181

176182
xpu::sycl::md_t dst_iter_md = this->dst_md(1)->data_type == data_type::undef
177183
? xpu::sycl::md_t()
178184
: xpu::sycl::md_t(this->dst_md(1));
179185

180186
copy_res_iter_conf_ = sycl_rnn_copy_conf_t {xpu::sycl::md_t(&state_md),
181187
dst_iter_md, rnn_conf.dhc, rnn_conf.n_dir, rnn_conf.n_layer,
182-
rnn_conf.n_iter, rnn_conf.mb, rnn_conf.states_ws_ld, false, false};
188+
rnn_conf.n_iter, rnn_conf.mb, rnn_conf.states_ws_ld, false, false,
189+
is_l2r, is_r2l, false};
183190

184191
sycl_rnn_bias_conf_t_ = sycl_rnn_bias_conf_t();
185192
sycl_rnn_bias_conf_t_.dst_md = xpu::sycl::md_t(this->dst_md(0));
@@ -580,7 +587,7 @@ status_t _ref_rnn_common_t::rnn_bias(const exec_ctx_t &ctx, dim_t batch,
580587

581588
auto dst_mem_arg
582589
= utils::downcast<const xpu::sycl::memory_storage_base_t *>(
583-
ws.states(lay, dir, iter - 1).get())
590+
ws.states(lay + 1, dir, iter).get())
584591
->get_out_memory_arg(ctx.stream(), cgh);
585592
ref_rnn_bias bias_kernel(pd()->sycl_rnn_bias_conf_t_, src_mem_arg,
586593
bias_mem_arg, dst_mem_arg);

src/gpu/generic/sycl/rnn/rnn_kernels.hpp

+54-14
Original file line numberDiff line numberDiff line change
@@ -49,35 +49,83 @@ struct ref_rnn_copy_t {
4949
: src_ {src}, dst_ {dst}, conf_ {conf} {}
5050

5151
void operator()(::sycl::nd_item<3> item) const {
52-
const dim_t tl = item.get_global_id(0) / conf_.n_dir; // timestep/layer
53-
const dim_t dir = item.get_global_id(0) % conf_.n_dir; // direction
52+
const dim_t tl = item.get_global_id(0) // timestep/layer
53+
/ (conf_.layer ? 1 : conf_.n_dir);
54+
dim_t dir = conf_.layer
55+
? 0
56+
: item.get_global_id(0) % conf_.n_dir; // direction
5457
const dim_t n = item.get_global_id(1); // batch
5558
const dim_t c = item.get_global_id(2); // channel
5659

5760
if (dir >= conf_.n_dir || n >= conf_.batch || c >= conf_.range) return;
5861

5962
dim_t src_offset = 0;
6063
dim_t dst_offset = 0;
64+
6165
if (conf_.layer) { // layer
6266
if (tl >= conf_.n_iter) return;
6367
if (conf_.to_state) { // init
64-
src_offset = conf_.src_md.off(tl, n, c);
65-
dst_offset = conf_.dst_md.off(0, dir, tl, n, c);
68+
if (conf_.l2r) { // l2r
69+
src_offset = conf_.src_md.off(tl, n, c);
70+
dst_offset = conf_.dst_md.off(0, dir, tl, n, c);
71+
do_copy(src_offset, dst_offset, src_ptr(), dst_ptr());
72+
dir = 1;
73+
}
74+
if (conf_.r2l) { // r2l
75+
src_offset = conf_.src_md.off(tl, n, c);
76+
dst_offset = conf_.dst_md.off(
77+
0, conf_.n_dir - 1, conf_.n_iter - tl - 1, n, c);
78+
do_copy(src_offset, dst_offset, src_ptr(), dst_ptr());
79+
}
6680
} else { // res
67-
src_offset = conf_.src_md.off(conf_.n_layer, dir, tl, n, c);
68-
dst_offset = conf_.dst_md.off(tl, n, dir * conf_.range + c);
81+
if (conf_.l2r) {
82+
dst_offset = conf_.dst_md.off(tl, n, dir * conf_.range + c);
83+
src_offset = conf_.src_md.off(conf_.n_layer, dir, tl, n, c);
84+
do_copy(src_offset, dst_offset, src_ptr(), dst_ptr());
85+
dir = 1;
86+
}
87+
if (conf_.r2l) {
88+
dst_offset = conf_.dst_md.off(tl, n, dir * conf_.range + c);
89+
src_offset = conf_.src_md.off(
90+
conf_.n_layer, dir, conf_.n_iter - tl - 1, n, c);
91+
if (conf_.sum) {
92+
dst_offset = conf_.dst_md.off(tl, n, c);
93+
auto src = load_float_value(
94+
src_md().data_type(), src_ptr(), src_offset);
95+
auto dst = load_float_value(conf_.dst_md.data_type(),
96+
dst_ptr(), dst_offset);
97+
store_float_value(src_md().data_type(), src + dst,
98+
dst_ptr(), dst_offset);
99+
} else {
100+
do_copy(src_offset, dst_offset, src_ptr(), dst_ptr());
101+
}
102+
}
69103
}
70104
} else { // iter
71105
if (tl >= conf_.n_layer) return;
72106
if (conf_.to_state) { // init
73107
src_offset = conf_.src_md.off(tl, dir, n, c);
74108
dst_offset = conf_.dst_md.off(tl, dir, conf_.n_iter, n, c);
109+
do_copy(src_offset, dst_offset, src_ptr(), dst_ptr());
75110
} else { // res
76111
src_offset
77112
= conf_.src_md.off(tl + 1, dir, conf_.n_iter - 1, n, c);
78113
dst_offset = conf_.dst_md.off(tl, dir, n, c);
114+
do_copy(src_offset, dst_offset, src_ptr(), dst_ptr());
79115
}
80116
}
117+
}
118+
119+
xpu::sycl::in_memory_arg_t src_;
120+
xpu::sycl::out_memory_arg_t dst_;
121+
sycl_rnn_copy_conf_t conf_;
122+
123+
const xpu::sycl::md_t &src_md() const { return conf_.src_md; }
124+
void *src_ptr() const { return src_.get_pointer(); }
125+
void *dst_ptr() const { return dst_.get_pointer(); }
126+
127+
void do_copy(
128+
dim_t src_offset, dim_t dst_offset, void *from, void *to) const {
81129
if (src_ptr()) {
82130
auto src = load_float_value(
83131
src_md().data_type(), src_ptr(), src_offset);
@@ -92,14 +140,6 @@ struct ref_rnn_copy_t {
92140
}
93141
}
94142
}
95-
96-
xpu::sycl::in_memory_arg_t src_;
97-
xpu::sycl::out_memory_arg_t dst_;
98-
sycl_rnn_copy_conf_t conf_;
99-
100-
const xpu::sycl::md_t &src_md() const { return conf_.src_md; }
101-
void *src_ptr() const { return src_.get_pointer(); }
102-
void *dst_ptr() const { return dst_.get_pointer(); }
103143
};
104144

105145
struct ref_rnn_bias {

src/gpu/generic/sycl/rnn/rnn_utils.hpp

+3-11
Original file line numberDiff line numberDiff line change
@@ -205,14 +205,9 @@ struct workspace_t : public data_helper_t {
205205
}
206206

207207
dim_t calc_off_ws_state(
208-
dim_t i0_, dim_t i1, dim_t i2_, dim_t i3, dim_t i4) const {
209-
//lay,dir,time
210-
// Logical index into workspace grid
211-
auto i0 = i0_ + 1;
212-
auto i2 = i2_ + 1;
213-
208+
dim_t i0, dim_t i1, dim_t i2, dim_t i3, dim_t i4) const {
214209
assert(i0 >= 0);
215-
210+
//lay,dir,time
216211
return calc_4d_off(i0, i1, conf_.n_dir, i2, conf_.n_iter + 1, i3,
217212
conf_.mb, i4, conf_.states_ws_ld);
218213
}
@@ -241,10 +236,7 @@ struct workspace_t : public data_helper_t {
241236

242237
std::unique_ptr<mst> states(dim_t layer, dim_t dir, dim_t time) const {
243238
if (!states_) return {};
244-
245-
auto i0 = layer + 1;
246-
auto i2 = time + 1;
247-
auto off_ = get_offset(states_strides(), {i0, dir, i2, 0})
239+
auto off_ = get_offset(states_strides(), {layer, dir, time, 0})
248240
* conf_.ws_states_elsz;
249241
return states().clone_ptr_off(off_);
250242
}

src/gpu/generic/sycl/sycl_primitive_conf.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,9 @@ struct sycl_rnn_copy_conf_t {
471471
dim_t states_ws_ld;
472472
bool layer;
473473
bool to_state;
474+
bool l2r;
475+
bool r2l;
476+
bool sum;
474477
};
475478

476479
struct sycl_rnn_bias_conf_t {

0 commit comments

Comments
 (0)