@@ -49,35 +49,83 @@ struct ref_rnn_copy_t {
49
49
: src_ {src}, dst_ {dst}, conf_ {conf} {}
50
50
51
51
void operator ()(::sycl::nd_item<3 > item) const {
52
- const dim_t tl = item.get_global_id (0 ) / conf_.n_dir ; // timestep/layer
53
- const dim_t dir = item.get_global_id (0 ) % conf_.n_dir ; // direction
52
+ const dim_t tl = item.get_global_id (0 ) // timestep/layer
53
+ / (conf_.layer ? 1 : conf_.n_dir );
54
+ dim_t dir = conf_.layer
55
+ ? 0
56
+ : item.get_global_id (0 ) % conf_.n_dir ; // direction
54
57
const dim_t n = item.get_global_id (1 ); // batch
55
58
const dim_t c = item.get_global_id (2 ); // channel
56
59
57
60
if (dir >= conf_.n_dir || n >= conf_.batch || c >= conf_.range ) return ;
58
61
59
62
dim_t src_offset = 0 ;
60
63
dim_t dst_offset = 0 ;
64
+
61
65
if (conf_.layer ) { // layer
62
66
if (tl >= conf_.n_iter ) return ;
63
67
if (conf_.to_state ) { // init
64
- src_offset = conf_.src_md .off (tl, n, c);
65
- dst_offset = conf_.dst_md .off (0 , dir, tl, n, c);
68
+ if (conf_.l2r ) { // l2r
69
+ src_offset = conf_.src_md .off (tl, n, c);
70
+ dst_offset = conf_.dst_md .off (0 , dir, tl, n, c);
71
+ do_copy (src_offset, dst_offset, src_ptr (), dst_ptr ());
72
+ dir = 1 ;
73
+ }
74
+ if (conf_.r2l ) { // r2l
75
+ src_offset = conf_.src_md .off (tl, n, c);
76
+ dst_offset = conf_.dst_md .off (
77
+ 0 , conf_.n_dir - 1 , conf_.n_iter - tl - 1 , n, c);
78
+ do_copy (src_offset, dst_offset, src_ptr (), dst_ptr ());
79
+ }
66
80
} else { // res
67
- src_offset = conf_.src_md .off (conf_.n_layer , dir, tl, n, c);
68
- dst_offset = conf_.dst_md .off (tl, n, dir * conf_.range + c);
81
+ if (conf_.l2r ) {
82
+ dst_offset = conf_.dst_md .off (tl, n, dir * conf_.range + c);
83
+ src_offset = conf_.src_md .off (conf_.n_layer , dir, tl, n, c);
84
+ do_copy (src_offset, dst_offset, src_ptr (), dst_ptr ());
85
+ dir = 1 ;
86
+ }
87
+ if (conf_.r2l ) {
88
+ dst_offset = conf_.dst_md .off (tl, n, dir * conf_.range + c);
89
+ src_offset = conf_.src_md .off (
90
+ conf_.n_layer , dir, conf_.n_iter - tl - 1 , n, c);
91
+ if (conf_.sum ) {
92
+ dst_offset = conf_.dst_md .off (tl, n, c);
93
+ auto src = load_float_value (
94
+ src_md ().data_type (), src_ptr (), src_offset);
95
+ auto dst = load_float_value (conf_.dst_md .data_type (),
96
+ dst_ptr (), dst_offset);
97
+ store_float_value (src_md ().data_type (), src + dst,
98
+ dst_ptr (), dst_offset);
99
+ } else {
100
+ do_copy (src_offset, dst_offset, src_ptr (), dst_ptr ());
101
+ }
102
+ }
69
103
}
70
104
} else { // iter
71
105
if (tl >= conf_.n_layer ) return ;
72
106
if (conf_.to_state ) { // init
73
107
src_offset = conf_.src_md .off (tl, dir, n, c);
74
108
dst_offset = conf_.dst_md .off (tl, dir, conf_.n_iter , n, c);
109
+ do_copy (src_offset, dst_offset, src_ptr (), dst_ptr ());
75
110
} else { // res
76
111
src_offset
77
112
= conf_.src_md .off (tl + 1 , dir, conf_.n_iter - 1 , n, c);
78
113
dst_offset = conf_.dst_md .off (tl, dir, n, c);
114
+ do_copy (src_offset, dst_offset, src_ptr (), dst_ptr ());
79
115
}
80
116
}
117
+ }
118
+
119
+ xpu::sycl::in_memory_arg_t src_;
120
+ xpu::sycl::out_memory_arg_t dst_;
121
+ sycl_rnn_copy_conf_t conf_;
122
+
123
+ const xpu::sycl::md_t &src_md () const { return conf_.src_md ; }
124
+ void *src_ptr () const { return src_.get_pointer (); }
125
+ void *dst_ptr () const { return dst_.get_pointer (); }
126
+
127
+ void do_copy (
128
+ dim_t src_offset, dim_t dst_offset, void *from, void *to) const {
81
129
if (src_ptr ()) {
82
130
auto src = load_float_value (
83
131
src_md ().data_type (), src_ptr (), src_offset);
@@ -92,14 +140,6 @@ struct ref_rnn_copy_t {
92
140
}
93
141
}
94
142
}
95
-
96
- xpu::sycl::in_memory_arg_t src_;
97
- xpu::sycl::out_memory_arg_t dst_;
98
- sycl_rnn_copy_conf_t conf_;
99
-
100
- const xpu::sycl::md_t &src_md () const { return conf_.src_md ; }
101
- void *src_ptr () const { return src_.get_pointer (); }
102
- void *dst_ptr () const { return dst_.get_pointer (); }
103
143
};
104
144
105
145
struct ref_rnn_bias {
0 commit comments