@@ -40,18 +40,6 @@ uint32_t read_u4_data(const void* array, size_t index) {
40
40
return val;
41
41
};
42
42
43
- void write_u4_data (void * array, size_t index, uint32_t data) {
44
- auto arr_u32 = reinterpret_cast <uint32_t *>(array);
45
- size_t idx_u32 = index / 8 ;
46
- size_t offset_u32 = index % 8 ;
47
- uint32_t old_val = arr_u32[idx_u32];
48
- data = data << (offset_u32 * 4 );
49
- uint32_t mask = 15 ;
50
- mask = ~(mask << (offset_u32 * 4 ));
51
- uint32_t new_val = (old_val & mask) | data;
52
- arr_u32[idx_u32] = new_val;
53
- };
54
-
55
43
GPTQDecompressionReplacer::GPTQDecompressionReplacer () {
56
44
const auto & const_1 = wrap_type<v0::Constant>();
57
45
const auto & const_2 = wrap_type<v0::Constant>();
@@ -73,61 +61,157 @@ GPTQDecompressionReplacer::GPTQDecompressionReplacer() {
73
61
const auto & convert_2 = wrap_type<v0::Convert>({const_6});
74
62
const auto & bitwise_and = wrap_type<ov::op::v13::BitwiseAnd>({add_or_convert, convert_2});
75
63
76
- ov::matcher_pass_callback callback = [unsqueeze_1 ](Matcher& m) {
64
+ ov::matcher_pass_callback callback = [= ](Matcher& m) {
77
65
auto bitwise_and = m.get_match_root ();
78
66
if (!bitwise_and) {
79
67
return false ;
80
68
}
81
69
const auto & pattern_map = m.get_pattern_value_map ();
82
- const auto & input_node = pattern_map.at (unsqueeze_1).get_node_shared_ptr ();
83
- auto weights_u32 = std::dynamic_pointer_cast<v0::Constant>(input_node->get_input_node_shared_ptr (0 ));
84
- auto axis = std::dynamic_pointer_cast<v0::Constant>(input_node->get_input_node_shared_ptr (1 ));
85
- auto axis_data = axis->get_data_ptr <uint32_t >();
86
-
87
- auto u8_shape = weights_u32->get_shape ();
88
- auto src = weights_u32->get_data_ptr <uint32_t >();
89
-
90
- ov::Shape u4_shape;
91
- bool dim_added = false ;
92
- size_t stride = 1 ;
93
- size_t size_y = 1 ;
94
- for (size_t i = 0 ; i < u8_shape.size (); i++) {
95
- if (axis_data[0 ] == i) {
96
- u4_shape.push_back (8 );
97
- dim_added = true ;
98
- }
99
- if (axis_data[0 ] <= i) {
100
- stride *= u8_shape[i];
101
- } else {
102
- size_y *= u8_shape[i];
103
- }
104
- u4_shape.push_back (u8_shape[i]);
70
+ auto unsqueeze_1_node = pattern_map.at (unsqueeze_1).get_node_shared_ptr ();
71
+ auto unsqueeze_1_in0_const =
72
+ std::dynamic_pointer_cast<v0::Constant>(unsqueeze_1_node->get_input_node_shared_ptr (0 ));
73
+ auto unsqueeze_1_in1_const =
74
+ std::dynamic_pointer_cast<v0::Constant>(unsqueeze_1_node->get_input_node_shared_ptr (1 ));
75
+ auto abs_node = pattern_map.at (abs ).get_node_shared_ptr ();
76
+ auto abs_in_const = std::dynamic_pointer_cast<v0::Constant>(abs_node->get_input_node_shared_ptr (0 ));
77
+ auto broadcast_node = pattern_map.at (broadcast).get_node_shared_ptr ();
78
+ auto unsqueeze_2_node = pattern_map.at (unsqueeze_2).get_node_shared_ptr ();
79
+ auto unsqueeze_2_in0_const =
80
+ std::dynamic_pointer_cast<v0::Constant>(unsqueeze_2_node->get_input_node_shared_ptr (0 ));
81
+ auto unsqueeze_2_in1_const =
82
+ std::dynamic_pointer_cast<v0::Constant>(unsqueeze_2_node->get_input_node_shared_ptr (1 ));
83
+
84
+ OutputVector outputs_1 (unsqueeze_1_node->get_output_size ());
85
+ OutputVector unsqueeze_1_inputs (2 );
86
+ unsqueeze_1_inputs[0 ] = unsqueeze_1_in0_const->outputs ()[0 ];
87
+ unsqueeze_1_inputs[1 ] = unsqueeze_1_in1_const->outputs ()[0 ];
88
+ if (!unsqueeze_1_node->constant_fold (outputs_1, unsqueeze_1_inputs)) {
89
+ return false ;
105
90
}
106
- if (!dim_added) {
107
- u4_shape.push_back (8 );
91
+
92
+ OutputVector outputs_2 (abs_node->get_output_size ());
93
+ if (!abs_node->constant_fold (outputs_2, abs_in_const->outputs ())) {
94
+ return false ;
108
95
}
109
96
110
- auto new_const = std::make_shared<v0::Constant>(element::u4, u4_shape);
111
- auto dst = const_cast <uint32_t *>(reinterpret_cast <const uint32_t *>(new_const->get_data_ptr ()));
97
+ OutputVector outputs_3 (broadcast_node->get_output_size ());
98
+ OutputVector broadcast_inputs (2 );
99
+ broadcast_inputs[0 ] = outputs_1[0 ];
100
+ broadcast_inputs[1 ] = outputs_2[0 ];
101
+ if (!broadcast_node->constant_fold (outputs_3, broadcast_inputs)) {
102
+ return false ;
103
+ }
104
+
105
+ OutputVector outputs_4 (unsqueeze_2_node->get_output_size ());
106
+ OutputVector unsqueeze_2_inputs (2 );
107
+ unsqueeze_2_inputs[0 ] = unsqueeze_2_in0_const->outputs ()[0 ];
108
+ unsqueeze_2_inputs[1 ] = unsqueeze_2_in1_const->outputs ()[0 ];
109
+ if (!unsqueeze_2_node->constant_fold (outputs_4, unsqueeze_2_inputs)) {
110
+ return false ;
111
+ }
112
+ const int32_t * rs_in0 =
113
+ std::dynamic_pointer_cast<v0::Constant>(outputs_3[0 ].get_node_shared_ptr ())->get_data_ptr <int32_t >();
114
+ const int32_t * rs_in1 =
115
+ std::dynamic_pointer_cast<v0::Constant>(outputs_4[0 ].get_node_shared_ptr ())->get_data_ptr <int32_t >();
116
+ auto shifted_const = std::make_shared<v0::Constant>(element::i32, outputs_3[0 ].get_shape ());
117
+ auto dst = const_cast <int32_t *>(reinterpret_cast <const int32_t *>(shifted_const->get_data_ptr ()));
112
118
if (!dst)
113
119
return false ;
114
120
115
- size_t in_idx = 0 ;
116
- for (size_t y = 0 ; y < size_y; y++) {
117
- size_t offset = y * stride * 8 ;
118
- for (size_t x = 0 ; x < stride; x++) {
119
- for (size_t z = 0 ; z < 8 ; z++) {
120
- uint32_t val = read_u4_data (src, in_idx);
121
- write_u4_data (dst, (offset + x + stride * z), val);
122
- in_idx++;
123
- }
121
+ // TODO: Bitwise right shift operation below might need to be
122
+ // optimized to reduce FIL.
123
+ size_t rs_in0_shape_size = shape_size (outputs_3[0 ].get_shape ());
124
+ const auto & rs_in0_shape = outputs_3[0 ].get_shape ();
125
+ const auto & rs_in1_shape = outputs_4[0 ].get_shape ();
126
+ int shift_dim = -1 ;
127
+ size_t shift_offset = 1 ;
128
+ for (size_t i = 0 ; i < rs_in1_shape.size (); ++i) {
129
+ size_t dim = rs_in1_shape[i];
130
+ if (dim != 1 && dim != rs_in0_shape[i]) {
131
+ return false ;
132
+ }
133
+ if (shift_dim != -1 ) {
134
+ shift_offset *= rs_in0_shape[i];
135
+ }
136
+ if (dim == rs_in0_shape[i]) {
137
+ shift_dim = static_cast <int >(i);
138
+ }
139
+ }
140
+ if (shift_dim == -1 )
141
+ return false ;
142
+ for (size_t k = 0 ; k < rs_in0_shape_size; ++k) {
143
+ size_t shift_idx = (k / shift_offset) % rs_in1_shape[shift_dim];
144
+ int32_t shift_val = rs_in1[shift_idx];
145
+ dst[k] = (rs_in0[k] >> shift_val);
146
+ }
147
+
148
+ std::shared_ptr<ov::Node> convert_1_node = nullptr ;
149
+ OutputVector outputs_7;
150
+ if (pattern_map.find (convert_1) != pattern_map.end ()) {
151
+ convert_1_node = pattern_map.at (convert_1).get_node_shared_ptr ();
152
+ outputs_7.resize (convert_1_node->get_output_size ());
153
+ if (!convert_1_node->constant_fold (outputs_7, shifted_const->outputs ())) {
154
+ return false ;
155
+ }
156
+ } else {
157
+ auto convert_3_node = pattern_map.at (convert_3).get_node_shared_ptr ();
158
+ auto convert_4_node = pattern_map.at (convert_4).get_node_shared_ptr ();
159
+ auto convert_4_in_const =
160
+ std::dynamic_pointer_cast<v0::Constant>(convert_4_node->get_input_node_shared_ptr (0 ));
161
+ auto add_node = pattern_map.at (add).get_node_shared_ptr ();
162
+ OutputVector outputs_5 (convert_3_node->get_output_size ());
163
+ if (!convert_3_node->constant_fold (outputs_5, shifted_const->outputs ())) {
164
+ return false ;
165
+ }
166
+ OutputVector outputs_6 (convert_4_node->get_output_size ());
167
+ if (!convert_4_node->constant_fold (outputs_6, convert_4_in_const->outputs ())) {
168
+ return false ;
169
+ }
170
+ outputs_7.resize (add_node->get_output_size ());
171
+ OutputVector add_inputs (2 );
172
+ add_inputs[0 ] = outputs_5[0 ];
173
+ add_inputs[1 ] = outputs_6[0 ];
174
+ if (!add_node->constant_fold (outputs_7, add_inputs)) {
175
+ return false ;
124
176
}
125
177
}
126
178
127
- copy_runtime_info_and_name (weights_u32, {new_const}, {weights_u32, bitwise_and});
179
+ auto convert_2_node = pattern_map.at (convert_2).get_node_shared_ptr ();
180
+ auto convert_2_in_const = std::dynamic_pointer_cast<v0::Constant>(convert_2_node->get_input_node_shared_ptr (0 ));
181
+
182
+ OutputVector outputs_8 (convert_2_node->get_output_size ());
183
+ if (!convert_2_node->constant_fold (outputs_8, convert_2_in_const->outputs ())) {
184
+ return false ;
185
+ }
186
+
187
+ OutputVector outputs_9 (bitwise_and->get_output_size ());
188
+
189
+ const int8_t * and_in0 =
190
+ std::dynamic_pointer_cast<v0::Constant>(outputs_7[0 ].get_node_shared_ptr ())->get_data_ptr <int8_t >();
191
+ const int8_t * and_in1 =
192
+ std::dynamic_pointer_cast<v0::Constant>(outputs_8[0 ].get_node_shared_ptr ())->get_data_ptr <int8_t >();
193
+ auto masked_const = std::make_shared<v0::Constant>(element::i8, outputs_7[0 ].get_shape ());
194
+ auto masked_dst = const_cast <int8_t *>(reinterpret_cast <const int8_t *>(masked_const->get_data_ptr ()));
195
+ if (!masked_dst)
196
+ return false ;
197
+
198
+ size_t and_in0_shape_size = shape_size (outputs_7[0 ].get_shape ());
199
+ // TODO: Bitwise and operation below might need to be
200
+ // optimized to reduce FIL.
201
+ int8_t mask = and_in1[0 ];
202
+ for (size_t k = 0 ; k < and_in0_shape_size; ++k) {
203
+ masked_dst[k] = (and_in0[k] & mask);
204
+ }
205
+
206
+ auto convert_to_u4 = std::make_shared<v0::Convert>(masked_const, element::u4);
207
+ OutputVector outputs_10 (convert_to_u4->get_output_size ());
208
+ if (!convert_to_u4->constant_fold (outputs_10, masked_const->outputs ())) {
209
+ return false ;
210
+ }
128
211
129
- auto new_convert = std::make_shared<v0::Convert>(new_const, bitwise_and->get_output_element_type (0 ));
130
- copy_runtime_info_and_name (bitwise_and, {new_convert}, {input_node});
212
+ auto new_convert =
213
+ std::make_shared<v0::Convert>(outputs_10[0 ].get_node_shared_ptr (), bitwise_and->get_output_element_type (0 ));
214
+ copy_runtime_info_and_name (bitwise_and, {new_convert}, {unsqueeze_1_node});
131
215
replace_node (bitwise_and, new_convert);
132
216
return true ;
133
217
};
0 commit comments