Skip to content

Commit 43df0b6

Browse files
cavusmustafamvafin
andauthored
TorchFX: GPTQ accuracy fix (#26294)
### Details: - Fix for the accuracy issues discovered in Llama2 GPTQ with aot_autograd ### Tickets: - [CVS-149032](https://jira.devtools.intel.com/browse/CVS-149032) --------- Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
1 parent 62183ab commit 43df0b6

File tree

4 files changed

+250
-52
lines changed

4 files changed

+250
-52
lines changed

.github/workflows/job_pytorch_models_tests.yml

+11
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,17 @@ jobs:
160160
TEST_DEVICE: CPU
161161
USE_SYSTEM_CACHE: False
162162

163+
- name: TorchFX GPTQ Pattern Test
164+
if: ${{ inputs.model_scope == 'precommit' }}
165+
# install torch 2.3.1 as newer is not yet supported by openvino backend
166+
run: |
167+
export PYTHONPATH=${MODEL_HUB_TESTS_INSTALL_DIR}:$PYTHONPATH
168+
python3 -m pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --upgrade --index-url https://download.pytorch.org/whl/cpu
169+
python3 -m pytest ${MODEL_HUB_TESTS_INSTALL_DIR}/transformation_tests/test_gptq_torchfx_transformations.py -m precommit --html=${INSTALL_TEST_DIR}/TEST-torch_gptqpattern_tests.html --self-contained-html -v --tb=short
170+
env:
171+
TEST_DEVICE: CPU
172+
USE_SYSTEM_CACHE: False
173+
163174
- name: Reformat unsupported ops file
164175
if: ${{ inputs.model_scope != 'precommit' && !cancelled()}}
165176
run: |

src/frontends/pytorch/src/transforms/torchfx_gptq_pattern_replacer.cpp

+136-52
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,6 @@ uint32_t read_u4_data(const void* array, size_t index) {
4040
return val;
4141
};
4242

43-
void write_u4_data(void* array, size_t index, uint32_t data) {
44-
auto arr_u32 = reinterpret_cast<uint32_t*>(array);
45-
size_t idx_u32 = index / 8;
46-
size_t offset_u32 = index % 8;
47-
uint32_t old_val = arr_u32[idx_u32];
48-
data = data << (offset_u32 * 4);
49-
uint32_t mask = 15;
50-
mask = ~(mask << (offset_u32 * 4));
51-
uint32_t new_val = (old_val & mask) | data;
52-
arr_u32[idx_u32] = new_val;
53-
};
54-
5543
GPTQDecompressionReplacer::GPTQDecompressionReplacer() {
5644
const auto& const_1 = wrap_type<v0::Constant>();
5745
const auto& const_2 = wrap_type<v0::Constant>();
@@ -73,61 +61,157 @@ GPTQDecompressionReplacer::GPTQDecompressionReplacer() {
7361
const auto& convert_2 = wrap_type<v0::Convert>({const_6});
7462
const auto& bitwise_and = wrap_type<ov::op::v13::BitwiseAnd>({add_or_convert, convert_2});
7563

76-
ov::matcher_pass_callback callback = [unsqueeze_1](Matcher& m) {
64+
ov::matcher_pass_callback callback = [=](Matcher& m) {
7765
auto bitwise_and = m.get_match_root();
7866
if (!bitwise_and) {
7967
return false;
8068
}
8169
const auto& pattern_map = m.get_pattern_value_map();
82-
const auto& input_node = pattern_map.at(unsqueeze_1).get_node_shared_ptr();
83-
auto weights_u32 = std::dynamic_pointer_cast<v0::Constant>(input_node->get_input_node_shared_ptr(0));
84-
auto axis = std::dynamic_pointer_cast<v0::Constant>(input_node->get_input_node_shared_ptr(1));
85-
auto axis_data = axis->get_data_ptr<uint32_t>();
86-
87-
auto u8_shape = weights_u32->get_shape();
88-
auto src = weights_u32->get_data_ptr<uint32_t>();
89-
90-
ov::Shape u4_shape;
91-
bool dim_added = false;
92-
size_t stride = 1;
93-
size_t size_y = 1;
94-
for (size_t i = 0; i < u8_shape.size(); i++) {
95-
if (axis_data[0] == i) {
96-
u4_shape.push_back(8);
97-
dim_added = true;
98-
}
99-
if (axis_data[0] <= i) {
100-
stride *= u8_shape[i];
101-
} else {
102-
size_y *= u8_shape[i];
103-
}
104-
u4_shape.push_back(u8_shape[i]);
70+
auto unsqueeze_1_node = pattern_map.at(unsqueeze_1).get_node_shared_ptr();
71+
auto unsqueeze_1_in0_const =
72+
std::dynamic_pointer_cast<v0::Constant>(unsqueeze_1_node->get_input_node_shared_ptr(0));
73+
auto unsqueeze_1_in1_const =
74+
std::dynamic_pointer_cast<v0::Constant>(unsqueeze_1_node->get_input_node_shared_ptr(1));
75+
auto abs_node = pattern_map.at(abs).get_node_shared_ptr();
76+
auto abs_in_const = std::dynamic_pointer_cast<v0::Constant>(abs_node->get_input_node_shared_ptr(0));
77+
auto broadcast_node = pattern_map.at(broadcast).get_node_shared_ptr();
78+
auto unsqueeze_2_node = pattern_map.at(unsqueeze_2).get_node_shared_ptr();
79+
auto unsqueeze_2_in0_const =
80+
std::dynamic_pointer_cast<v0::Constant>(unsqueeze_2_node->get_input_node_shared_ptr(0));
81+
auto unsqueeze_2_in1_const =
82+
std::dynamic_pointer_cast<v0::Constant>(unsqueeze_2_node->get_input_node_shared_ptr(1));
83+
84+
OutputVector outputs_1(unsqueeze_1_node->get_output_size());
85+
OutputVector unsqueeze_1_inputs(2);
86+
unsqueeze_1_inputs[0] = unsqueeze_1_in0_const->outputs()[0];
87+
unsqueeze_1_inputs[1] = unsqueeze_1_in1_const->outputs()[0];
88+
if (!unsqueeze_1_node->constant_fold(outputs_1, unsqueeze_1_inputs)) {
89+
return false;
10590
}
106-
if (!dim_added) {
107-
u4_shape.push_back(8);
91+
92+
OutputVector outputs_2(abs_node->get_output_size());
93+
if (!abs_node->constant_fold(outputs_2, abs_in_const->outputs())) {
94+
return false;
10895
}
10996

110-
auto new_const = std::make_shared<v0::Constant>(element::u4, u4_shape);
111-
auto dst = const_cast<uint32_t*>(reinterpret_cast<const uint32_t*>(new_const->get_data_ptr()));
97+
OutputVector outputs_3(broadcast_node->get_output_size());
98+
OutputVector broadcast_inputs(2);
99+
broadcast_inputs[0] = outputs_1[0];
100+
broadcast_inputs[1] = outputs_2[0];
101+
if (!broadcast_node->constant_fold(outputs_3, broadcast_inputs)) {
102+
return false;
103+
}
104+
105+
OutputVector outputs_4(unsqueeze_2_node->get_output_size());
106+
OutputVector unsqueeze_2_inputs(2);
107+
unsqueeze_2_inputs[0] = unsqueeze_2_in0_const->outputs()[0];
108+
unsqueeze_2_inputs[1] = unsqueeze_2_in1_const->outputs()[0];
109+
if (!unsqueeze_2_node->constant_fold(outputs_4, unsqueeze_2_inputs)) {
110+
return false;
111+
}
112+
const int32_t* rs_in0 =
113+
std::dynamic_pointer_cast<v0::Constant>(outputs_3[0].get_node_shared_ptr())->get_data_ptr<int32_t>();
114+
const int32_t* rs_in1 =
115+
std::dynamic_pointer_cast<v0::Constant>(outputs_4[0].get_node_shared_ptr())->get_data_ptr<int32_t>();
116+
auto shifted_const = std::make_shared<v0::Constant>(element::i32, outputs_3[0].get_shape());
117+
auto dst = const_cast<int32_t*>(reinterpret_cast<const int32_t*>(shifted_const->get_data_ptr()));
112118
if (!dst)
113119
return false;
114120

115-
size_t in_idx = 0;
116-
for (size_t y = 0; y < size_y; y++) {
117-
size_t offset = y * stride * 8;
118-
for (size_t x = 0; x < stride; x++) {
119-
for (size_t z = 0; z < 8; z++) {
120-
uint32_t val = read_u4_data(src, in_idx);
121-
write_u4_data(dst, (offset + x + stride * z), val);
122-
in_idx++;
123-
}
121+
// TODO: Bitwise right shift operation below might need to be
122+
// optimized to reduce FIL.
123+
size_t rs_in0_shape_size = shape_size(outputs_3[0].get_shape());
124+
const auto& rs_in0_shape = outputs_3[0].get_shape();
125+
const auto& rs_in1_shape = outputs_4[0].get_shape();
126+
int shift_dim = -1;
127+
size_t shift_offset = 1;
128+
for (size_t i = 0; i < rs_in1_shape.size(); ++i) {
129+
size_t dim = rs_in1_shape[i];
130+
if (dim != 1 && dim != rs_in0_shape[i]) {
131+
return false;
132+
}
133+
if (shift_dim != -1) {
134+
shift_offset *= rs_in0_shape[i];
135+
}
136+
if (dim == rs_in0_shape[i]) {
137+
shift_dim = static_cast<int>(i);
138+
}
139+
}
140+
if (shift_dim == -1)
141+
return false;
142+
for (size_t k = 0; k < rs_in0_shape_size; ++k) {
143+
size_t shift_idx = (k / shift_offset) % rs_in1_shape[shift_dim];
144+
int32_t shift_val = rs_in1[shift_idx];
145+
dst[k] = (rs_in0[k] >> shift_val);
146+
}
147+
148+
std::shared_ptr<ov::Node> convert_1_node = nullptr;
149+
OutputVector outputs_7;
150+
if (pattern_map.find(convert_1) != pattern_map.end()) {
151+
convert_1_node = pattern_map.at(convert_1).get_node_shared_ptr();
152+
outputs_7.resize(convert_1_node->get_output_size());
153+
if (!convert_1_node->constant_fold(outputs_7, shifted_const->outputs())) {
154+
return false;
155+
}
156+
} else {
157+
auto convert_3_node = pattern_map.at(convert_3).get_node_shared_ptr();
158+
auto convert_4_node = pattern_map.at(convert_4).get_node_shared_ptr();
159+
auto convert_4_in_const =
160+
std::dynamic_pointer_cast<v0::Constant>(convert_4_node->get_input_node_shared_ptr(0));
161+
auto add_node = pattern_map.at(add).get_node_shared_ptr();
162+
OutputVector outputs_5(convert_3_node->get_output_size());
163+
if (!convert_3_node->constant_fold(outputs_5, shifted_const->outputs())) {
164+
return false;
165+
}
166+
OutputVector outputs_6(convert_4_node->get_output_size());
167+
if (!convert_4_node->constant_fold(outputs_6, convert_4_in_const->outputs())) {
168+
return false;
169+
}
170+
outputs_7.resize(add_node->get_output_size());
171+
OutputVector add_inputs(2);
172+
add_inputs[0] = outputs_5[0];
173+
add_inputs[1] = outputs_6[0];
174+
if (!add_node->constant_fold(outputs_7, add_inputs)) {
175+
return false;
124176
}
125177
}
126178

127-
copy_runtime_info_and_name(weights_u32, {new_const}, {weights_u32, bitwise_and});
179+
auto convert_2_node = pattern_map.at(convert_2).get_node_shared_ptr();
180+
auto convert_2_in_const = std::dynamic_pointer_cast<v0::Constant>(convert_2_node->get_input_node_shared_ptr(0));
181+
182+
OutputVector outputs_8(convert_2_node->get_output_size());
183+
if (!convert_2_node->constant_fold(outputs_8, convert_2_in_const->outputs())) {
184+
return false;
185+
}
186+
187+
OutputVector outputs_9(bitwise_and->get_output_size());
188+
189+
const int8_t* and_in0 =
190+
std::dynamic_pointer_cast<v0::Constant>(outputs_7[0].get_node_shared_ptr())->get_data_ptr<int8_t>();
191+
const int8_t* and_in1 =
192+
std::dynamic_pointer_cast<v0::Constant>(outputs_8[0].get_node_shared_ptr())->get_data_ptr<int8_t>();
193+
auto masked_const = std::make_shared<v0::Constant>(element::i8, outputs_7[0].get_shape());
194+
auto masked_dst = const_cast<int8_t*>(reinterpret_cast<const int8_t*>(masked_const->get_data_ptr()));
195+
if (!masked_dst)
196+
return false;
197+
198+
size_t and_in0_shape_size = shape_size(outputs_7[0].get_shape());
199+
// TODO: Bitwise and operation below might need to be
200+
// optimized to reduce FIL.
201+
int8_t mask = and_in1[0];
202+
for (size_t k = 0; k < and_in0_shape_size; ++k) {
203+
masked_dst[k] = (and_in0[k] & mask);
204+
}
205+
206+
auto convert_to_u4 = std::make_shared<v0::Convert>(masked_const, element::u4);
207+
OutputVector outputs_10(convert_to_u4->get_output_size());
208+
if (!convert_to_u4->constant_fold(outputs_10, masked_const->outputs())) {
209+
return false;
210+
}
128211

129-
auto new_convert = std::make_shared<v0::Convert>(new_const, bitwise_and->get_output_element_type(0));
130-
copy_runtime_info_and_name(bitwise_and, {new_convert}, {input_node});
212+
auto new_convert =
213+
std::make_shared<v0::Convert>(outputs_10[0].get_node_shared_ptr(), bitwise_and->get_output_element_type(0));
214+
copy_runtime_info_and_name(bitwise_and, {new_convert}, {unsqueeze_1_node});
131215
replace_node(bitwise_and, new_convert);
132216
return true;
133217
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
atorsvn/TinyLlama-1.1B-Chat-v0.3-gptq-4bit,https://huggingface.co/atorsvn/TinyLlama-1.1B-Chat-v0.3-gptq-4bit
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (C) 2018-2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, pipeline
5+
import torch
6+
import hashlib
7+
from openvino.frontend.pytorch.torchdynamo.execute import compiled_cache
8+
import models_hub_common.utils as utils
9+
import pytest
10+
import os
11+
12+
def patch_gptq(config):
13+
do_gptq_patching = False
14+
config_dict = config.to_dict()
15+
quantization_config = config_dict.get("quantization_config", None)
16+
do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq"
17+
orig_cuda_check = torch.cuda.is_available
18+
orig_post_init_model = None
19+
if do_gptq_patching:
20+
torch.set_default_dtype(torch.float32)
21+
torch.cuda.is_available = lambda: False
22+
23+
from optimum.gptq import GPTQQuantizer
24+
25+
orig_post_init_model = GPTQQuantizer.post_init_model
26+
27+
def post_init_model(self, model):
28+
from auto_gptq import exllama_set_max_input_length
29+
30+
class StoreAttr(object):
31+
pass
32+
33+
model.quantize_config = StoreAttr()
34+
model.quantize_config.desc_act = self.desc_act
35+
if self.desc_act and not self.disable_exllama and self.max_input_length is not None:
36+
model = exllama_set_max_input_length(model, self.max_input_length)
37+
return model
38+
39+
GPTQQuantizer.post_init_model = post_init_model
40+
return orig_cuda_check, orig_post_init_model
41+
42+
def run_gptq_torchfx(tmp_path, model_id, model_link, prompt_result_pair):
43+
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float32)
44+
cuda, post_init = patch_gptq(config)
45+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, torch_dtype=torch.float32)
46+
model = AutoModelForCausalLM.from_pretrained(
47+
model_id,
48+
trust_remote_code=True,
49+
config=config,
50+
device_map='cpu',
51+
torch_dtype=torch.float32
52+
)
53+
54+
pipe = pipeline(
55+
"text-generation",
56+
model=model,
57+
tokenizer=tokenizer,
58+
max_new_tokens=4,
59+
do_sample=True,
60+
temperature=0.01,
61+
top_p=0.01,
62+
top_k=1,
63+
repetition_penalty=1.1,
64+
num_beams=1,
65+
)
66+
67+
prompt = prompt_result_pair["prompt"]
68+
expected_md5 = prompt_result_pair["result_md5"]
69+
70+
model.model.forward = torch.compile(model.model.forward, backend="openvino", dynamic=True, fullgraph=True, options={'aot_autograd': True})
71+
72+
result_ov = pipe(prompt)
73+
md5_ov = hashlib.new("md5", result_ov[0]['generated_text'].encode(), usedforsecurity=False).hexdigest()
74+
75+
u4_ops = ["FullyConnected",]
76+
num_u4_ops = 0
77+
num_u4_ops_supported = 0
78+
for pid in compiled_cache:
79+
for op in compiled_cache[pid].get_runtime_model().get_ordered_ops():
80+
if (str(op.get_rt_info()["layerType"].get()) in u4_ops):
81+
u4_exec = (str(op.get_rt_info()["runtimePrecision"].get()) == "u4")
82+
if u4_exec:
83+
num_u4_ops_supported += 1
84+
num_u4_ops += 1
85+
86+
assert(expected_md5 == md5_ov), "Output does not match with the expected output"
87+
assert((num_u4_ops > 0) and (num_u4_ops == num_u4_ops_supported)), "Runtime precision is not u4"
88+
89+
@pytest.mark.precommit
90+
@pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "gptq-torchfx-models-precommit")))
91+
@pytest.mark.parametrize('prompt_result_pair', ([
92+
{"prompt" : "Tell me about AI", "result_md5" : "4385ccbce14627ae91f846b4c8a3f145"},
93+
]))
94+
def test_gptq_torchfx_precommit(tmp_path, model_name, model_link, mark, reason, prompt_result_pair, ie_device):
95+
assert mark is None or mark == 'skip' or mark == 'xfail', \
96+
"Incorrect test case: {}, {}".format(model_name, model_link)
97+
if mark == 'skip':
98+
pytest.skip(reason)
99+
elif mark == 'xfail':
100+
pytest.xfail(reason)
101+
run_gptq_torchfx(tmp_path, model_name, model_link, prompt_result_pair)
102+

0 commit comments

Comments
 (0)