@@ -1977,12 +1977,13 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1977
1977
static bool is_applicable (const memory_desc_wrapper &input_d,
1978
1978
const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1979
1979
return !input_d.has_runtime_dims_or_strides () && input_d.is_dense ()
1980
- && output_d.is_dense ()
1981
- && output_d.strides ()[output_d.ndims () - 1 ] == 1
1982
- && simple_attr_check (attr, false , true );
1980
+ && output_d.is_dense () && simple_attr_check (attr, false , true );
1983
1981
}
1984
1982
1985
- GET_SCRATCHPAD_SIZE_ZERO ();
1983
+ static size_t get_scratchpad_size (const memory_desc_wrapper &input_d,
1984
+ const memory_desc_wrapper &output_d) {
1985
+ return input_d.size ();
1986
+ }
1986
1987
1987
1988
static status_t execute (const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
1988
1989
DECLARE_COMMON_PARAMS ();
@@ -1991,6 +1992,31 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1991
1992
input += input_d.blk_off (0 );
1992
1993
output += output_d.blk_off (0 );
1993
1994
1995
+ data_t <type_i> *wspace = scratchpad.template get <data_t <type_i>>(
1996
+ memory_tracking::names::key_reorder_space);
1997
+
1998
+ // When formats of the input and the output are not identical, the idea
1999
+ // is to reorder the data from the input format to the output format
2000
+ // but within the same data type, and after the format reorder apply
2001
+ // the compression into int4 as on `abx` format.
2002
+ const bool need_transform
2003
+ = output_d.strides ()[output_d.ndims () - 1 ] != 1 ;
2004
+ if (need_transform) {
2005
+ const dim_t work_amount = input_d.nelems ();
2006
+ parallel (0 , [&](const int ithr, const int nthr) {
2007
+ dim_t start {0 }, end {0 };
2008
+ balance211 (work_amount, nthr, ithr, start, end);
2009
+ PRAGMA_OMP_SIMD ()
2010
+ for (dim_t idx = start; idx < end; idx++) {
2011
+ const auto i_off = input_d.off_l (idx);
2012
+ const auto o_off = output_d.off_l (idx);
2013
+ wspace[o_off] = input[i_off];
2014
+ }
2015
+ });
2016
+
2017
+ input = wspace;
2018
+ }
2019
+
1994
2020
// To avoid clashes between threads each byte (or 2 elements)
1995
2021
// is handled by a single thread
1996
2022
const dim_t work_amount = input_d.nelems () / 2 ;
@@ -2001,8 +2027,10 @@ struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
2001
2027
PRAGMA_OMP_SIMD ()
2002
2028
for_ (dim_t idx = start; idx < end; idx++)
2003
2029
for (int i = 0 ; i < 2 ; ++i) {
2004
- const auto i_off = input_d.off_l (2 * idx + i);
2005
- const auto o_off = output_d.off_l (2 * idx + i);
2030
+ const auto i_off = need_transform ? 2 * idx + i
2031
+ : input_d.off_l (2 * idx + i);
2032
+ const auto o_off = need_transform ? 2 * idx + i
2033
+ : output_d.off_l (2 * idx + i);
2006
2034
const auto shift = i % 2 ? int4_extract_t ::high_half
2007
2035
: int4_extract_t ::low_half;
2008
2036
auto src_val = _qz_a1b0<data_type::f32, type_o>()(input[i_off]);
0 commit comments