Skip to content

Commit 6c2091f

Browse files
committed
sycl: simplify bin_bcast_kernel
1 parent de4c07f commit 6c2091f

File tree

1 file changed

+121
-232
lines changed

1 file changed

+121
-232
lines changed

ggml/src/ggml-sycl/binbcast.cpp

+121-232
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,74 @@
11
#include "binbcast.hpp"
22

3+
#include <array>
34
#include <cstddef>
45
#include <cstdint>
56
#include <sycl/sycl.hpp>
67

8+
#include "dpct/helper.hpp"
79
#include "ggml.h"
810

9-
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
10-
static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
11-
int ne0, int ne1, int ne2, int ne3,
12-
int ne10, int ne11, int ne12, int ne13,
13-
/*int s0, */ int s1, int s2, int s3,
14-
/*int s00,*/ int s01, int s02, int s03,
15-
/*int s10,*/ int s11, int s12, int s13,
16-
const sycl::nd_item<3> &item_ct1) {
17-
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
18-
item_ct1.get_local_id(2);
19-
const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
20-
item_ct1.get_local_id(1));
21-
const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
22-
item_ct1.get_local_id(0)) /
23-
ne3;
24-
const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
25-
item_ct1.get_local_id(0)) %
26-
ne3;
27-
28-
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
29-
return;
30-
}
31-
32-
const int i11 = i1 % ne11;
33-
const int i12 = i2 % ne12;
34-
const int i13 = i3 % ne13;
35-
36-
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
37-
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
38-
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
39-
40-
const src0_t * src0_row = src0 + i_src0;
41-
const src1_t * src1_row = src1 + i_src1;
42-
dst_t * dst_row = dst + i_dst;
43-
44-
for (int i0 = i0s; i0 < ne0;
45-
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
46-
const int i10 = i0 % ne10;
47-
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
11+
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
12+
static __dpct_inline__ void k_bin_bcast_contiguous(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1,
13+
dst_t * dst, std::size_t num_elements, const sycl::nd_item<1> & it) {
14+
auto element_id = it.get_global_id(0);
15+
auto global_range = it.get_global_range(0);
16+
for (; element_id < num_elements; element_id += global_range) {
17+
auto src0_float_val = sycl::vec(src0[element_id]).template convert<float, sycl::rounding_mode::rte>();
18+
auto src1_float_val = sycl::vec(src1[element_id]).template convert<float, sycl::rounding_mode::rte>();
19+
float dst_val = bin_op(src0_float_val[0], src1_float_val[0]);
20+
auto val_to_store = sycl::vec(dst_val).template convert<dst_t, sycl::rounding_mode::rte>();
21+
dst[element_id] = val_to_store;
4822
}
4923
}
5024

51-
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
52-
static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
53-
int ne0, int ne1, int ne2, int ne3,
54-
int ne10, int ne11, int ne12, int ne13,
55-
/*int s0, */ int s1, int s2, int s3,
56-
/*int s00,*/ int s01, int s02, int s03,
57-
/*int s10,*/ int s11, int s12, int s13,
58-
const sycl::nd_item<3> &item_ct1) {
59-
60-
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
61-
item_ct1.get_local_id(2);
62-
63-
const int i3 = i/(ne2*ne1*ne0);
64-
const int i2 = (i/(ne1*ne0)) % ne2;
65-
const int i1 = (i/ne0) % ne1;
66-
const int i0 = i % ne0;
67-
68-
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
69-
return;
25+
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
26+
static __dpct_inline__ void k_bin_bcast(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, dst_t * dst,
27+
int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13,
28+
int s0, int s1, int s2, int s3, int s00, int s01, int s02, int s03, int s10,
29+
int s11, int s12, int s13, std::size_t num_dst_elements,
30+
const sycl::nd_item<1> & item_ct1) {
31+
auto calculate_logical_index =
32+
[](const std::array<int, 4> & dims, std::size_t element_id) __attribute__((always_inline))->std::array<int, 4> {
33+
std::array<int, 4> logical_index;
34+
#pragma unroll(4)
35+
for (int i = 3; i >= 0; i--) {
36+
logical_index[i] = element_id % dims[i];
37+
element_id /= dims[i];
38+
}
39+
return logical_index;
40+
};
41+
42+
auto calculate_index = [](const std::array<int, 4> & dims, const std::array<int, 4> & strides,
43+
const std::array<int, 4> & indices) __attribute__((always_inline))
44+
->std::size_t {
45+
std::size_t index = 0;
46+
#pragma unroll(4)
47+
for (int i = 0; i < 4; i++) {
48+
auto index_i = indices[i];
49+
if (indices[i] >= dims[i]) {
50+
index_i = indices[i] % dims[i];
51+
}
52+
index += strides[i] * index_i;
53+
}
54+
return index;
55+
};
56+
57+
auto element_id = item_ct1.get_global_id(0);
58+
for (; element_id < num_dst_elements; element_id += item_ct1.get_global_range(0)) {
59+
auto logical_index = calculate_logical_index({ ne3, ne2, ne1, ne0 }, element_id);
60+
auto src_0_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s03, s02, s01, s00 }, logical_index);
61+
auto src_1_index = calculate_index({ ne13, ne12, ne11, ne10 }, { s13, s12, s11, s10 }, logical_index);
62+
auto dst_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s3, s2, s1, s0 }, logical_index);
63+
auto src0_float_val = sycl::vec(src0[src_0_index]).template convert<float, sycl::rounding_mode::rte>();
64+
auto src1_float_val = sycl::vec(src1[src_1_index]).template convert<float, sycl::rounding_mode::rte>();
65+
float dst_val = bin_op(src0_float_val[0], src1_float_val[0]);
66+
auto val_to_store = sycl::vec(dst_val).template convert<dst_t, sycl::rounding_mode::rte>();
67+
dst[dst_index] = val_to_store;
7068
}
71-
72-
const int i11 = i1 % ne11;
73-
const int i12 = i2 % ne12;
74-
const int i13 = i3 % ne13;
75-
76-
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
77-
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
78-
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
79-
80-
const src0_t * src0_row = src0 + i_src0;
81-
const src1_t * src1_row = src1 + i_src1;
82-
dst_t * dst_row = dst + i_dst;
83-
84-
const int i10 = i0 % ne10;
85-
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
8669
}
8770

