|
| 1 | +/******************************************************************************* |
| 2 | + * Copyright 2023-2024 Intel Corporation |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + *******************************************************************************/ |
| 16 | +/// @example int4_weights_decompression.cpp |
| 17 | +/// > Annotated version: @ref int4_weights_decompression.cpp |
| 18 | +/// |
| 19 | +/// @page int4_weights_decompression |
| 20 | +/// C++ API example demonstrating how one can use |
| 21 | +/// [MatMul](@ref dev_guide_matmul) with int4 compressed weights. |
| 22 | +/// |
| 23 | +/// Concepts: |
| 24 | +/// - AWQ (activation-aware quantization) |
| 25 | +/// - Scales: dnnl::primitive_attr::set_scales() |
| 26 | +/// - Zero points: dnnl::primitive_attr::set_zero_points() |
| 27 | +/// - [Operation fusion](@ref dev_guide_attributes_post_ops) |
| 28 | +/// - Create primitive once, use multiple times |
| 29 | +/// - Weights pre-packing: use #dnnl::memory::format_tag::any |
| 30 | +/// |
| 31 | +/// @page int4_weights_decompression_matmul_cpp MatMul Tutorial: weights |
| 32 | +/// decompression |
| 33 | +/// @copydetails int4_weights_decompression_matmul_cpp |
| 34 | +/// |
| 35 | +/// Assumptions: |
| 36 | +/// 1. The shape of the weights (matrix \f$B(K, N)\f$) is known in advance, the |
| 37 | +/// data type is `int4` and shifted from 0 (i.e. the zero point is not 0). |
| 38 | +/// 2. The source matrix \f$A\f$ and destination matrix \f$C\f$ have floating |
| 39 | +/// point data type. |
| 40 | +/// 3. Scaling (re-quantization) factor specified at run-time only. |
| 41 | +/// |
| 42 | +/// Since the shape of weights is known in advance, the MatMul weights can be |
| 43 | +/// created with format tag #dnnl::memory::format_tag::any to enable the library |
| 44 | +/// to choose the most appropriate layout for best performance. |
| 45 | +/// |
| 46 | +/// @warning |
| 47 | +/// The format tag #dnnl::memory::format_tag::any doesn't work for memory |
| 48 | +/// descriptors that have one or more unknown dimensions and/or strides. |
| 49 | +/// |
| 50 | +/// @include int4_weight_decompression.cpp |
| 51 | +#include <cassert> |
| 52 | +#include <cctype> |
| 53 | +#include <cmath> |
| 54 | +#include <cstdio> |
| 55 | +#include <iostream> |
| 56 | +#include <random> |
| 57 | +#include <stdexcept> |
| 58 | +#include <vector> |
| 59 | + |
| 60 | +#include "oneapi/dnnl/dnnl.hpp" |
| 61 | + |
| 62 | +#include "example_utils.hpp" |
| 63 | + |
| 64 | +using namespace dnnl; |
| 65 | + |
| 66 | +namespace { |
| 67 | + |
| 68 | +void init_vector(std::vector<float> &v) { |
| 69 | + std::mt19937 gen; |
| 70 | + std::uniform_real_distribution<float> u(0, 1); |
| 71 | + for (auto &e : v) |
| 72 | + e = u(gen); |
| 73 | +} |
| 74 | +// Comparing two vectors by calculating their L2 norms and the L2 norm of their |
| 75 | +// difference Checking if the difference is within a calculated threshold The |
| 76 | +// function returns 0 if the vectors are considered similar, otherwise it |
| 77 | +// returns 1. |
| 78 | +int compare_vectors(const std::vector<float> &v1, const std::vector<float> &v2, |
| 79 | + int64_t K, const char *message) { |
| 80 | + double v1_l2 = 0, diff_l2 = 0; |
| 81 | + for (size_t n = 0; n < v1.size(); ++n) { |
| 82 | + float diff = v1[n] - v2[n]; |
| 83 | + v1_l2 += v1[n] * v1[n]; |
| 84 | + diff_l2 += diff * diff; |
| 85 | + } |
| 86 | + |
| 87 | + v1_l2 = std::sqrt(v1_l2); |
| 88 | + diff_l2 = std::sqrt(diff_l2); |
| 89 | + |
| 90 | + // Finding the reasonable (tight and accurate) threshold is quite difficult |
| 91 | + // problem. |
| 92 | + // The implementation testing might also use special data filling to |
| 93 | + // alleviate issues related to the finite precision arithmetic. |
| 94 | + // However, in simple cases the machine epsilon multiplied by log(K) should |
| 95 | + // work reasonably well. |
| 96 | + const double threshold = std::numeric_limits<float>::epsilon() |
| 97 | + * std::log(std::max(2., (double)K)); |
| 98 | + bool ok = diff_l2 <= threshold * v1_l2; |
| 99 | + |
| 100 | + printf("%s\n\tL2 Norms" |
| 101 | + "\n\t\tReference matrix:%g\n\t\tError:%g\n\t\tRelative_error:%g\n" |
| 102 | + "\tAccuracy check: %s\n", |
| 103 | + message, v1_l2, diff_l2, diff_l2 / v1_l2, ok ? "OK" : "FAILED"); |
| 104 | + |
| 105 | + return ok ? 0 : 1; |
| 106 | +} |
| 107 | + |
| 108 | +} // namespace |
| 109 | + |
| 110 | +// Floating point MatMul |
| 111 | +// Inputs: |
| 112 | +// - Shape: M, N, K |
| 113 | +// - Matrices A and B |
| 114 | +// Outputs: |
| 115 | +// - Matrix C |
| 116 | +void ref_compute_matmul_f32(int64_t M, int64_t N, int64_t K, int64_t G, |
| 117 | + std::vector<float> &A_f32, std::vector<float> &B_f32, |
| 118 | + std::vector<float> &zp_B_f32, std::vector<float> &sc_B, |
| 119 | + std::vector<float> &C_f32) { |
| 120 | + // Perform the GEMM operation |
| 121 | + for (int m = 0; m < M; ++m) { |
| 122 | + for (int n = 0; n < N; ++n) { |
| 123 | + for (int k = 0; k < K; ++k) { |
| 124 | + // Decompress the weight |
| 125 | + int64_t idx1 = k * N + n; |
| 126 | + int64_t idx2 = (k / G) * N + n; |
| 127 | + float decompressed_B |
| 128 | + = (B_f32[idx1] - zp_B_f32[idx1]) * sc_B[idx2]; |
| 129 | + // Perform the multiplication and accumulation |
| 130 | + C_f32[m * N + n] += A_f32[m * K + k] * decompressed_B; |
| 131 | + } |
| 132 | + } |
| 133 | + } |
| 134 | +} |
| 135 | + |
| 136 | +// Create a MatMul primitive descriptor for the following op: |
| 137 | +// C_f32 = A_f32 * (B_s4 - zp_B) * sc_B[:] |
| 138 | +matmul::primitive_desc matmul_pd_create( |
| 139 | + int64_t M, int64_t N, int64_t K, int64_t G, const engine &eng) { |
| 140 | + |
| 141 | + memory::desc a_md({M, K}, memory::data_type::f32, {K, 1}); // M x K layout |
| 142 | + memory::desc b_md({K, N}, memory::data_type::s4, |
| 143 | + memory::format_tag::any); // K x N layout |
| 144 | + memory::desc c_md({M, N}, memory::data_type::f32, {N, 1}); // M x N layout |
| 145 | + |
| 146 | + // Create attributes and indicate that the alpha and zero points are |
| 147 | + // runtime parameters |
| 148 | + primitive_attr attr; |
| 149 | + // Set scales with multiple scales along K and N dimensions and with groups |
| 150 | + // along K. |
| 151 | + attr.set_scales(DNNL_ARG_WEIGHTS, |
| 152 | + /* mask */ (1 << 0) + (1 << 1), {G, 1}, memory::data_type::f32); |
| 153 | + |
| 154 | + // Set zero points with s4 data type. |
| 155 | + // The mask determines which dimensions the zero points are applied to. |
| 156 | + // Current mask value (1 << 0) + (1 << 1) means zero points are applied |
| 157 | + // both along K and N dimensions. |
| 158 | + // Changing the mask value would alter the dimensions along which the zero |
| 159 | + // points are applied. For example: |
| 160 | + // - mask = (1 << 0) would apply zero points only along the K dimension. |
| 161 | + // - mask = (1 << 1) would apply zero points only along the N dimension. |
| 162 | + int mask = (1 << 0) + (1 << 1); // zero points both along K and N dimensions |
| 163 | + memory::dims groups = {}; |
| 164 | + attr.set_zero_points(DNNL_ARG_WEIGHTS, mask, groups, memory::data_type::s4); |
| 165 | + |
| 166 | + // Set fpmath mode with `apply_to_int=true` to apply fpmath mode behavior to |
| 167 | + // integral primitives (in this example, matmul). |
| 168 | + attr.set_fpmath_mode(fpmath_mode::f16, true); |
| 169 | + |
| 170 | + // Create a MatMul primitive descriptor |
| 171 | + return matmul::primitive_desc(eng, a_md, b_md, c_md, attr); |
| 172 | +} |
| 173 | + |
| 174 | +// Function to perform matrix multiplication with int4 weights decompression |
| 175 | +// using oneDNN |
| 176 | +void weights_decompression_matmul(int64_t M, int64_t N, int64_t K, int64_t G, |
| 177 | + std::vector<float> &A_f32, std::vector<float> &B_f32, |
| 178 | + std::vector<float> &zp_B_f32, std::vector<float> &sc_B, |
| 179 | + std::vector<float> &C_f32, const engine &eng) { |
| 180 | + auto matmul_pd = matmul_pd_create(M, N, K, G, eng); |
| 181 | + stream s(eng); |
| 182 | + |
| 183 | + // Pre-packed weights stored as int4 |
| 184 | + memory B_s4_mem(matmul_pd.weights_desc(), eng); |
| 185 | + { |
| 186 | + memory B_f32_mem( |
| 187 | + {{K, N}, memory::data_type::f32, memory::format_tag::ab}, eng); |
| 188 | + write_to_dnnl_memory(B_f32.data(), B_f32_mem); |
| 189 | + reorder(B_f32_mem, B_s4_mem).execute(s, B_f32_mem, B_s4_mem); |
| 190 | + s.wait(); |
| 191 | + } |
| 192 | + matmul matmul_p(matmul_pd); |
| 193 | + |
| 194 | + // input of the current layer / operation |
| 195 | + memory A_f32_mem({{M, K}, memory::data_type::f32, {K, 1}}, eng); |
| 196 | + // De-quantization parameters (eg. Scale and Shift) |
| 197 | + const int64_t n_groups = K / G; |
| 198 | + memory sc_B_mem({{N, n_groups}, memory::data_type::f32, {1, N}}, eng); |
| 199 | + |
| 200 | + // Pre-packed zp stored as int4 |
| 201 | + // A unique zero point is used for each weight in this example |
| 202 | + // Allocates memory for zp_B_s4_mem with specified dimensions and data type. |
| 203 | + memory zp_B_s4_mem({{K, N}, memory::data_type::s4, {1, K}}, eng); |
| 204 | + { |
| 205 | + memory zp_B_f32_mem({{K, N}, memory::data_type::f32, {1, K}}, eng); |
| 206 | + write_to_dnnl_memory(zp_B_f32.data(), zp_B_f32_mem); |
| 207 | + reorder(zp_B_f32_mem, zp_B_s4_mem) |
| 208 | + .execute(s, zp_B_f32_mem, zp_B_s4_mem); |
| 209 | + s.wait(); |
| 210 | + } |
| 211 | + |
| 212 | + write_to_dnnl_memory(A_f32.data(), A_f32_mem); |
| 213 | + write_to_dnnl_memory(sc_B.data(), sc_B_mem); |
| 214 | + |
| 215 | + // output - no initialization required |
| 216 | + memory C_f32_mem({{M, N}, memory::data_type::f32, {N, 1}}, eng); |
| 217 | + |
| 218 | + matmul_p.execute(s, |
| 219 | + {{DNNL_ARG_SRC, A_f32_mem}, {DNNL_ARG_WEIGHTS, B_s4_mem}, |
| 220 | + {DNNL_ARG_DST, C_f32_mem}, |
| 221 | + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, sc_B_mem}, |
| 222 | + {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, |
| 223 | + zp_B_s4_mem}}); |
| 224 | + s.wait(); |
| 225 | +} |
| 226 | + |
| 227 | +// Compares the results of reference matrix multiplication and oneDNN weights |
| 228 | +// decompression. |
| 229 | +void compare_ref_and_weights_decompression(engine::kind engine_kind) { |
| 230 | + engine eng(engine_kind, 0); |
| 231 | + |
| 232 | + // MatMul parameters |
| 233 | + const int64_t M = 1, N = 4096, K = 1024; |
| 234 | + // Quantization Group size for scales |
| 235 | + const int64_t G = 64; |
| 236 | + |
| 237 | + // Prepare matrices |
| 238 | + std::vector<float> A_f32(M * K), C_ref(M * N), sc_B(K * N / G); |
| 239 | + std::vector<float> B_f32(K * N); |
| 240 | + std::vector<float> zp_B_f32(K * N); |
| 241 | + init_vector(A_f32); |
| 242 | + init_vector(B_f32); |
| 243 | + init_vector(sc_B); |
| 244 | + init_vector(zp_B_f32); |
| 245 | + init_vector(C_ref); |
| 246 | + std::vector<float> C_onednn = C_ref; |
| 247 | + |
| 248 | + // Compute _true_ C_ref result |
| 249 | + ref_compute_matmul_f32(M, N, K, G, A_f32, B_f32, zp_B_f32, sc_B, C_ref); |
| 250 | + |
| 251 | + // Compute _true_ C_onednn result |
| 252 | + weights_decompression_matmul( |
| 253 | + M, N, K, G, A_f32, B_f32, zp_B_f32, sc_B, C_onednn, eng); |
| 254 | + |
| 255 | + int rc = 0; |
| 256 | + rc |= compare_vectors( |
| 257 | + C_ref, C_onednn, K, "Compare ref vs oneDNN weights decompression"); |
| 258 | + if (rc) throw std::logic_error("The resulting matrices diverged too much."); |
| 259 | +} |
| 260 | + |
| 261 | +int main(int argc, char **argv) { |
| 262 | + engine::kind engine_kind = parse_engine_kind(argc, argv); |
| 263 | + return handle_example_errors( |
| 264 | + compare_ref_and_weights_decompression, engine_kind); |
| 265 | +} |
0 commit comments