Skip to content

Commit 5e0a8a9

Browse files
kawakami-kmgouicem
authored andcommitted
aarch64: shuffle: fix segv for bf16 cases
1 parent e41e332 commit 5e0a8a9

File tree

2 files changed

+36
-20
lines changed

2 files changed

+36
-20
lines changed

src/cpu/aarch64/shuffle/jit_uni_shuffle.cpp

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*******************************************************************************
2-
* Copyright 2020-2022 Intel Corporation
3-
* Copyright 2022 FUJITSU LIMITED
2+
* Copyright 2020-2024 Intel Corporation
3+
* Copyright 2022-2024 FUJITSU LIMITED
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -34,6 +34,7 @@ template <cpu_isa_t isa>
3434
status_t jit_uni_shuffle_t<isa>::pd_t::init(engine_t *engine) {
3535
using namespace format_tag;
3636
using namespace data_type;
37+
using namespace types;
3738

3839
const memory_desc_wrapper src_d(is_fwd() ? src_md() : diff_src_md());
3940
const memory_desc_wrapper dst_d(is_fwd() ? dst_md() : diff_dst_md());
@@ -58,7 +59,10 @@ status_t jit_uni_shuffle_t<isa>::pd_t::init(engine_t *engine) {
5859
if (blocked_format == format_tag::undef) return status::unimplemented;
5960

6061
conf_.blk_size = src_d.blocking_desc().strides[ndims() - 1];
61-
conf_.simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
62+
/* Because "ST1H { <Zt>.S }, <Pg>, [<Xn|SP>, <Zm>.S, UXTW #1]" is used
63+
to gather data for bf16, simd_w must be calculated
64+
with sizeof(uint32_t). */
65+
conf_.simd_w = cpu_isa_traits<isa>::vlen / sizeof(uint32_t);
6266

6367
const bool has_spatial = utils::one_of(ndims(), 3, 4, 5);
6468
const dim_t HW = H() * W();

src/cpu/aarch64/shuffle/jit_uni_shuffle_kernel.cpp

+29-17
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*******************************************************************************
2-
* Copyright 2021-2022 Intel Corporation
3-
* Copyright 2022 FUJITSU LIMITED
2+
* Copyright 2021-2024 Intel Corporation
3+
* Copyright 2022-2024 FUJITSU LIMITED
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -47,9 +47,12 @@ jit_uni_shuffle_kernel_t<isa>::jit_uni_shuffle_kernel_t(
4747
template <cpu_isa_t isa>
4848
void jit_uni_shuffle_kernel_t<isa>::prepare_mask() {
4949
using namespace data_type;
50+
using namespace types;
5051
if (conf_.simd_tail > 0) {
51-
assert(utils::one_of(conf_.data_type, f32, s32));
52-
assert(conf_.simd_tail < isa_sveLen / sizeof(float));
52+
/* Because "ST1H { <Zt>.S }, <Pg>, [<Xn|SP>, <Zm>.S, UXTW #1]" is used
53+
to gather data for bf16, simd_tail must be evaluated
54+
with sizeof(unsigned). */
55+
assert(conf_.simd_tail < isa_sveLen / sizeof(uint32_t));
5356
index(vmm_tmp_.s, 0, 1);
5457
cmplt(k_tail_mask_.s, P_ALL_ONE / T_z, vmm_tmp_.s, conf_.simd_tail);
5558
}
@@ -68,13 +71,17 @@ void jit_uni_shuffle_kernel_t<asimd>::prepare_mask() {}
6871
template <cpu_isa_t isa>
6972
void jit_uni_shuffle_kernel_t<isa>::gather_data(const XReg &reg_src_addr,
7073
const int indices_idx, const int data_idx, const bool is_tail) {
71-
if (conf_.dt_size == sizeof(float)) {
72-
const PReg &mask = is_tail ? k_tail_mask_ : k_full_mask_;
74+
using namespace data_type;
75+
const PReg &mask = is_tail ? k_tail_mask_ : k_full_mask_;
76+
77+
if (utils::one_of(conf_.data_type, f32, s32)) {
7378
lsr(TRegS(indices_idx), TRegS(indices_idx), 2);
7479
ld1w(TRegS(data_idx), mask / T_z,
7580
ptr(reg_src_addr, TRegS(indices_idx), UXTW, 2));
76-
} else {
77-
assert(!"unsupported emu_gather_data");
81+
} else if (conf_.data_type == bf16) {
82+
lsr(TRegS(indices_idx), TRegS(indices_idx), 1);
83+
ld1h(TRegS(data_idx), mask / T_z,
84+
ptr(reg_src_addr, TRegS(indices_idx), UXTW, 1));
7885
}
7986
}
8087

@@ -97,21 +104,26 @@ void jit_uni_shuffle_kernel_t<asimd>::gather_data(const XReg &addr,
97104
template <cpu_isa_t isa>
98105
void jit_uni_shuffle_kernel_t<isa>::store_data(const int data_idx,
99106
const XReg &reg_dst_addr, const int offset, const bool is_tail) {
107+
using namespace data_type;
100108
const auto extend_for_padding
101109
= is_tail && padding_size_ + conf_.simd_tail >= conf_.simd_w;
110+
const PReg &mask = is_tail ? k_tail_mask_ : P_ALL_ONE;
111+
112+
add_imm(X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0);
113+
102114
if (extend_for_padding) {
103115
sel(vmm_tmp_.s, k_tail_mask_, TRegS(data_idx), vmm_zero_.s);
104-
add_imm(X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0);
105-
st1w(vmm_tmp_.s, P_ALL_ONE, ptr(X_DEFAULT_ADDR));
116+
if (utils::one_of(conf_.data_type, f32, s32))
117+
st1w(vmm_tmp_.s, P_ALL_ONE, ptr(X_DEFAULT_ADDR));
118+
else // bf16
119+
st1h(vmm_tmp_.s, P_ALL_ONE, ptr(X_DEFAULT_ADDR));
106120
} else {
107-
if (is_tail) {
108-
add_imm(X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0);
109-
st1w(TRegS(data_idx), k_tail_mask_, ptr(X_DEFAULT_ADDR));
110-
} else {
111-
add_imm(X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0);
112-
st1w(TRegS(data_idx), P_ALL_ONE, ptr(X_DEFAULT_ADDR));
113-
}
121+
if (utils::one_of(conf_.data_type, f32, s32))
122+
st1w(TRegS(data_idx), mask, ptr(X_DEFAULT_ADDR));
123+
else // bf16
124+
st1h(TRegS(data_idx), mask, ptr(X_DEFAULT_ADDR));
114125
}
126+
115127
append_zero_padding(
116128
reg_dst_, isa_sveLen > 128 ? extend_for_padding : false);
117129
}

0 commit comments

Comments
 (0)