Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TorchFX: GPTQ accuracy fix #26294

Merged
merged 19 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/job_pytorch_models_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,17 @@ jobs:
TEST_DEVICE: CPU
USE_SYSTEM_CACHE: False

- name: TorchFX GPTQ Pattern Test
if: ${{ inputs.model_scope == 'precommit' }}
# install torch 2.3.1 as newer is not yet supported by openvino backend
run: |
export PYTHONPATH=${MODEL_HUB_TESTS_INSTALL_DIR}:$PYTHONPATH
python3 -m pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --upgrade --index-url https://download.pytorch.org/whl/cpu
python3 -m pytest ${MODEL_HUB_TESTS_INSTALL_DIR}/transformation_tests/test_gptq_torchfx_transformations.py -m precommit --html=${INSTALL_TEST_DIR}/TEST-torch_gptqpattern_tests.html --self-contained-html -v --tb=short
env:
TEST_DEVICE: CPU
USE_SYSTEM_CACHE: False

- name: Reformat unsupported ops file
if: ${{ inputs.model_scope != 'precommit' && !cancelled()}}
run: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,6 @@ uint32_t read_u4_data(const void* array, size_t index) {
return val;
};

void write_u4_data(void* array, size_t index, uint32_t data) {
auto arr_u32 = reinterpret_cast<uint32_t*>(array);
size_t idx_u32 = index / 8;
size_t offset_u32 = index % 8;
uint32_t old_val = arr_u32[idx_u32];
data = data << (offset_u32 * 4);
uint32_t mask = 15;
mask = ~(mask << (offset_u32 * 4));
uint32_t new_val = (old_val & mask) | data;
arr_u32[idx_u32] = new_val;
};

