Skip to content

Commit 912851f

Browse files
dzarukinpetercad
authored andcommitted
cpu: reorder: remove f32 -> int4 format limitation
1 parent 57e9f8d commit 912851f

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

src/cpu/reorder/simple_reorder.hpp

+34-6
Original file line numberDiff line numberDiff line change
@@ -1977,12 +1977,13 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
19771977
static bool is_applicable(const memory_desc_wrapper &input_d,
19781978
const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
19791979
return !input_d.has_runtime_dims_or_strides() && input_d.is_dense()
1980-
&& output_d.is_dense()
1981-
&& output_d.strides()[output_d.ndims() - 1] == 1
1982-
&& simple_attr_check(attr, false, true);
1980+
&& output_d.is_dense() && simple_attr_check(attr, false, true);
19831981
}
19841982

1985-
GET_SCRATCHPAD_SIZE_ZERO();
1983+
static size_t get_scratchpad_size(const memory_desc_wrapper &input_d,
1984+
const memory_desc_wrapper &output_d) {
1985+
return input_d.size();
1986+
}
19861987

19871988
static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
19881989
DECLARE_COMMON_PARAMS();
@@ -1991,6 +1992,31 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
19911992
input += input_d.blk_off(0);
19921993
output += output_d.blk_off(0);
19931994

1995+
data_t<type_i> *wspace = scratchpad.template get<data_t<type_i>>(
1996+
memory_tracking::names::key_reorder_space);
1997+
1998+
// When formats of the input and the output are not identical, the idea
1999+
// is to reorder the data from the input format to the output format
2000+
// but within the same data type, and after the format reorder apply
2001+
// the compression into int4 as on `abx` format.
2002+
const bool need_transform
2003+
= output_d.strides()[output_d.ndims() - 1] != 1;
2004+
if (need_transform) {
2005+
const dim_t work_amount = input_d.nelems();
2006+
parallel(0, [&](const int ithr, const int nthr) {
2007+
dim_t start {0}, end {0};
2008+
balance211(work_amount, nthr, ithr, start, end);
2009+
PRAGMA_OMP_SIMD()
2010+
for (dim_t idx = start; idx < end; idx++) {
2011+
const auto i_off = input_d.off_l(idx);
2012+
const auto o_off = output_d.off_l(idx);
2013+
wspace[o_off] = input[i_off];
2014+
}
2015+
});
2016+
2017+
input = wspace;
2018+
}
2019+
19942020
// To avoid clashes between threads each byte (or 2 elements)
19952021
// is handled by a single thread
19962022
const dim_t work_amount = input_d.nelems() / 2;
@@ -2001,8 +2027,10 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
20012027
PRAGMA_OMP_SIMD()
20022028
for_(dim_t idx = start; idx < end; idx++)
20032029
for (int i = 0; i < 2; ++i) {
2004-
const auto i_off = input_d.off_l(2 * idx + i);
2005-
const auto o_off = output_d.off_l(2 * idx + i);
2030+
const auto i_off = need_transform ? 2 * idx + i
2031+
: input_d.off_l(2 * idx + i);
2032+
const auto o_off = need_transform ? 2 * idx + i
2033+
: output_d.off_l(2 * idx + i);
20062034
const auto shift = i % 2 ? int4_extract_t::high_half
20072035
: int4_extract_t::low_half;
20082036
auto src_val = _qz_a1b0<data_type::f32, type_o>()(input[i_off]);

0 commit comments

Comments
 (0)