Skip to content

Commit 871ab4a

Browse files
authored
[CPU]Gather node executed with f16/bf16 precision (#29239)
### Details: - *Gather node executed with f16/bf16 precision, which accept f16/bf16 input directly without weight decompression. Otherwise weight decompression (convert f16/bf16 to f32) is constant folded and consumes big memory* - [x] - *adding test* ### Tickets: - *CVS-161854*
1 parent e2eff09 commit 871ab4a

File tree

7 files changed

+283
-31
lines changed

7 files changed

+283
-31
lines changed

src/common/transformations/include/transformations/op_conversions/convert_gather_to_compressed.hpp

+46
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,65 @@
44

55
#pragma once
66

7+
#include "openvino/pass/graph_rewrite.hpp"
78
#include "openvino/pass/matcher_pass.hpp"
89
#include "transformations_visibility.hpp"
910

1011
namespace ov {
1112
namespace pass {
1213

14+
class TRANSFORMATIONS_API CompressedGatherTransformation;
1315
class TRANSFORMATIONS_API ConvertGatherToGatherCompressed;
16+
class TRANSFORMATIONS_API MoveDecompressionAfterGather;
1417

1518
} // namespace pass
1619
} // namespace ov
1720

21+
/*
22+
* ConvertGatherToGatherCompressed transform gather node with constant weight decompression pattern(U8/NF4/U4/I4 +
23+
* Subtract + Multiply) to GatherCompressed node, which handle decompression internally.
24+
*
25+
* Subtract_const(U8/NF4/U4/I4)
26+
* /
27+
* Weights(U8/NF4/U4/I4) Convert(F32) Weights Subtract_const Multiply_const Indices
28+
* | / (U8/NF4/U4/I4) (U8/NF4/U4/I4) (F32) (I32)
29+
* Convert(F32) Reshape(optional) \ \ / /
30+
* \ / Multiply_const(F32) ------> \ \ / /
31+
* Subtract(optional) / \ \ / /
32+
* \ Reshape(optional) \ \ / /
33+
* \ / GatherCompressed
34+
* Indices(I32) Multiply
35+
* \ /
36+
* Gather
37+
*/
1838
class ov::pass::ConvertGatherToGatherCompressed : public ov::pass::MatcherPass {
1939
public:
2040
OPENVINO_MATCHER_PASS_RTTI("ConvertGatherToGatherCompressed");
2141
ConvertGatherToGatherCompressed();
2242
};
43+
44+
/*
45+
* MoveDecompressionAfterGather transform gather node with constant weight decompression pattern(FP16/BF16 +
46+
* convert(FP32)) to gather node with compressed(FP16/BF16) weight, and move decompression after gather node.
47+
*
48+
* Weights(FP16/BF16) Weights(FP16/BF16) Indices(I32)
49+
* | \ /
50+
* Convert(F32) Indices(I32) ------> Gather(FP16/BF16)
51+
* \ / |
52+
* Gather(F32) Convert(F32)
53+
*
54+
*/
55+
class ov::pass::MoveDecompressionAfterGather : public ov::pass::MatcherPass {
56+
public:
57+
OPENVINO_MATCHER_PASS_RTTI("MoveDecompressionAfterGather");
58+
MoveDecompressionAfterGather();
59+
};
60+
61+
class ov::pass::CompressedGatherTransformation : public ov::pass::GraphRewrite {
62+
public:
63+
OPENVINO_GRAPH_REWRITE_RTTI("CompressedGatherTransformation");
64+
CompressedGatherTransformation() {
65+
add_matcher<ov::pass::ConvertGatherToGatherCompressed>();
66+
add_matcher<ov::pass::MoveDecompressionAfterGather>();
67+
}
68+
};

src/common/transformations/src/transformations/op_conversions/convert_gather_to_compressed.cpp

+38
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "openvino/pass/pattern/op/pattern.hpp"
1717
#include "openvino/pass/pattern/op/wrap_type.hpp"
1818
#include "ov_ops/gather_compressed.hpp"
19+
#include "transformations/rt_info/keep_const_precision.hpp"
1920
#include "transformations/utils/utils.hpp"
2021

2122
ov::pass::ConvertGatherToGatherCompressed::ConvertGatherToGatherCompressed() {
@@ -146,3 +147,40 @@ ov::pass::ConvertGatherToGatherCompressed::ConvertGatherToGatherCompressed() {
146147
auto m = std::make_shared<ov::pass::pattern::Matcher>(gather_m, "ConvertGatherToGatherCompressed");
147148
this->register_matcher(m, callback);
148149
}
150+
151+
ov::pass::MoveDecompressionAfterGather::MoveDecompressionAfterGather() {
152+
using namespace ov::pass::pattern;
153+
154+
auto dicts = wrap_type<ov::op::v0::Constant>(pattern::type_matches_any({element::f16, element::bf16}));
155+
auto convert_predicate = [](ov::Output<ov::Node> output) -> bool {
156+
return pattern::consumers_count(1)(output) && pattern::type_matches(ov::element::f32)(output);
157+
};
158+
auto convert = wrap_type<ov::op::v0::Convert>({dicts}, convert_predicate);
159+
auto gather = wrap_type<ov::op::v8::Gather>({convert, any_input(), wrap_type<ov::op::v0::Constant>()});
160+
161+
ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
162+
const auto& pattern_map = m.get_pattern_value_map();
163+
auto constant_node = pattern_map.at(dicts).get_node_shared_ptr();
164+
auto convert_node = pattern_map.at(convert).get_node_shared_ptr();
165+
auto gather_node = pattern_map.at(gather).get_node_shared_ptr();
166+
if (transformation_callback(gather_node)) {
167+
return false;
168+
}
169+
170+
auto new_gather = gather_node->clone_with_new_inputs(
171+
{constant_node, gather_node->get_input_source_output(1), gather_node->get_input_source_output(2)});
172+
auto new_convert = convert_node->clone_with_new_inputs({new_gather});
173+
register_new_node(new_gather);
174+
register_new_node(new_convert);
175+
176+
ov::enable_keep_const_precision(constant_node);
177+
178+
new_convert->set_friendly_name(gather_node->get_friendly_name());
179+
ov::copy_runtime_info({convert_node, gather_node}, {new_gather, new_convert});
180+
replace_node(gather_node, new_convert);
181+
return true;
182+
};
183+
184+
auto m = std::make_shared<ov::pass::pattern::Matcher>(gather, "MoveDecompressionAfterGather");
185+
this->register_matcher(m, callback);
186+
}

src/common/transformations/tests/op_conversions/convert_gather_to_compressed_test.cpp

+47
Original file line numberDiff line numberDiff line change
@@ -225,3 +225,50 @@ TEST_F(TransformationTestsF, ConvertGatherToCompressedMultiOutput) {
225225
model_ref = std::make_shared<ov::Model>(ov::NodeVector{gather_compressed}, ov::ParameterVector{input1, input2});
226226
}
227227
}
228+
229+
// In compressed FP16/BF16 weight case, gather node with constant weight decompression pattern (FP16/BF16 +
230+
// convert(FP32)) is transformed to gather node with compressed (FP16/BF16) weights, and decompression convert is moved
231+
// after gather node, so GatherCompressed node should not be generated.
232+
TEST_F(TransformationTestsF, MoveDecompressionAfterGatherFP16Weight) {
233+
{
234+
auto input1 = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::PartialShape{-1, 16});
235+
auto axis_const = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {1});
236+
auto weights_const = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{32, 16}, {1});
237+
auto convert = std::make_shared<ov::op::v0::Convert>(weights_const, ov::element::f32);
238+
auto gather = std::make_shared<ov::op::v8::Gather>(convert, input1, axis_const);
239+
240+
model = std::make_shared<ov::Model>(ov::NodeVector{gather}, ov::ParameterVector{input1});
241+
manager.register_pass<MoveDecompressionAfterGather>();
242+
}
243+
{
244+
auto input1 = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::PartialShape{-1, 16});
245+
auto axis_const = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {1});
246+
auto weights_const = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{32, 16}, {1});
247+
auto gather = std::make_shared<ov::op::v8::Gather>(weights_const, input1, axis_const);
248+
auto convert = std::make_shared<ov::op::v0::Convert>(gather, ov::element::f32);
249+
250+
model_ref = std::make_shared<ov::Model>(ov::NodeVector{convert}, ov::ParameterVector{input1});
251+
}
252+
}
253+
254+
TEST_F(TransformationTestsF, MoveDecompressionAfterGatherBF16Weight) {
255+
{
256+
auto input1 = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::PartialShape{-1, 16});
257+
auto axis_const = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {1});
258+
auto weights_const = ov::op::v0::Constant::create(ov::element::bf16, ov::Shape{32, 16}, {1});
259+
auto convert = std::make_shared<ov::op::v0::Convert>(weights_const, ov::element::f32);
260+
auto gather = std::make_shared<ov::op::v8::Gather>(convert, input1, axis_const);
261+
262+
model = std::make_shared<ov::Model>(ov::NodeVector{gather}, ov::ParameterVector{input1});
263+
manager.register_pass<MoveDecompressionAfterGather>();
264+
}
265+
{
266+
auto input1 = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::PartialShape{-1, 16});
267+
auto axis_const = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {1});
268+
auto weights_const = ov::op::v0::Constant::create(ov::element::bf16, ov::Shape{32, 16}, {1});
269+
auto gather = std::make_shared<ov::op::v8::Gather>(weights_const, input1, axis_const);
270+
auto convert = std::make_shared<ov::op::v0::Convert>(gather, ov::element::f32);
271+
272+
model_ref = std::make_shared<ov::Model>(ov::NodeVector{convert}, ov::ParameterVector{input1});
273+
}
274+
}

