Skip to content

Commit 0f7ed03

Browse files
committed
src: gpu: jit: unify argument key types across generator
1 parent 1e41a03 commit 0f7ed03

File tree

7 files changed

+47
-55
lines changed

7 files changed

+47
-55
lines changed

src/gpu/jit/conv/gen_convolution.cpp

+30-31
Original file line numberDiff line numberDiff line change
@@ -376,14 +376,16 @@ class gen_convolution_t {
376376
auto &cfg = data.pd_cfg;
377377
const bool needs_zp_precalc = cfg.zp_cfg().needs_src_precalc;
378378

379-
auto scratchpad = pd->scratchpad_registry().registrar();
380379
auto &conv_info = create_kernel_info(pd, kernel_id_t::convolution);
381380
auto &zp_precalc_info = (needs_zp_precalc)
382381
? create_kernel_info(pd, kernel_id_t::zp_precalc)
383382
: conv_info;
384383

384+
static_assert(DNNL_ARG_UNDEF == memory_tracking::names::key_none,
385+
"Undefined argument and empty scratchpad key are out of sync!");
386+
385387
// Initialize kernel arguments.
386-
int scratchpad_key = 0;
388+
int scratchpad_key = memory_tracking::names::key_none;
387389
for (auto &t : data.tensor_cfg.tensors()) {
388390
const bool src_zp_precalc
389391
= needs_zp_precalc && (t.name == "src_zero_points");
@@ -392,6 +394,13 @@ class gen_convolution_t {
392394
size_t compute_size = t.compute_layout.size();
393395
int compute_arg_key = t.arg_key;
394396

397+
if (compute_arg_key == DNNL_ARG_UNDEF) {
398+
ir_assert(!t.needs_reorder);
399+
ir_assert(!t.needs_zero_out);
400+
ir_error_not_expected();
401+
continue;
402+
}
403+
395404
auto add_compute_arg = [&](kernel_info_t &ki, const expr_t &buf,
396405
bool is_input) {
397406
if (t.needs_reorder || src_zp_precalc)
@@ -400,13 +409,21 @@ class gen_convolution_t {
400409
else
401410
ki.register_user_arg(buf, compute_arg_key, is_input);
402411
};
403-
404-
if (compute_arg_key == -1) {
405-
ir_assert(!t.needs_reorder);
406-
ir_assert(!t.needs_zero_out);
407-
ir_error_not_expected();
408-
continue;
409-
}
412+
auto scratchpad_book = [&](int key) {
413+
pd->scratchpad_registry().registrar().book(
414+
gpu_utils::into<uint32_t>(key), compute_size, 1,
415+
ocl::OCL_BUFFER_ALIGNMENT);
416+
};
417+
auto create_zero_out_info = [&]() -> kernel_info_t & {
418+
auto &zero_out_info
419+
= create_kernel_info(pd, kernel_id_t::zero_out);
420+
auto size_var = var_t::make(type_t::u32(), "size");
421+
zero_out_info.register_internal_arg(
422+
size_var, gpu_utils::into<uint32_t>(compute_size));
423+
zero_out_info.set_nd_range(zero_out_kernel_t<>::nd_range(
424+
cfg.simd(), gpu_utils::into<int>(compute_size)));
425+
return zero_out_info;
426+
};
410427

411428
if (t.needs_reorder || src_zp_precalc) {
412429
int user_arg_key = compute_arg_key;
@@ -432,19 +449,9 @@ class gen_convolution_t {
432449
cfg.exec_cfg(), t.compute_layout, t.user_layout));
433450
}
434451
if (src_zp_precalc) {
435-
++scratchpad_key;
436-
scratchpad.book(uint32_t(scratchpad_key), compute_size, 1,
437-
ocl::OCL_BUFFER_ALIGNMENT);
438-
439-
auto &zero_out_info
440-
= create_kernel_info(pd, kernel_id_t::zero_out);
441-
zero_out_info.register_scratchpad_arg(compute_buf,
452+
scratchpad_book(++scratchpad_key);
453+
create_zero_out_info().register_scratchpad_arg(compute_buf,
442454
scratchpad_key, /*is_input=*/false, compute_size);
443-
auto size_var = var_t::make(type_t::u32(), "size");
444-
zero_out_info.register_internal_arg(
445-
size_var, uint32_t(compute_size));
446-
zero_out_info.set_nd_range(zero_out_kernel_t<>::nd_range(
447-
cfg.simd(), int(compute_size)));
448455

449456
zp_precalc_info.register_scratchpad_arg(compute_buf,
450457
scratchpad_key, /*is_input=*/true, compute_size);
@@ -462,18 +469,10 @@ class gen_convolution_t {
462469
/ std::min(KDW, prb.iw) / (prb.g * prb.ic);
463470
add_compute_arg(zp_precalc_info, make_buffer("dst"), false);
464471
}
465-
scratchpad.book(uint32_t(compute_arg_key), compute_size, 1,
466-
ocl::OCL_BUFFER_ALIGNMENT);
472+
scratchpad_book(compute_arg_key);
467473
}
468474
if (t.needs_zero_out) {
469-
auto &zero_out_info
470-
= create_kernel_info(pd, kernel_id_t::zero_out);
471-
add_compute_arg(zero_out_info, compute_buf, false);
472-
auto size_var = var_t::make(type_t::u32(), "size");
473-
zero_out_info.register_internal_arg(
474-
size_var, uint32_t(compute_size));
475-
zero_out_info.set_nd_range(zero_out_kernel_t<>::nd_range(
476-
cfg.simd(), int(compute_size)));
475+
add_compute_arg(create_zero_out_info(), compute_buf, false);
477476
}
478477
add_compute_arg(conv_info, compute_buf, t.is_input && !t.is_output);
479478
}

src/gpu/jit/conv/problem.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ class conv_arg_helper_t {
208208
if (prb_.is_bwd_d) return DNNL_ARG_DIFF_SRC;
209209
if (prb_.is_bwd_w) return DNNL_ARG_SRC;
210210
ir_error_not_expected();
211-
return -1;
211+
return DNNL_ARG_UNDEF;
212212
}
213213

214214
bool is_src_input() const { return prb_.is_fwd || prb_.is_bwd_w; }
@@ -219,7 +219,7 @@ class conv_arg_helper_t {
219219
if (prb_.is_bwd_d) return DNNL_ARG_WEIGHTS;
220220
if (prb_.is_bwd_w) return DNNL_ARG_DIFF_WEIGHTS;
221221
ir_error_not_expected();
222-
return -1;
222+
return DNNL_ARG_UNDEF;
223223
}
224224

225225
bool is_wei_input() const { return prb_.is_fwd || prb_.is_bwd_d; }
@@ -230,7 +230,7 @@ class conv_arg_helper_t {
230230
if (prb_.is_bwd_d) return DNNL_ARG_BIAS;
231231
if (prb_.is_bwd_w) return DNNL_ARG_DIFF_BIAS;
232232
ir_error_not_expected();
233-
return -1;
233+
return DNNL_ARG_UNDEF;
234234
}
235235

236236
bool is_bia_input() const { return prb_.is_fwd || prb_.is_bwd_d; }
@@ -241,7 +241,7 @@ class conv_arg_helper_t {
241241
if (prb_.is_bwd_d) return DNNL_ARG_DIFF_DST;
242242
if (prb_.is_bwd_w) return DNNL_ARG_DIFF_DST;
243243
ir_error_not_expected();
244-
return -1;
244+
return DNNL_ARG_UNDEF;
245245
}
246246

247247
bool is_dst_input() const { return prb_.is_bwd_d || prb_.is_bwd_w; }

src/gpu/jit/conv/zero_out.hpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ class zero_out_kernel_t : public ir_kernel_t<hw> {
5050
// XXX: Stateful messages don't work on XeHPC.
5151
bool use_a64 = (hw >= ngen::HW::XeHPC);
5252

53-
auto ptr = getArgument(arg_names[0]);
54-
auto surf = Surface(getArgumentSurfaceIfExists(arg_names[0]));
55-
auto size = getArgument(arg_names[1]);
53+
auto size = getArgument(arg_names[0]);
54+
auto ptr = getArgument(arg_names[1]);
55+
auto surf = Surface(getArgumentSurfaceIfExists(arg_names[1]));
5656
auto global_id = ra_.template alloc_sub<uint32_t>();
5757
auto off0 = ra_.template alloc_sub<uint32_t>();
5858

src/gpu/jit/ir/kernel_info.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ class kernel_info_t {
153153

154154
void register_internal_arg(
155155
const expr_t &var, const expr_t &value = expr_t()) {
156-
register_arg(var, arg_kind_t::internal, -1, /*is_input=*/true);
156+
register_arg(
157+
var, arg_kind_t::internal, DNNL_ARG_UNDEF, /*is_input=*/true);
157158
set_internal_arg(var.as<var_t>().name, value);
158159
}
159160

src/gpu/jit/pooling/gen_pooling.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -138,17 +138,13 @@ status_t gen_pooling_fwd_t::init(engine_t *engine) {
138138
ir_assert(!t.needs_reorder);
139139
ir_assert(!t.needs_zero_out);
140140

141-
int user_arg_key = t.arg_key;
142-
auto user_buf = make_buffer(t.name);
143-
144-
if (user_arg_key == -1) {
141+
if (t.arg_key == DNNL_ARG_UNDEF) {
145142
ir_assert(!t.needs_reorder);
146143
ir_assert(!t.needs_zero_out);
147144
ir_error_not_expected();
148145
continue;
149146
}
150-
151-
kernel_info_.register_user_arg(user_buf, user_arg_key,
147+
kernel_info_.register_user_arg(make_buffer(t.name), t.arg_key,
152148
/*is_input=*/t.is_input && !t.is_output);
153149
}
154150

src/gpu/jit/reorder/gen_reorder.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,13 @@ status_t gen_reorder_t::pd_t::init_kernel_info() {
185185
ir_assert(!t.needs_reorder);
186186
ir_assert(!t.needs_zero_out);
187187

188-
int user_arg_key = t.arg_key;
189-
auto user_buf = make_buffer(t.name);
190-
191-
if (user_arg_key == -1) {
188+
if (t.arg_key == DNNL_ARG_UNDEF) {
192189
ir_assert(!t.needs_reorder);
193190
ir_assert(!t.needs_zero_out);
194191
ir_error_not_expected();
195192
continue;
196193
}
197-
198-
kernel_info->register_user_arg(user_buf, user_arg_key,
194+
kernel_info->register_user_arg(make_buffer(t.name), t.arg_key,
199195
/*is_input=*/t.is_input && !t.is_output);
200196
}
201197
return status::success;

src/gpu/jit/v2/conv/problem.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ class arg_helper_t {
199199
if (is_bwd_d()) return DNNL_ARG_DIFF_SRC;
200200
if (is_bwd_w()) return DNNL_ARG_SRC;
201201
ir_error_not_expected();
202-
return -1;
202+
return DNNL_ARG_UNDEF;
203203
}
204204

205205
bool is_src_input() const { return is_fwd() || is_bwd_w(); }
@@ -210,7 +210,7 @@ class arg_helper_t {
210210
if (is_bwd_d()) return DNNL_ARG_WEIGHTS;
211211
if (is_bwd_w()) return DNNL_ARG_DIFF_WEIGHTS;
212212
ir_error_not_expected();
213-
return -1;
213+
return DNNL_ARG_UNDEF;
214214
}
215215

216216
bool is_wei_input() const { return is_fwd() || is_bwd_d(); }
@@ -221,7 +221,7 @@ class arg_helper_t {
221221
if (is_bwd_d()) return DNNL_ARG_BIAS;
222222
if (is_bwd_w()) return DNNL_ARG_DIFF_BIAS;
223223
ir_error_not_expected();
224-
return -1;
224+
return DNNL_ARG_UNDEF;
225225
}
226226

227227
bool is_bia_input() const { return is_fwd() || is_bwd_d(); }
@@ -232,7 +232,7 @@ class arg_helper_t {
232232
if (is_bwd_d()) return DNNL_ARG_DIFF_DST;
233233
if (is_bwd_w()) return DNNL_ARG_DIFF_DST;
234234
ir_error_not_expected();
235-
return -1;
235+
return DNNL_ARG_UNDEF;
236236
}
237237

238238
bool is_dst_input() const { return is_bwd_d() || is_bwd_w(); }

0 commit comments

Comments
 (0)