Skip to content

Commit 2ead5d4

Browse files
committed
[FORK][FIX] changed comp_tile_len data type from int16_t to int
[FORK][FEATURE] cpu: add inner product with sparse packed weights
1 parent ff9205a commit 2ead5d4

4 files changed

+14
-13
lines changed

src/common/memory_desc_wrapper.hpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,9 @@ struct memory_desc_wrapper : public c_compatible {
235235
// assert(matches_tag(format_tag::OI16i64o4i)); - TODO: enable for sparse packed.
236236
const size_t metadata = padded_dims()[0] * padded_dims()[1] / 64
237237
* sizeof(uint64_t);
238-
size_t comp_tile_data_size = ceil(static_cast<float>(padded_dims()[0] * padded_dims()[1]) / (64 * 64 * 32)) * 64;
238+
using comp_tile_len_type = int;
239+
size_t comp_tile_data_size = ceil(static_cast<float>(padded_dims()[0] * padded_dims()[1])
240+
/ (64 * 64 * (64 / sizeof(comp_tile_len_type)))) * 64;
239241
return comp_tile_data_size + (padded_dims()[0] * padded_dims()[1] * data_type_size())
240242
+ metadata + 1000;
241243
// todo: [av] why 1000?

src/cpu/reorder/simple_sparse_reorder.hpp

+6-9
Original file line numberDiff line numberDiff line change
@@ -134,17 +134,14 @@ struct simple_sparse_reorder_impl<SIMPLE_SPARSE_REORDER_TEMPL_CALL,
134134
size_t offset = padded_dims[0] * padded_dims[1];
135135

136136
int total_blocks = offset / 4096;
137-
int16_t *comp_tile_len_ptr = reinterpret_cast<int16_t *>(output);
137+
using comp_tile_len_type = int;
138+
comp_tile_len_type *comp_tile_len_ptr = reinterpret_cast<comp_tile_len_type *>(output);
138139
int comp_tile_len_index = 0;
139140
int cl_length = 0;
140-
// TODO: why 2 / 64?
141141
// Wasting memory space due to allocation a buffer for the whole tensor?
142-
int output_offset = ceil((float)total_blocks * 2 / 64.0);
143-
144-
size_t offset_2 = static_cast<size_t>(ceil((float)total_blocks * 2 / 64.0)) * 64;
145-
uint64_t *bitmask_ptr = reinterpret_cast<uint64_t *>(output + offset + offset_2);
146-
147-
auto outp = &output[output_d.blk_off(0, 0, 0, 0) + output_offset * 64];
142+
int output_offset = ceil((float)total_blocks * sizeof(comp_tile_len_type) / 64.0) * 64;
143+
uint64_t *bitmask_ptr = reinterpret_cast<uint64_t *>(output + output_offset + offset);
144+
auto outp = &output[output_d.blk_off(0, 0, 0, 0) + output_offset];
148145

149146
// TODO: add threading.
150147
for (int O = 0; O < NB_OC; O++) {
@@ -184,7 +181,7 @@ struct simple_sparse_reorder_impl<SIMPLE_SPARSE_REORDER_TEMPL_CALL,
184181
if (count % 64 == 0) { bitmask_idx++; }
185182
}
186183
}
187-
int16_t cl = (int16_t)ceil(non_zeros / 64.0);
184+
comp_tile_len_type cl = (comp_tile_len_type)ceil(non_zeros / 64.0);
188185
comp_tile_len_index++;
189186
cl_length = comp_tile_len_ptr[comp_tile_len_index - 1] + cl;
190187
int unsed_bytes_in_cl = 64 - (non_zeros % 64);

src/cpu/x64/jit_brgemm_inner_product.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,9 @@ status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
292292
const dim_t wei_offset = (wei_cur_ocb
293293
+ wei_ic_stride * (icb + b * ic_blocks_per_batch)) / typesize_scale;
294294
if (jbgp.weights_compressed) {
295-
const int16_t *compressed_tile_lengths_ptr
296-
= reinterpret_cast<const int16_t *>(weights);
295+
using comp_tile_len_type = int;
296+
const comp_tile_len_type *compressed_tile_lengths_ptr
297+
= reinterpret_cast<const comp_tile_len_type *>(weights);
297298
int compressed_weights_offset = wei_offset / 4096;
298299

299300
auto dcomp_params = brgemm_decomp_kernel_params_t();

src/cpu/x64/jit_brgemm_inner_product_utils.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -1382,8 +1382,9 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa,
13821382
if (jbgp.weights_compressed) {
13831383
jbgp.weights_compressed = true;
13841384
int total_blocks = (jbgp.oc * jbgp.ic) / 4096;
1385+
using comp_tile_len_type = int;
13851386
jbgp.weights_starting_offset
1386-
= ceil((float)total_blocks * 2 / 64.0) * 64;
1387+
= ceil((float)total_blocks * sizeof(comp_tile_len_type) / 64.0) * 64;
13871388
jbgp.weight_comp_bitmask_off = jbgp.weights_starting_offset + jbgp.ic * jbgp.oc;
13881389
}
13891390
} else if (is_bf16) {

0 commit comments

Comments
 (0)