src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
359359
decompression_handling_manager.set_per_pass_validation(false);
360360
CPU_REGISTER_PASS_COMMON(decompression_handling_manager, ov::pass::InitNodeInfo);
361361
const bool useLpt = !defaultPrecisions.empty();
362-
CPU_REGISTER_PASS_COMMON(decompression_handling_manager, ov::pass::ConvertGatherToGatherCompressed);
362+
CPU_REGISTER_PASS_COMMON(decompression_handling_manager, ov::pass::CompressedGatherTransformation);
363363
CPU_REGISTER_PASS_COMMON(decompression_handling_manager, ov::pass::MarkShapeOfSubgraphs);
364364
// We need to fuse Transpose to MatMul to have a simpler callback for the next transformation
365365
CPU_REGISTER_PASS_X64(decompression_handling_manager, ov::pass::TransposeMatMul);

src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/gather_weights_decompression.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,21 @@ INSTANTIATE_TEST_SUITE_P(smoke_GatherCompressedWeights_basic,
4545
::testing::ValuesIn(per_tensor_scale)),
4646
GatherWeightsDecompression::get_test_case_name);
4747

48+
// fp16/bf16 constant + convert(16bit to f32) + gather case
49+
TEST_P(GatherWeightsDecompressionWithoutScale, CompareWithRefs) {
50+
SKIP_IF_CURRENT_TEST_IS_DISABLED()
51+
run();
52+
check_results();
53+
}
54+
const std::vector<ov::element::Type> weights_precisions_wo_scale = {ov::element::f16, ov::element::bf16};
55+
const std::vector<ov::element::Type> output_precisions_wo_scale = {ov::element::f32, ov::element::f16, ov::element::bf16};
56+
57+
INSTANTIATE_TEST_SUITE_P(smoke_GatherCompressedWeightsWithoutScale_basic,
58+
GatherWeightsDecompressionWithoutScale,
59+
::testing::Combine(::testing::Values(ov::test::utils::DEVICE_CPU),
60+
::testing::ValuesIn(input_shapes_basic),
61+
::testing::ValuesIn(weights_precisions_wo_scale),
62+
::testing::ValuesIn(output_precisions_wo_scale)),
63+
GatherWeightsDecompressionWithoutScale::get_test_case_name);
64+
4865
} // namespace

