@@ -376,14 +376,16 @@ class gen_convolution_t {
376
376
auto &cfg = data.pd_cfg ;
377
377
const bool needs_zp_precalc = cfg.zp_cfg ().needs_src_precalc ;
378
378
379
- auto scratchpad = pd->scratchpad_registry ().registrar ();
380
379
auto &conv_info = create_kernel_info (pd, kernel_id_t ::convolution);
381
380
auto &zp_precalc_info = (needs_zp_precalc)
382
381
? create_kernel_info (pd, kernel_id_t ::zp_precalc)
383
382
: conv_info;
384
383
384
+ static_assert (DNNL_ARG_UNDEF == memory_tracking::names::key_none,
385
+ " Undefined argument and empty scratchpad key are out of sync!" );
386
+
385
387
// Initialize kernel arguments.
386
- int scratchpad_key = 0 ;
388
+ int scratchpad_key = memory_tracking::names::key_none ;
387
389
for (auto &t : data.tensor_cfg .tensors ()) {
388
390
const bool src_zp_precalc
389
391
= needs_zp_precalc && (t.name == " src_zero_points" );
@@ -392,6 +394,13 @@ class gen_convolution_t {
392
394
size_t compute_size = t.compute_layout .size ();
393
395
int compute_arg_key = t.arg_key ;
394
396
397
+ if (compute_arg_key == DNNL_ARG_UNDEF) {
398
+ ir_assert (!t.needs_reorder );
399
+ ir_assert (!t.needs_zero_out );
400
+ ir_error_not_expected ();
401
+ continue ;
402
+ }
403
+
395
404
auto add_compute_arg = [&](kernel_info_t &ki, const expr_t &buf,
396
405
bool is_input) {
397
406
if (t.needs_reorder || src_zp_precalc)
@@ -400,13 +409,21 @@ class gen_convolution_t {
400
409
else
401
410
ki.register_user_arg (buf, compute_arg_key, is_input);
402
411
};
403
-
404
- if (compute_arg_key == -1 ) {
405
- ir_assert (!t.needs_reorder );
406
- ir_assert (!t.needs_zero_out );
407
- ir_error_not_expected ();
408
- continue ;
409
- }
412
+ auto scratchpad_book = [&](int key) {
413
+ pd->scratchpad_registry ().registrar ().book (
414
+ gpu_utils::into<uint32_t >(key), compute_size, 1 ,
415
+ ocl::OCL_BUFFER_ALIGNMENT);
416
+ };
417
+ auto create_zero_out_info = [&]() -> kernel_info_t & {
418
+ auto &zero_out_info
419
+ = create_kernel_info (pd, kernel_id_t ::zero_out);
420
+ auto size_var = var_t::make (type_t::u32 (), " size" );
421
+ zero_out_info.register_internal_arg (
422
+ size_var, gpu_utils::into<uint32_t >(compute_size));
423
+ zero_out_info.set_nd_range (zero_out_kernel_t <>::nd_range (
424
+ cfg.simd (), gpu_utils::into<int >(compute_size)));
425
+ return zero_out_info;
426
+ };
410
427
411
428
if (t.needs_reorder || src_zp_precalc) {
412
429
int user_arg_key = compute_arg_key;
@@ -432,19 +449,9 @@ class gen_convolution_t {
432
449
cfg.exec_cfg (), t.compute_layout , t.user_layout ));
433
450
}
434
451
if (src_zp_precalc) {
435
- ++scratchpad_key;
436
- scratchpad.book (uint32_t (scratchpad_key), compute_size, 1 ,
437
- ocl::OCL_BUFFER_ALIGNMENT);
438
-
439
- auto &zero_out_info
440
- = create_kernel_info (pd, kernel_id_t ::zero_out);
441
- zero_out_info.register_scratchpad_arg (compute_buf,
452
+ scratchpad_book (++scratchpad_key);
453
+ create_zero_out_info ().register_scratchpad_arg (compute_buf,
442
454
scratchpad_key, /* is_input=*/ false , compute_size);
443
- auto size_var = var_t::make (type_t::u32 (), " size" );
444
- zero_out_info.register_internal_arg (
445
- size_var, uint32_t (compute_size));
446
- zero_out_info.set_nd_range (zero_out_kernel_t <>::nd_range (
447
- cfg.simd (), int (compute_size)));
448
455
449
456
zp_precalc_info.register_scratchpad_arg (compute_buf,
450
457
scratchpad_key, /* is_input=*/ true , compute_size);
@@ -462,18 +469,10 @@ class gen_convolution_t {
462
469
/ std::min (KDW, prb.iw ) / (prb.g * prb.ic );
463
470
add_compute_arg (zp_precalc_info, make_buffer (" dst" ), false );
464
471
}
465
- scratchpad.book (uint32_t (compute_arg_key), compute_size, 1 ,
466
- ocl::OCL_BUFFER_ALIGNMENT);
472
+ scratchpad_book (compute_arg_key);
467
473
}
468
474
if (t.needs_zero_out ) {
469
- auto &zero_out_info
470
- = create_kernel_info (pd, kernel_id_t ::zero_out);
471
- add_compute_arg (zero_out_info, compute_buf, false );
472
- auto size_var = var_t::make (type_t::u32 (), " size" );
473
- zero_out_info.register_internal_arg (
474
- size_var, uint32_t (compute_size));
475
- zero_out_info.set_nd_range (zero_out_kernel_t <>::nd_range (
476
- cfg.simd (), int (compute_size)));
475
+ add_compute_arg (create_zero_out_info (), compute_buf, false );
477
476
}
478
477
add_compute_arg (conv_info, compute_buf, t.is_input && !t.is_output );
479
478
}
0 commit comments