Skip to content

Commit c8cee60

Browse files
committed
[CPU][ARM] Weights compression f32->f16 is moved to CPU Plug-in side
1 parent 7cb3bf5 commit c8cee60

File tree

11 files changed

+83
-10
lines changed

11 files changed

+83
-10
lines changed

samples/cpp/benchmark_app/main.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ void fuse_mean_scale(ov::preprocess::PrePostProcessor& preproc, const benchmark_
233233
* @brief The entry point of the benchmark application
234234
*/
235235
int main(int argc, char* argv[]) {
236-
std::shared_ptr<StatisticsReport> statistics;
236+
std::shared_ptr<StatisticsReport> statistics;
237237
try {
238238
ov::CompiledModel compiledModel;
239239

src/common/transformations/include/transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp

+7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ namespace pass {
1414
class TRANSFORMATIONS_API EnableDecompressionConvertConstantFolding;
1515
class TRANSFORMATIONS_API DisableDecompressionConvertConstantFolding;
1616
class TRANSFORMATIONS_API KeepConstAndDecompression;
17+
class TRANSFORMATIONS_API KeepConstFP32Unfolded;
1718
class TRANSFORMATIONS_API KeepConstantsPrecisionAndAddConverts;
1819

1920
} // namespace pass
@@ -49,6 +50,12 @@ class ov::pass::KeepConstAndDecompression : public MatcherPass {
4950
KeepConstAndDecompression();
5051
};
5152

53+
class ov::pass::KeepConstFP32Unfolded : public MatcherPass {
54+
public:
55+
OPENVINO_RTTI("KeepConstFP32Unfolded", "0");
56+
KeepConstFP32Unfolded();
57+
};
58+
5259
/**
5360
* @ingroup ie_transformation_common_api
5461
* @brief Prevents Consts precision conversion and adds Convert with disabled ConstantFolding

src/common/transformations/include/transformations/rt_info/decompression.hpp

+21
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ TRANSFORMATIONS_API void unmark_as_decompression(const std::shared_ptr<Node>& no
2323

2424
TRANSFORMATIONS_API bool is_decompression(const std::shared_ptr<Node>& node);
2525

26+
TRANSFORMATIONS_API void mark_as_compression(const std::shared_ptr<Node>& node);
27+
28+
TRANSFORMATIONS_API void unmark_as_compression(const std::shared_ptr<Node>& node);
29+
30+
TRANSFORMATIONS_API bool is_compression(const std::shared_ptr<Node>& node);
31+
2632
/**
2733
* @ingroup ie_runtime_attr_api
2834
* @brief Decompression class represents runtime info attribute that marks operation
@@ -43,4 +49,19 @@ class TRANSFORMATIONS_API Decompression : public RuntimeAttribute {
4349
}
4450
};
4551

52+
class TRANSFORMATIONS_API Compression : public RuntimeAttribute {
53+
public:
54+
OPENVINO_RTTI("Compression", "0");
55+
56+
Compression() = default;
57+
58+
bool visit_attributes(AttributeVisitor& visitor) override {
59+
return true;
60+
}
61+
62+
bool is_copyable() const override {
63+
return false;
64+
}
65+
};
66+
4667
} // namespace ov

src/common/transformations/src/transformations/fp16_compression/align_mixed_fp32_fp16_types.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "openvino/op/util/precision_sensitive_attribute.hpp"
1212
#include "openvino/pass/constant_folding.hpp"
1313
#include "transformations/rt_info/disable_fp16_compression.hpp"
14+
#include "transformations/rt_info/decompression.hpp"
1415

1516
using namespace ov;
1617

@@ -48,6 +49,7 @@ bool ov::pass::AlignMixedFP32FP16Types::run_on_model(const std::shared_ptr<ov::M
4849
copy_runtime_info(incoming_node, convert);
4950
input.replace_source_output(convert);
5051
disable_fp16_compression(convert);
52+
mark_as_compression(convert);
5153
pass::disable_constant_folding(convert);
5254
is_changed = true;
5355
}
@@ -76,6 +78,7 @@ bool ov::pass::AlignMixedFP32FP16Types::run_on_model(const std::shared_ptr<ov::M
7678
auto init_name = node->get_friendly_name() + "_compressed_to_f16";
7779
convert->set_friendly_name(generate_uniq_name(init_name));
7880
out_inputs.replace_source_output(convert);
81+
mark_as_compression(convert);
7982
pass::disable_constant_folding(convert);
8083
is_changed = true;
8184
}

src/common/transformations/src/transformations/fp16_compression/mark_decompression_convert_constant_folding.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,32 @@ pass::KeepConstAndDecompression::KeepConstAndDecompression() {
7777
register_matcher(m, callback);
7878
}
7979

80+
pass::KeepConstFP32Unfolded::KeepConstFP32Unfolded() {
81+
MATCHER_SCOPE(KeepConstFP16Unfolded);
82+
83+
auto node_pattern = pattern::wrap_type<ov::op::v0::MatMul>();
84+
85+
matcher_pass_callback callback = [=](pattern::Matcher& m) {
86+
auto node = m.get_match_root();
87+
88+
if (transformation_callback(node)) {
89+
return false;
90+
}
91+
92+
auto constNode = node->get_input_node_shared_ptr(1);
93+
if (!is_type<ov::op::v0::Constant>(constNode) || constNode->get_output_element_type(0) != element::f32)
94+
return false;
95+
96+
disable_constant_folding(constNode);
97+
enable_keep_const_precision(constNode);
98+
disable_fp16_compression(constNode);
99+
100+
return false;
101+
};
102+
auto m = std::make_shared<pattern::Matcher>(node_pattern, matcher_name);
103+
register_matcher(m, callback);
104+
}
105+
80106
pass::KeepConstantsPrecisionAndAddConverts::KeepConstantsPrecisionAndAddConverts() {
81107
MATCHER_SCOPE(KeepConstantsPrecisionAndAddConverts);
82108
auto const_pattern = pattern::wrap_type<ov::op::v0::Constant>();

src/common/transformations/src/transformations/fp16_compression/mark_subgraphs_to_keep_in_mixed_precision.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -432,16 +432,16 @@ bool MarkSugraphsToKeepInMixedPrecision::run_on_model(const shared_ptr<ov::Model
432432
Manager manager(get_pass_config());
433433
// Mark root of Division with eps pattern to keep in FP32
434434
REGISTER_PASS(manager, MarkDivWithEps)
435-
REGISTER_PASS(manager, MarkExpInReduceOpPath)
436-
REGISTER_PASS(manager, PropagateDownDisableSensitivityForQuantized)
437-
435+
REGISTER_PASS(manager, MarkExpInReduceOpPath)
436+
REGISTER_PASS(manager, PropagateDownDisableSensitivityForQuantized)
437+
438438
// both Up and Down propagations are needed.
439439
// Why both of them are needed is explained in comments in passes declarations.
440440
REGISTER_PASS(manager, PropagateDownMarkToKeepInMixedPrecision)
441-
441+
442442
auto propagate_up = manager.register_pass<BackwardGraphRewrite>();
443443
ADD_MATCHER(propagate_up, PropagateUpMarkToKeepInMixedPrecision)
444-
444+
445445
// Mark nodes in ShapeOf subgraphs to keep in FP32
446446
REGISTER_PASS(manager, MarkPrecisionSensitiveShapeOfSubgraphs)
447447
manager.run_passes(m);

src/common/transformations/src/transformations/rt_info/decompression.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,18 @@ bool ov::is_decompression(const std::shared_ptr<Node>& node) {
1818
const auto& rt_info = node->get_rt_info();
1919
return rt_info.count(Decompression::get_type_info_static());
2020
}
21+
22+
void ov::mark_as_compression(const std::shared_ptr<Node>& node) {
23+
auto& rt_info = node->get_rt_info();
24+
rt_info[Compression::get_type_info_static()] = Compression();
25+
}
26+
27+
void ov::unmark_as_compression(const std::shared_ptr<Node>& node) {
28+
auto& rt_info = node->get_rt_info();
29+
rt_info.erase(Compression::get_type_info_static());
30+
}
31+
32+
bool ov::is_compression(const std::shared_ptr<Node>& node) {
33+
const auto& rt_info = node->get_rt_info();
34+
return rt_info.count(Compression::get_type_info_static());
35+
}

src/plugins/intel_cpu/src/graph_optimizer.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -921,8 +921,8 @@ void GraphOptimizer::FuseFCAndConvertOnWeights(Graph& graph) {
921921
&& parent->getChildEdges().size() == 1
922922
&& parent->getChildEdgeAt(0)->getOutputNum() == 1
923923
&& parent->getChildEdgeAt(0)->getChild()->getType() == Type::FullyConnected
924-
&& one_of(parent->getOriginalInputPrecisionAtPort(0), Precision::FP16)
925-
&& one_of(parent->getOriginalOutputPrecisionAtPort(0), Precision::FP32, Precision::BF16)
924+
&& one_of(parent->getOriginalInputPrecisionAtPort(0), Precision::FP32, Precision::BF16, Precision::FP16)
925+
&& one_of(parent->getOriginalOutputPrecisionAtPort(0), Precision::FP32, Precision::BF16, Precision::FP16)
926926
&& parent->isConstant();
927927
return res;
928928
};

src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/convert_matmul_to_fc.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
3737
auto fc_input_b = pattern_map.at(weights_m);
3838
bool is_convert = false;
3939
if (auto convert_node = std::dynamic_pointer_cast<ov::op::v0::Convert>(fc_input_b.get_node_shared_ptr())) {
40-
if (is_decompression(convert_node)) {
40+
if (is_decompression(convert_node) || fp16_compression_is_disabled(convert_node) || is_compression(convert_node)) {
4141
is_convert = true;
4242
fc_input_b = convert_node->get_input_node_shared_ptr(0);
4343
} else {

src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
291291
// It cannot be static data, because it may be difference for different inferencePrecision
292292
const auto precisions = get_convert_precisions();
293293
if (inferencePrecision == ov::element::f16) {
294+
CPU_REGISTER_PASS_ARM(manager, ov::pass::KeepConstFP32Unfolded);
294295
precisions_map fp_convert_precision_map = {{ov::element::f32, ov::element::f16}};
295296
type_to_fuse_map empty_fuse_map = {};
296297
const bool keep_precision_sensitive_in_fp32 = true;

0 commit comments

Comments
 (0)