src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/gather_weights_decompression.hpp

+35-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111
namespace ov {
1212
namespace test {
1313

14+
class GatherWeightsDecompressionBase : virtual public ov::test::SubgraphBaseTest {
15+
protected:
16+
void generate_inputs(const std::vector<ov::Shape>& target_input_static_shapes) override;
17+
void check_results(const ov::element::Type& weights_precision, const size_t& num_exec_ops_expect);
18+
};
19+
1420
/*
1521
* Subtract_const(U8/NF4/U4/I4)
1622
* /
@@ -36,7 +42,7 @@ using GatherWeightsDecompressionParams = std::tuple<std::string, // Devic
3642
bool>; // per-tensor zero-point
3743

3844
class GatherWeightsDecompression : public testing::WithParamInterface<GatherWeightsDecompressionParams>,
39-
virtual public ov::test::SubgraphBaseTest {
45+
virtual public ov::test::GatherWeightsDecompressionBase {
4046
public:
4147
static std::string get_test_case_name(testing::TestParamInfo<GatherWeightsDecompressionParams> obj);
4248

@@ -52,10 +58,37 @@ class GatherWeightsDecompression : public testing::WithParamInterface<GatherWeig
5258
const bool reshape_on_decompression,
5359
const bool per_tensor_zp,
5460
const bool per_tensor_scale);
55-
void generate_inputs(const std::vector<ov::Shape>& target_input_static_shapes) override;
5661
void check_results();
5762
void SetUp() override;
5863
};
5964

65+
/*
66+
* Weights(FP16/BF16)
67+
* |
68+
* Convert(F32) Indices(I32)
69+
* \ /
70+
* Gather
71+
*/
72+
using GatherWeightsDecompressionWithoutScaleParams = std::tuple<std::string, // Device name
73+
GatherDecompressionShapeParams,
74+
ov::element::Type, // data type
75+
ov::element::Type>; // output type
76+
77+
class GatherWeightsDecompressionWithoutScale
78+
: public testing::WithParamInterface<GatherWeightsDecompressionWithoutScaleParams>,
79+
virtual public ov::test::GatherWeightsDecompressionBase {
80+
public:
81+
static std::string get_test_case_name(testing::TestParamInfo<GatherWeightsDecompressionWithoutScaleParams> obj);
82+
83+
protected:
84+
std::shared_ptr<ov::Model> init_subgraph(const ov::Shape& data_shape,
85+
const ov::PartialShape& indices_shape,
86+
const int axis,
87+
const int64_t batch_dims,
88+
const ov::element::Type data_precision,
89+
const ov::element::Type output_precision);
90+
void check_results();
91+
void SetUp() override;
92+
};
6093
} // namespace test
6194
} // namespace ov

0 commit comments

Comments
 (0)