88-
89-
template<float (*bin_op)(const float, const float)>
90-
struct bin_bcast_sycl {
71+
template <float (*bin_op)(const float, const float)> struct bin_bcast_sycl {
9172
template <typename src0_t, typename src1_t, typename dst_t>
9273
void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00,
9374
const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11,
@@ -96,165 +77,73 @@ struct bin_bcast_sycl {
9677
const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
9778
const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
9879
const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
99-
int nr0 = ne10 / ne0;
100-
int nr1 = ne11/ne1;
101-
int nr2 = ne12/ne2;
102-
int nr3 = ne13/ne3;
103-
104-
int nr[4] = { nr0, nr1, nr2, nr3 };
105-
106-
// collapse dimensions until first broadcast dimension
107-
int64_t cne[] = {ne0, ne1, ne2, ne3};
108-
int64_t cne0[] = {ne00, ne01, ne02, ne03};
109-
int64_t cne1[] = {ne10, ne11, ne12, ne13};
110-
size_t cnb[] = {nb0, nb1, nb2, nb3};
111-
size_t cnb0[] = {nb00, nb01, nb02, nb03};
112-
size_t cnb1[] = {nb10, nb11, nb12, nb13};
113-
auto collapse = [](int64_t cne[]) {
114-
cne[0] *= cne[1];
115-
cne[1] = cne[2];
116-
cne[2] = cne[3];
117-
cne[3] = 1;
118-
};
119-
120-
auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
121-
cnb[1] *= cne[1];
122-
cnb[2] *= cne[2];
123-
cnb[3] *= cne[3];
124-
};
125-
126-
if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) {
80+
auto check_bcast_required = [](const std::array<int64_t, 4> & src_dims,
81+
const std::array<int64_t, 4> & dst_dims) -> bool {
12782
for (int i = 0; i < 4; i++) {
128-
if (nr[i] != 1) {
129-
break;
130-
}
131-
if (i > 0) {
132-
collapse_nb(cnb, cne);
133-
collapse_nb(cnb0, cne0);
134-
collapse_nb(cnb1, cne1);
135-
collapse(cne);
136-
collapse(cne0);
137-
collapse(cne1);
83+
if (dst_dims[i] > src_dims[i]) {
84+
return true;
13885
}
13986
}
140-
}
141-
{
142-
int64_t ne0 = cne[0];
143-
int64_t ne1 = cne[1];
144-
int64_t ne2 = cne[2];
145-
int64_t ne3 = cne[3];
146-
147-
int64_t ne10 = cne1[0];
148-
int64_t ne11 = cne1[1];
149-
int64_t ne12 = cne1[2];
150-
int64_t ne13 = cne1[3];
151-
152-
size_t nb0 = cnb[0];
153-
size_t nb1 = cnb[1];
154-
size_t nb2 = cnb[2];
155-
size_t nb3 = cnb[3];
156-
157-
size_t nb00 = cnb0[0];
158-
size_t nb01 = cnb0[1];
159-
size_t nb02 = cnb0[2];
160-
size_t nb03 = cnb0[3];
161-
162-
size_t nb10 = cnb1[0];
163-
size_t nb11 = cnb1[1];
164-
size_t nb12 = cnb1[2];
165-
size_t nb13 = cnb1[3];
166-
167-
size_t s0 = nb0 / sizeof(dst_t);
168-
size_t s1 = nb1 / sizeof(dst_t);
169-
size_t s2 = nb2 / sizeof(dst_t);
170-
size_t s3 = nb3 / sizeof(dst_t);
171-
172-
size_t s10 = nb10 / sizeof(src1_t);
173-
size_t s11 = nb11 / sizeof(src1_t);
174-
size_t s12 = nb12 / sizeof(src1_t);
175-
size_t s13 = nb13 / sizeof(src1_t);
176-
177-
size_t s00 = nb00 / sizeof(src0_t);
178-
size_t s01 = nb01 / sizeof(src0_t);
179-
size_t s02 = nb02 / sizeof(src0_t);
180-
size_t s03 = nb03 / sizeof(src0_t);
181-
182-
GGML_UNUSED(s00);
183-
184-
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
185-
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
186-
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
187-
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
188-
189-
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
190-
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
191-
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
192-
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
193-
194-
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
195-
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
196-
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
197-
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
198-
199-
GGML_ASSERT(s0 == 1);
200-
GGML_ASSERT(s10 == 1);
201-
202-
const int block_size = 128;
203-
204-
int64_t hne0 = std::max(ne0/2LL, 1LL);
205-
206-
sycl::range<3> block_dims(1, 1, 1);
207-
block_dims[2] = std::min<unsigned int>(hne0, block_size);
208-
block_dims[1] = std::min<unsigned int>(
209-
ne1, block_size / (unsigned int)block_dims[2]);
210-
block_dims[0] = std::min(
211-
std::min<unsigned int>(
212-
ne2 * ne3, block_size / (unsigned int)block_dims[2] /
213-
(unsigned int)block_dims[1]),
214-
64U);
215-
216-
sycl::range<3> block_nums(
217-
(ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
218-
(ne1 + block_dims[1] - 1) / block_dims[1],
219-
(hne0 + block_dims[2] - 1) / block_dims[2]);
220-
221-
if (block_nums[0] > 65535) {
222-
// this is the maximum number of blocks in z direction, fallback to 1D grid kernel
223-
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
224-
{
225-
dpct::has_capability_or_fail(stream->get_device(),
226-
{sycl::aspect::fp16});
227-
228-
stream->parallel_for(
229-
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
230-
sycl::range<3>(1, 1, block_size),
231-
sycl::range<3>(1, 1, block_size)),
232-
[=](sycl::nd_item<3> item_ct1) {
233-
k_bin_bcast_unravel<bin_op>(
234-
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
235-
ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
236-
s03, s11, s12, s13, item_ct1);
237-
});
238-
}
239-
} else {
240-
/*
241-
DPCT1049:16: The work-group size passed to the SYCL kernel may
242-
exceed the limit. To get the device limit, query
243-
info::device::max_work_group_size. Adjust the work-group size if
244-
needed.
245-
*/
246-
dpct::has_capability_or_fail(stream->get_device(),
247-
{sycl::aspect::fp16});
248-
249-
stream->parallel_for(
250-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
251-
[=](sycl::nd_item<3> item_ct1) {
252-
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
253-
ne2, ne3, ne10, ne11, ne12, ne13,
254-
s1, s2, s3, s01, s02, s03, s11, s12, s13,
255-
item_ct1);
256-
});
257-
}
87+
return false;
88+
};
89+
90+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
91+
92+
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
93+
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
94+
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
95+
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
96+
97+
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
98+
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
99+
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
100+
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
101+
102+
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
103+
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
104+
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
105+
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
106+
107+
// dst strides in number of elements
108+
size_t s0 = nb0 / sizeof(dst_t);
109+
size_t s1 = nb1 / sizeof(dst_t);
110+
size_t s2 = nb2 / sizeof(dst_t);
111+
size_t s3 = nb3 / sizeof(dst_t);
112+
113+
// src1 strides in number of elements
114+
size_t s10 = nb10 / sizeof(src0_t);
115+
size_t s11 = nb11 / sizeof(src1_t);
116+
size_t s12 = nb12 / sizeof(src1_t);
117+
size_t s13 = nb13 / sizeof(src1_t);
118+
119+
// src0 strides in number of elements
120+
size_t s00 = nb00 / sizeof(src0_t);
121+
size_t s01 = nb01 / sizeof(src0_t);
122+
size_t s02 = nb02 / sizeof(src0_t);
123+
size_t s03 = nb03 / sizeof(src0_t);
124+
125+
std::size_t num_dst_elements = static_cast<std::size_t>(ne0) * static_cast<std::size_t>(ne1) *
126+
static_cast<std::size_t>(ne2) * static_cast<std::size_t>(ne3);
127+
std::size_t local_range = 256;
128+
std::size_t global_range = ceil_div(num_dst_elements, local_range) * local_range;
129+
130+
bool needs_broadcasting = check_bcast_required({ ne00, ne01, ne02, ne03 }, { ne0, ne1, ne2, ne3 }) ||
131+
check_bcast_required({ ne10, ne11, ne12, ne13 }, { ne0, ne1, ne2, ne3 });
132+
bool all_contiguous = src0_is_contiguous && src1_is_contiguous && dst_is_contiguous;
133+
134+
if (! needs_broadcasting && all_contiguous) {
135+
stream->submit([&](sycl::handler & cgh) {
136+
cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) {
137+
k_bin_bcast_contiguous<bin_op>(src0_dd, src1_dd, dst_dd, num_dst_elements, it);
138+
});
139+
});
140+
} else {
141+
stream->submit([&](sycl::handler & cgh) {
142+
cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) {
143+
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, s0, s1,
144+
s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, num_dst_elements, it);
145+
});
146+
});
258147
}
259148
}
260149
};

0 commit comments

Comments
 (0)