Skip to content

Commit 6bc933a

Browse files
author
Kadian
committed
Added a new pattern in pattern matcher
1 parent 70b8346 commit 6bc933a

File tree

3 files changed

+89
-0
lines changed

3 files changed

+89
-0
lines changed

src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -1622,6 +1622,8 @@ void Partitioner::decompressionCutOff(const std::string& func_name) {
16221622
// LLaMaGPTQ
16231623
rewr.add_matcher<ov::npuw::patterns::SymmZP::DCOFFPassReshape2>(dcoff_mode, dcoff_type, std::ref(params_to));
16241624

1625+
rewr.add_matcher<ov::npuw::patterns::SymmZP::DCOFFPassCWAI3>(dcoff_mode, dcoff_type, std::ref(params_to));
1626+
16251627
rewr.run_on_model(f._model);
16261628

16271629
ov::pass::Validate val;

src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.cpp

+82
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,88 @@ DCOFFPassReshape2::DCOFFPassReshape2(DCOffMode dcoff_mode, ov::element::Type dco
512512
register_matcher(std::make_shared<opp::Matcher>(reshpe, "TagDCOFFReshape2"), std::move(callback));
513513
}
514514

515+
// Pattern: Phi-3 4SymW16A/GPTQ for CWAI
516+
//
517+
// FIXME: Think how it can be unified with the above
518+
//
519+
// "tensor" "scale" > "tensor"
520+
// Param:A Param:C > Param:A
521+
// i4 f16|f32 > f16
522+
// : : > :
523+
// V : > V
524+
// Convert : > Convert
525+
// f16|f32 : > f32
526+
// : : >
527+
// V V >
528+
// Multiply >
529+
// f16|f32 >
530+
// : >
531+
// : >
532+
// V >
533+
// Convert
534+
535+
DCOFFPassCWAI3::DCOFFPassCWAI3(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref) {
536+
auto paramA = opp::wrap_type<ov::op::v0::Parameter>();
537+
auto paramC = opp::wrap_type<ov::op::v0::Parameter>();
538+
auto cvtA = opp::wrap_type<ov::op::v0::Convert>({paramA});
539+
auto mulply = opp::wrap_type<ov::op::v1::Multiply>({cvtA, paramC});
540+
auto cvt = opp::wrap_type<ov::op::v0::Convert>({mulply});
541+
542+
auto callback = [=](ov::pass::pattern::Matcher& m) {
543+
auto& node_to_output = m.get_pattern_value_map();
544+
auto matched_nodeA = node_to_output.at(paramA).get_node_shared_ptr();
545+
auto matched_nodeC = node_to_output.at(paramC).get_node_shared_ptr();
546+
547+
NPUW_ASSERT(ov::op::util::is_parameter(matched_nodeA));
548+
NPUW_ASSERT(ov::op::util::is_parameter(matched_nodeC));
549+
550+
auto matched_paramA = std::static_pointer_cast<ov::op::v0::Parameter>(matched_nodeA);
551+
auto matched_paramC = std::static_pointer_cast<ov::op::v0::Parameter>(matched_nodeC);
552+
553+
if (ov::element::i4 == matched_paramA->get_element_type() &&
554+
(ov::element::f16 == matched_paramC->get_element_type() ||
555+
ov::element::f32 == matched_paramC->get_element_type())) {
556+
LOG_DEBUG("Matched: " << matched_paramA << ", set element type to " << dcoff_type);
557+
matched_paramA->set_element_type(dcoff_type);
558+
559+
if (dcoff_mode == DCOffMode::CAST_SCALE) {
560+
NPUW_ASSERT(dcoff_type == ov::element::f16);
561+
562+
LOG_DEBUG("Matched: " << matched_paramC << " - parameter to remove...");
563+
LOG_BLOCK();
564+
565+
// Extra transformation here:
566+
// - remove Multiply + Intermediate Convert
567+
// - mark paramC for removal.
568+
// Convert will be reconnected to paramA directly.
569+
570+
pref.get().scales[matched_paramC] = std::move(matched_paramA);
571+
// Disconnect Multiply and Convert from their outputs
572+
auto matched_mulply = node_to_output.at(mulply).get_node_shared_ptr();
573+
auto matched_convrt = node_to_output.at(cvtA).get_node_shared_ptr();
574+
auto drop_outputs = [](std::shared_ptr<ov::Node> node) {
575+
for (auto&& node_outputs : node->outputs()) {
576+
for (auto&& node_reader_port : node_outputs.get_target_inputs()) {
577+
node_outputs.remove_target_input(node_reader_port);
578+
}
579+
}
580+
};
581+
LOG_DEBUG("Dropping the connections...");
582+
drop_outputs(std::move(matched_mulply));
583+
drop_outputs(std::move(matched_convrt));
584+
585+
LOG_DEBUG("Reconnecting the Root...");
586+
auto matched_cvt = node_to_output.at(cvt).get_node_shared_ptr();
587+
matched_cvt->input(0).replace_source_output(matched_paramA);
588+
}
589+
LOG_DEBUG("Done");
590+
}
591+
return false; // root node hasn't changed
592+
};
593+
594+
register_matcher(std::make_shared<opp::Matcher>(cvt, "TagDCOFFPassCWAI3"), std::move(callback));
595+
}
596+
515597
//------------------------------------------------------------------------------
516598
// Pattern: 4SymW16A for CWAI
517599
//

src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.hpp

+5
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ class DCOFFPassReshape2 : public ov::pass::MatcherPass {
129129
DCOFFPassReshape2(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref);
130130
};
131131

132+
class DCOFFPassCWAI3 : public ov::pass::MatcherPass {
133+
public:
134+
DCOFFPassCWAI3(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref);
135+
};
136+
132137
class CWAI1 : public ov::pass::MatcherPass {
133138
public:
134139
using CPtr = std::shared_ptr<ov::op::v0::Constant>;

0 commit comments

Comments
 (0)