Skip to content

Commit ff9205a

Browse files
author
dmitrygo
committed
[FORK][FEATURE] InnerProduct primitive: 4bit weights decompression support
1 parent 36c2060 commit ff9205a

30 files changed

+1051
-121
lines changed

include/oneapi/dnnl/dnnl.h

+4
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,8 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode(
334334
/// otherwise.
335335
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_mask(
336336
dnnl_primitive_attr_t attr, int arg, int mask);
337+
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_dims(
338+
dnnl_primitive_attr_t attr, int arg, const dnnl_dims_t dims, int ndims);
337339

338340
/// Sets primitive attributes zero points for primitive operations for a given
339341
/// memory argument. The zero points must be passed at execution time
@@ -354,6 +356,8 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_mask(
354356
/// otherwise.
355357
dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points_mask(
356358
dnnl_primitive_attr_t attr, int arg, int mask);
359+
dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points_dims(
360+
dnnl_primitive_attr_t attr, int arg, const dnnl_dims_t dims, int ndims);
357361

358362
dnnl_status_t DNNL_API dnnl_primitive_attr_set_output_compensations(
359363
dnnl_primitive_attr_t attr, int count, int mask);

include/oneapi/dnnl/dnnl.hpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,13 @@ struct memory : public handle<dnnl_memory_t> {
869869
/// 8-bit unsigned integer.
870870
u8 = dnnl_u8,
871871
/// 1-bit integer
872-
bin = dnnl_bin
872+
bin = dnnl_bin,
873+
/// 4-bit normalized float.
874+
nf4 = dnnl_nf4,
875+
/// 4-bit signed integer.
876+
s4 = dnnl_s4,
877+
/// 4-bit unsigned integer.
878+
u4 = dnnl_u4
873879
};
874880

875881
/// Returns size of data type in bytes.
@@ -3874,6 +3880,10 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
38743880
error::wrap_c_api(dnnl_primitive_attr_set_scales_mask(get(), arg, mask),
38753881
"could not set scales primitive attribute");
38763882
}
3883+
void set_scales_dims(int arg, const memory::dims& dims) {
3884+
error::wrap_c_api(dnnl_primitive_attr_set_scales_dims(get(), arg, dims.data(), dims.size()),
3885+
"could not set scales primitive attribute");
3886+
}
38773887

38783888
/// Sets zero points for primitive operations for a given memory argument.
38793889
/// The zero points must be passed at execution time as an argument with
@@ -3893,6 +3903,11 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
38933903
dnnl_primitive_attr_set_zero_points_mask(get(), arg, mask),
38943904
"could not set zero points primitive attribute");
38953905
}
3906+
void set_zero_points_dims(int arg, const memory::dims& dims) {
3907+
error::wrap_c_api(
3908+
dnnl_primitive_attr_set_zero_points_dims(get(), arg, dims.data(), dims.size()),
3909+
"could not set zero points primitive attribute");
3910+
}
38963911