GPTQDecompressionReplacer::GPTQDecompressionReplacer() {
const auto& const_1 = wrap_type<v0::Constant>();
const auto& const_2 = wrap_type<v0::Constant>();
Expand All @@ -73,61 +61,157 @@ GPTQDecompressionReplacer::GPTQDecompressionReplacer() {
const auto& convert_2 = wrap_type<v0::Convert>({const_6});
const auto& bitwise_and = wrap_type<ov::op::v13::BitwiseAnd>({add_or_convert, convert_2});

ov::matcher_pass_callback callback = [unsqueeze_1](Matcher& m) {
ov::matcher_pass_callback callback = [=](Matcher& m) {
auto bitwise_and = m.get_match_root();
if (!bitwise_and) {
return false;
}
const auto& pattern_map = m.get_pattern_value_map();
const auto& input_node = pattern_map.at(unsqueeze_1).get_node_shared_ptr();
auto weights_u32 = std::dynamic_pointer_cast<v0::Constant>(input_node->get_input_node_shared_ptr(0));
auto axis = std::dynamic_pointer_cast<v0::Constant>(input_node->get_input_node_shared_ptr(1));
auto axis_data = axis->get_data_ptr<uint32_t>();

auto u8_shape = weights_u32->get_shape();
auto src = weights_u32->get_data_ptr<uint32_t>();

ov::Shape u4_shape;
bool dim_added = false;
size_t stride = 1;
size_t size_y = 1;
for (size_t i = 0; i < u8_shape.size(); i++) {
if (axis_data[0] == i) {
u4_shape.push_back(8);
dim_added = true;
}
if (axis_data[0] <= i) {
stride *= u8_shape[i];
} else {
size_y *= u8_shape[i];
}
u4_shape.push_back(u8_shape[i]);
auto unsqueeze_1_node = pattern_map.at(unsqueeze_1).get_node_shared_ptr();
auto unsqueeze_1_in0_const =
std::dynamic_pointer_cast<v0::Constant>(unsqueeze_1_node->get_input_node_shared_ptr(0));
auto unsqueeze_1_in1_const =
std::dynamic_pointer_cast<v0::Constant>(unsqueeze_1_node->get_input_node_shared_ptr(1));
auto abs_node = pattern_map.at(abs).get_node_shared_ptr();
auto abs_in_const = std::dynamic_pointer_cast<v0::Constant>(abs_node->get_input_node_shared_ptr(0));
auto broadcast_node = pattern_map.at(broadcast).get_node_shared_ptr();
auto unsqueeze_2_node = pattern_map.at(unsqueeze_2).get_node_shared_ptr();
auto unsqueeze_2_in0_const =
std::dynamic_pointer_cast<v0::Constant>(unsqueeze_2_node->get_input_node_shared_ptr(0));
auto unsqueeze_2_in1_const =
std::dynamic_pointer_cast<v0::Constant>(unsqueeze_2_node->get_input_node_shared_ptr(1));

OutputVector outputs_1(unsqueeze_1_node->get_output_size());
OutputVector unsqueeze_1_inputs(2);
unsqueeze_1_inputs[0] = unsqueeze_1_in0_const->outputs()[0];
unsqueeze_1_inputs[1] = unsqueeze_1_in1_const->outputs()[0];
if (!unsqueeze_1_node->constant_fold(outputs_1, unsqueeze_1_inputs)) {
return false;
}
if (!dim_added) {
u4_shape.push_back(8);

OutputVector outputs_2(abs_node->get_output_size());
if (!abs_node->constant_fold(outputs_2, abs_in_const->outputs())) {
return false;
}

auto new_const = std::make_shared<v0::Constant>(element::u4, u4_shape);
auto dst = const_cast<uint32_t*>(reinterpret_cast<const uint32_t*>(new_const->get_data_ptr()));
OutputVector outputs_3(broadcast_node->get_output_size());
OutputVector broadcast_inputs(2);
broadcast_inputs[0] = outputs_1[0];
broadcast_inputs[1] = outputs_2[0];
if (!broadcast_node->constant_fold(outputs_3, broadcast_inputs)) {
return false;
}

OutputVector outputs_4(unsqueeze_2_node->get_output_size());
OutputVector unsqueeze_2_inputs(2);
unsqueeze_2_inputs[0] = unsqueeze_2_in0_const->outputs()[0];
unsqueeze_2_inputs[1] = unsqueeze_2_in1_const->outputs()[0];
if (!unsqueeze_2_node->constant_fold(outputs_4, unsqueeze_2_inputs)) {
return false;
}
const int32_t* rs_in0 =
std::dynamic_pointer_cast<v0::Constant>(outputs_3[0].get_node_shared_ptr())->get_data_ptr<int32_t>();
const int32_t* rs_in1 =
std::dynamic_pointer_cast<v0::Constant>(outputs_4[0].get_node_shared_ptr())->get_data_ptr<int32_t>();
auto shifted_const = std::make_shared<v0::Constant>(element::i32, outputs_3[0].get_shape());
auto dst = const_cast<int32_t*>(reinterpret_cast<const int32_t*>(shifted_const->get_data_ptr()));
if (!dst)
return false;

size_t in_idx = 0;
for (size_t y = 0; y < size_y; y++) {
size_t offset = y * stride * 8;
for (size_t x = 0; x < stride; x++) {
for (size_t z = 0; z < 8; z++) {
uint32_t val = read_u4_data(src, in_idx);
write_u4_data(dst, (offset + x + stride * z), val);
in_idx++;
}
// TODO: Bitwise right shift operation below might need to be
// optimized to reduce FIL.
size_t rs_in0_shape_size = shape_size(outputs_3[0].get_shape());
const auto& rs_in0_shape = outputs_3[0].get_shape();
const auto& rs_in1_shape = outputs_4[0].get_shape();
int shift_dim = -1;
size_t shift_offset = 1;
for (size_t i = 0; i < rs_in1_shape.size(); ++i) {
size_t dim = rs_in1_shape[i];
if (dim != 1 && dim != rs_in0_shape[i]) {
return false;
}
if (shift_dim != -1) {
shift_offset *= rs_in0_shape[i];
}
if (dim == rs_in0_shape[i]) {
shift_dim = static_cast<int>(i);
}
}
if (shift_dim == -1)
return false;
for (size_t k = 0; k < rs_in0_shape_size; ++k) {
size_t shift_idx = (k / shift_offset) % rs_in1_shape[shift_dim];
int32_t shift_val = rs_in1[shift_idx];
dst[k] = (rs_in0[k] >> shift_val);
}

std::shared_ptr<ov::Node> convert_1_node = nullptr;
OutputVector outputs_7;
if (pattern_map.find(convert_1) != pattern_map.end()) {
convert_1_node = pattern_map.at(convert_1).get_node_shared_ptr();
outputs_7.resize(convert_1_node->get_output_size());
if (!convert_1_node->constant_fold(outputs_7, shifted_const->outputs())) {
return false;
}
} else {
auto convert_3_node = pattern_map.at(convert_3).get_node_shared_ptr();
auto convert_4_node = pattern_map.at(convert_4).get_node_shared_ptr();
auto convert_4_in_const =
std::dynamic_pointer_cast<v0::Constant>(convert_4_node->get_input_node_shared_ptr(0));
auto add_node = pattern_map.at(add).get_node_shared_ptr();
OutputVector outputs_5(convert_3_node->get_output_size());
if (!convert_3_node->constant_fold(outputs_5, shifted_const->outputs())) {
return false;
}
OutputVector outputs_6(convert_4_node->get_output_size());
if (!convert_4_node->constant_fold(outputs_6, convert_4_in_const->outputs())) {
return false;
}
outputs_7.resize(add_node->get_output_size());
OutputVector add_inputs(2);
add_inputs[0] = outputs_5[0];
add_inputs[1] = outputs_6[0];
if (!add_node->constant_fold(outputs_7, add_inputs)) {
return false;
}
}

copy_runtime_info_and_name(weights_u32, {new_const}, {weights_u32, bitwise_and});
auto convert_2_node = pattern_map.at(convert_2).get_node_shared_ptr();
auto convert_2_in_const = std::dynamic_pointer_cast<v0::Constant>(convert_2_node->get_input_node_shared_ptr(0));

OutputVector outputs_8(convert_2_node->get_output_size());
if (!convert_2_node->constant_fold(outputs_8, convert_2_in_const->outputs())) {
return false;
}

OutputVector outputs_9(bitwise_and->get_output_size());

const int8_t* and_in0 =
std::dynamic_pointer_cast<v0::Constant>(outputs_7[0].get_node_shared_ptr())->get_data_ptr<int8_t>();
const int8_t* and_in1 =
std::dynamic_pointer_cast<v0::Constant>(outputs_8[0].get_node_shared_ptr())->get_data_ptr<int8_t>();
auto masked_const = std::make_shared<v0::Constant>(element::i8, outputs_7[0].get_shape());
auto masked_dst = const_cast<int8_t*>(reinterpret_cast<const int8_t*>(masked_const->get_data_ptr()));
if (!masked_dst)
return false;

size_t and_in0_shape_size = shape_size(outputs_7[0].get_shape());
// TODO: Bitwise and operation below might need to be
// optimized to reduce FIL.
int8_t mask = and_in1[0];
for (size_t k = 0; k < and_in0_shape_size; ++k) {
masked_dst[k] = (and_in0[k] & mask);
}

auto convert_to_u4 = std::make_shared<v0::Convert>(masked_const, element::u4);
OutputVector outputs_10(convert_to_u4->get_output_size());
if (!convert_to_u4->constant_fold(outputs_10, masked_const->outputs())) {
return false;
}

auto new_convert = std::make_shared<v0::Convert>(new_const, bitwise_and->get_output_element_type(0));
copy_runtime_info_and_name(bitwise_and, {new_convert}, {input_node});
auto new_convert =
std::make_shared<v0::Convert>(outputs_10[0].get_node_shared_ptr(), bitwise_and->get_output_element_type(0));
copy_runtime_info_and_name(bitwise_and, {new_convert}, {unsqueeze_1_node});
replace_node(bitwise_and, new_convert);
return true;
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
atorsvn/TinyLlama-1.1B-Chat-v0.3-gptq-4bit,https://huggingface.co/atorsvn/TinyLlama-1.1B-Chat-v0.3-gptq-4bit
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import hashlib
from openvino.frontend.pytorch.torchdynamo.execute import compiled_cache
import models_hub_common.utils as utils
import pytest
import os

def patch_gptq(config):
do_gptq_patching = False
config_dict = config.to_dict()
quantization_config = config_dict.get("quantization_config", None)
do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq"
orig_cuda_check = torch.cuda.is_available
orig_post_init_model = None
if do_gptq_patching:
torch.set_default_dtype(torch.float32)
torch.cuda.is_available = lambda: False

from optimum.gptq import GPTQQuantizer

orig_post_init_model = GPTQQuantizer.post_init_model

def post_init_model(self, model):
from auto_gptq import exllama_set_max_input_length

class StoreAttr(object):
pass

model.quantize_config = StoreAttr()
model.quantize_config.desc_act = self.desc_act
if self.desc_act and not self.disable_exllama and self.max_input_length is not None:
model = exllama_set_max_input_length(model, self.max_input_length)
return model

GPTQQuantizer.post_init_model = post_init_model
return orig_cuda_check, orig_post_init_model

def run_gptq_torchfx(tmp_path, model_id, model_link, prompt_result_pair):
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float32)
cuda, post_init = patch_gptq(config)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float32)
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
config=config,
device_map='cpu',
torch_dtype=torch.float32
)

pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=4,
do_sample=True,
temperature=0.01,
top_p=0.01,
top_k=1,
repetition_penalty=1.1,
num_beams=1,
)

prompt = prompt_result_pair["prompt"]
expected_md5 = prompt_result_pair["result_md5"]

model.model.forward = torch.compile(model.model.forward, backend="openvino", dynamic=True, fullgraph=True, options={'aot_autograd': True})

result_ov = pipe(prompt)
md5_ov = hashlib.new("md5", result_ov[0]['generated_text'].encode(), usedforsecurity=False).hexdigest()

u4_ops = ["FullyConnected",]
num_u4_ops = 0
num_u4_ops_supported = 0
for pid in compiled_cache:
for op in compiled_cache[pid].get_runtime_model().get_ordered_ops():
if (str(op.get_rt_info()["layerType"].get()) in u4_ops):
u4_exec = (str(op.get_rt_info()["runtimePrecision"].get()) == "u4")
if u4_exec:
num_u4_ops_supported += 1
num_u4_ops += 1

assert(expected_md5 == md5_ov), "Output does not match with the expected output"
assert((num_u4_ops > 0) and (num_u4_ops == num_u4_ops_supported)), "Runtime precision is not u4"

@pytest.mark.precommit
@pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "gptq-torchfx-models-precommit")))
@pytest.mark.parametrize('prompt_result_pair', ([
{"prompt" : "Tell me about AI", "result_md5" : "4385ccbce14627ae91f846b4c8a3f145"},
]))
def test_gptq_torchfx_precommit(tmp_path, model_name, model_link, mark, reason, prompt_result_pair, ie_device):
assert mark is None or mark == 'skip' or mark == 'xfail', \
"Incorrect test case: {}, {}".format(model_name, model_link)
if mark == 'skip':
pytest.skip(reason)
elif mark == 'xfail':
pytest.xfail(reason)
run_gptq_torchfx(tmp_path, model_name, model_link, prompt_result_pair)

Loading