38973912
void set_output_compensations(dnnl_dim_t count, int mask)
38983913
{

include/oneapi/dnnl/dnnl_common_types.h

+6
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ typedef enum {
9494
dnnl_boolean = 8,
9595
/// 1-bit integer.
9696
dnnl_bin = 9,
97+
/// 4-bit normalized float.
98+
dnnl_nf4 = 10,
99+
/// 4-bit signed integer.
100+
dnnl_s4 = 11,
101+
/// 4-bit unsigned integer.
102+
dnnl_u4 = 12,
97103
/// Parameter to allow internal only data_types without undefined behavior.
98104
/// This parameter is chosen to be valid for so long as sizeof(int) >= 2.
99105
dnnl_data_type_max = 0x7fff,

src/common/c_types_map.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ const data_type_t boolean = dnnl_boolean;
168168
const data_type_t tf32 = static_cast<data_type_t>(1 << 8);
169169

170170
const data_type_t bin = dnnl_bin;
171+
const data_type_t nf4 = dnnl_nf4;
172+
const data_type_t s4 = dnnl_s4;
173+
const data_type_t u4 = dnnl_u4;
171174
} // namespace data_type
172175

173176
using fpmath_mode_t = dnnl_fpmath_mode_t;

src/common/dnnl_debug_autogenerated.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ const char *dnnl_dt2str(dnnl_data_type_t v) {
5151
if (v == dnnl_f64) return "f64";
5252
if (v == dnnl_boolean) return "boolean";
5353
if (v == dnnl_bin) return "bin";
54+
if (v == dnnl_nf4) return "nf4";
55+
if (v == dnnl_s4) return "s4";
56+
if (v == dnnl_u4) return "u4";
5457
if (v == dnnl_data_type_max) return "data_type_max";
5558
assert(!"unknown dt");
5659
return "unknown dt";

src/common/dnnl_traits.hpp

+12
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,18 @@ template <> struct prec_traits<data_type::bin> {
7979
typedef uint8_t type;
8080
};
8181

82+
template <> struct prec_traits<data_type::nf4> {
83+
typedef uint8_t type;
84+
};
85+
86+
template <> struct prec_traits<data_type::s4> {
87+
typedef uint8_t type;
88+
};
89+
90+
template <> struct prec_traits<data_type::u4> {
91+
typedef uint8_t type;
92+
};
93+
8294
template <>
8395
struct data_traits<float16_t> {
8496
static constexpr data_type_t data_type = data_type::f16;

src/common/memory_zero_pad.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,9 @@ static status_t zero_pad(const memory_t *memory, const exec_ctx_t &ctx) {
289289
case s8: return typed_zero_pad<s8>(memory, ctx);
290290
case u8: return typed_zero_pad<u8>(memory, ctx);
291291
case bin: return typed_zero_pad<u8>(memory, ctx);
292+
case nf4: return typed_zero_pad<u8>(memory, ctx);
293+
case s4: return typed_zero_pad<u8>(memory, ctx);
294+
case u4: return typed_zero_pad<u8>(memory, ctx);
292295
default: assert(!"memory is undefined"); return unimplemented;
293296
}
294297
return unimplemented;

src/common/primitive_attr.cpp

+31
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "primitive_attr.hpp"
2121
#include "type_helpers.hpp"
2222
#include "utils.hpp"
23+
#include <iostream>
2324

2425
using namespace dnnl::impl;
2526
using namespace dnnl::impl::status;
@@ -119,6 +120,22 @@ status_t zero_points_t::set(int arg, int mask) {
119120
return status::success;
120121
}
121122

123+
status_t zero_points_t::set(int arg, const dims_t dims, int ndims) {
124+
const bool supported_arg
125+
= utils::one_of(arg, DNNL_ARG_WEIGHTS);
126+
if (!supported_arg) return status::unimplemented;
127+
128+
switch (arg) {
129+
case DNNL_ARG_WEIGHTS:
130+
is_set_wei = true;
131+
ndims_wei = ndims;
132+
mask_wei = 1;
133+
utils::array_copy(dims_wei, dims, ndims);
134+
break;
135+
}
136+
return status::success;
137+
}
138+
122139
} // namespace impl
123140
} // namespace dnnl
124141

@@ -548,6 +565,13 @@ status_t dnnl_primitive_attr_set_scales_mask(
548565
if (!ok) return invalid_arguments;
549566
return attr->scales_.set(arg, mask);
550567
}
568+
status_t dnnl_primitive_attr_set_scales_dims(
569+
primitive_attr_t *attr, int arg, const dims_t dims, int ndims) {
570+
bool ok = attr && arg >= 0 && ndims > 0
571+
&& attr->output_scales_.has_default_values();
572+
if (!ok) return invalid_arguments;
573+
return attr->scales_.set(arg, dims, ndims);
574+
}
551575

552576
status_t dnnl_primitive_attr_set_zero_points_mask(
553577
primitive_attr_t *attr, int arg, int mask) {
@@ -556,6 +580,13 @@ status_t dnnl_primitive_attr_set_zero_points_mask(
556580

557581
return attr->zero_points_.set(arg, mask);
558582
}
583+
status_t dnnl_primitive_attr_set_zero_points_dims(
584+
primitive_attr_t *attr, int arg, const dims_t dims, int ndims) {
585+
bool ok = attr && ndims > 0;
586+
if (!ok) return invalid_arguments;
587+
588+
return attr->zero_points_.set(arg, dims, ndims);
589+
}
559590

560591
status_t dnnl_primitive_attr_set_output_compensations(primitive_attr_t *attr,
561592
int count, int mask) {

src/common/primitive_attr.hpp

+33-2
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,17 @@ struct runtime_scales_t : public c_compatible {
242242
return status::success;
243243
}
244244

245+
status_t set(const dims_t dims, int ndims) {
246+
is_set_ = true;
247+
ndims_ = ndims;
248+
mask_ = 1;
249+
utils::array_copy(dims_, dims, ndims_);
250+
return status::success;
251+
}
252+
245253
bool operator==(const runtime_scales_t &rhs) const {
246-
return mask_ == rhs.mask_ && is_set_ == rhs.is_set_;
254+
return mask_ == rhs.mask_ && is_set_ == rhs.is_set_ &&
255+
ndims_ == rhs.ndims_ && utils::array_cmp(dims_, rhs.dims_, ndims_);
247256
}
248257

249258
bool has_default_values() const { return !is_set_; }
@@ -259,6 +268,9 @@ struct runtime_scales_t : public c_compatible {
259268
// Hide `mask_` under `private:` to force interface usage.
260269
int mask_ = 0;
261270
bool is_set_ = false;
271+
272+
int ndims_ = 0;
273+
dnnl::impl::dims_t dims_;
262274
};
263275

264276
struct arg_scales_t : public c_compatible {
@@ -295,6 +307,10 @@ struct arg_scales_t : public c_compatible {
295307
if (!check_arg(arg)) return status::invalid_arguments;
296308
return scales_[arg].set(mask);
297309
}
310+
status_t set(int arg, const dims_t dims, int ndims) {
311+
if (!check_arg(arg)) return status::invalid_arguments;
312+
return scales_[arg].set(dims, ndims);
313+
}
298314

299315
status_t get(int arg, int *mask, bool *is_set) const {
300316
if (!check_arg(arg)) return status::invalid_arguments;
@@ -354,7 +370,8 @@ struct zero_points_t : public c_compatible {
354370
bool operator==(const zero_points_t &rhs) const {
355371
return mask_src == rhs.mask_src && mask_wei == rhs.mask_wei
356372
&& mask_dst == rhs.mask_dst && is_set_src == rhs.is_set_src
357-
&& is_set_wei == rhs.is_set_wei && is_set_dst == rhs.is_set_dst;
373+
&& is_set_wei == rhs.is_set_wei && is_set_dst == rhs.is_set_dst
374+
&& IMPLICATION(ndims_wei > 0, ndims_wei == rhs.ndims_wei && utils::array_cmp(dims_wei, rhs.dims_wei, ndims_wei));
358375
}
359376

360377
// arg-specific checks
@@ -373,12 +390,26 @@ struct zero_points_t : public c_compatible {
373390
int get(int arg) const; // Returns 0 if dimension is unset
374391

375392
status_t set(int arg, int mask);
393+
status_t set(int arg, const dims_t dims, int ndims);
376394
status_t set(int arg) { return set(arg, 0); }
377395

396+
const dims_t & get_dims(int /*arg*/) const {
397+
return dims_wei;
398+
}
399+
int get_ndims(int arg) const {
400+
switch (arg) {
401+
case DNNL_ARG_WEIGHTS: return ndims_wei; break;
402+
default: return 0;
403+
}
404+
}
405+
378406
private:
379407
bool is_set_src = false, is_set_wei = false, is_set_dst = false;
380408
int mask_src = 0, mask_wei = 0, mask_dst = 0;
381409

410+
int ndims_wei = 0;
411+
dnnl::impl::dims_t dims_wei;
412+
382413
int get_mask(int arg) const {
383414
int mask = 0;
384415
switch (arg) {

src/common/type_helpers.hpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ inline size_t data_type_size(data_type_t data_type) {
9393
case u8: return sizeof(prec_traits<u8>::type);
9494
case boolean: return sizeof(prec_traits<boolean>::type);
9595
case bin: return sizeof(prec_traits<u8>::type);
96+
case nf4: return sizeof(prec_traits<u8>::type);
97+
case u4: return sizeof(prec_traits<u8>::type);
98+
case s4: return sizeof(prec_traits<u8>::type);
9699
case data_type::undef:
97100
default: assert(!"unknown data_type");
98101
}
@@ -284,7 +287,7 @@ inline data_type_t default_accum_data_type(data_type_t src_dt,
284287

285288
/* prop_kind doesn't matter */
286289
if (everyone_is(f32, src_dt, wei_dt)) return f32;
287-
if (one_of(src_dt, f32, bf16) && wei_dt == u8) return f32;
290+
if (one_of(src_dt, f32, bf16) && one_of(wei_dt, u8, nf4, s4, u4)) return f32;
288291
if (everyone_is(f64, src_dt, wei_dt)) return f64;
289292

290293
if (one_of(prop_kind, forward_training, forward_inference)) {
@@ -949,7 +952,7 @@ inline bool memory_desc_sanity_check(int ndims, const dims_t dims,
949952
if (ndims == 0) return true;
950953

951954
bool ok = dims != nullptr && 0 < ndims && ndims <= DNNL_MAX_NDIMS
952-
&& utils::one_of(data_type, f16, bf16, f32, f64, s32, s8, u8, bin);
955+
&& utils::one_of(data_type, f16, bf16, f32, f64, s32, s8, u8, bin, nf4, s4, u4);
953956
if (!ok) return false;
954957

955958
bool has_runtime_dims = false;

src/cpu/cpu_inner_product_list.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,21 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
5757
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2)
5858
nullptr,
5959
}},
60+
{{forward, f32, nf4, f32}, {
61+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core)
62+
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2)
63+
nullptr,
64+
}},
65+
{{forward, f32, s4, f32}, {
66+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core)
67+
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2)
68+
nullptr,
69+
}},
70+
{{forward, f32, u4, f32}, {
71+
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core)
72+
CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t, avx2)
73+
nullptr,
74+
}},
6075
{{forward, bf16, bf16, f32}, {
6176
CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t, avx512_core_amx)
6277
CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t, avx512_core_bf16)

src/cpu/cpu_primitive.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
if (scales == nullptr) return status::invalid_arguments; \
6868
const auto scales_d = ctx.memory_mdw(DNNL_ARG_ATTR_SCALES | arg); \
6969
bool ok = scales_d.data_type() == data_type::f32 \
70-
&& scales_d.ndims() == 1; \
70+
&& (scales_d.ndims() == 1 || scales_d.ndims() == 2); \
7171
if (!ok) return status::invalid_arguments; \
7272
if (scales_d.dims()[0] == 1) { \
7373
if (utils::one_of(arg, DNNL_ARG_DST, \

src/cpu/reorder/cpu_reorder.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ regular_impl_list_map() {
3838
{{s8, data_type::undef, 0}, &regular_s8_impl_list_map()},
3939
{{u8, data_type::undef, 0}, &regular_u8_impl_list_map()},
4040
{{bin, data_type::undef, 0}, &regular_bin_impl_list_map()},
41+
{{nf4, data_type::undef, 0}, &regular_nf4_impl_list_map()},
42+
{{s4, data_type::undef, 0}, &regular_s4_impl_list_map()},
43+
{{u4, data_type::undef, 0}, &regular_u4_impl_list_map()},
4144
};
4245
return the_map;
4346
}

src/cpu/reorder/cpu_reorder.hpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ struct reorder_impl_key_t {
5656
}
5757

5858
private:
59-
enum { MAX_DT_NUM = 10 };
59+
enum { MAX_DT_NUM = 12 };
6060
size_t value() const {
6161
return ((size_t)ndims * MAX_DT_NUM + (size_t)src_dt) * MAX_DT_NUM
6262
+ (size_t)dst_dt;
@@ -80,6 +80,9 @@ extern const impl_list_map_t &regular_s32_impl_list_map();
8080
extern const impl_list_map_t &regular_s8_impl_list_map();
8181
extern const impl_list_map_t &regular_u8_impl_list_map();
8282
extern const impl_list_map_t &regular_bin_impl_list_map();
83+
extern const impl_list_map_t &regular_nf4_impl_list_map();
84+
extern const impl_list_map_t &regular_s4_impl_list_map();
85+
extern const impl_list_map_t &regular_u4_impl_list_map();
8386

8487
/* conv reorders w/ compensation */
8588
extern const impl_list_map_t &comp_f32_s8_impl_list_map();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*******************************************************************************
2+
* Copyright 2021 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#include "cpu/reorder/cpu_reorder.hpp"
18+
19+
namespace dnnl {
20+
namespace impl {
21+
namespace cpu {
22+
23+
// clang-format off
24+
25+
const impl_list_map_t &regular_nf4_impl_list_map() {
26+
static const impl_list_map_t the_map = REG_REORDER_P({
27+
// nf4 ->
28+
{{nf4, data_type::undef, 0}, {
29+
REG_SR(nf4, any, nf4, OI8i8o2i, fmt_order_keep)
30+
REG_SR(nf4, any, nf4, OI8i16o2i, fmt_order_keep)
31+
REG_SR(nf4, any, nf4, OI8i24o2i, fmt_order_keep)
32+
REG_SR(nf4, any, nf4, OI8i32o2i, fmt_order_keep)
33+
REG_SR(nf4, any, nf4, OI8i64o2i, fmt_order_keep)
34+
nullptr,
35+
}},
36+
});
37+
return the_map;
38+
}
39+
40+
// clang-format on
41+
42+
} // namespace cpu
43+
} // namespace impl
44+
} // namespace dnnl

0 commit comments

Comments
 (0)