Skip to content

Commit 6f98a27

Browse files
NPUW: Adding a new dcoff pattern (#25938)
### Details: - Implemented a new pattern in continuation of the PR: [PR:2587](#25827). ### Tickets: - *121052* Co-authored-by: Dmitry Matveev <dmitry.matveev@intel.com>
1 parent cf6cb43 commit 6f98a27

File tree

3 files changed

+91
-0
lines changed

3 files changed

+91
-0
lines changed

src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -1624,6 +1624,9 @@ void Partitioner::decompressionCutOff(const std::string& func_name) {
16241624
// LLaMaGPTQ
16251625
rewr.add_matcher<ov::npuw::patterns::SymmZP::DCOFFPassReshape2>(dcoff_mode, dcoff_type, std::ref(params_to));
16261626

1627+
// Phi-3 4SymW16A/GPTQ
1628+
rewr.add_matcher<ov::npuw::patterns::SymmZP::DCOFFPassCWAI3>(dcoff_mode, dcoff_type, std::ref(params_to));
1629+
16271630
rewr.run_on_model(f._model);
16281631

16291632
ov::pass::Validate val;

src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.cpp

+83
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,89 @@ DCOFFPassReshape2::DCOFFPassReshape2(DCOffMode dcoff_mode, ov::element::Type dco
512512
register_matcher(std::make_shared<opp::Matcher>(reshpe, "TagDCOFFReshape2"), std::move(callback));
513513
}
514514

515+
// Pattern: Phi-3 4SymW16A/GPTQ
516+
//
517+
//
518+
// "tensor" "scale" > "tensor"
519+
// Param:A Param:C > Param:A
520+
// i4 f16|f32 > f16
521+
// : : > :
522+
// V : > V
523+
// Convert : > Convert
524+
// f16|f32 : > f32
525+
// : : >
526+
// V V >
527+
// Multiply >
528+
// f16|f32 >
529+
// : >
530+
// : >
531+
// V >
532+
// Convert
533+
534+
DCOFFPassCWAI3::DCOFFPassCWAI3(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref) {
535+
auto paramA = opp::wrap_type<ov::op::v0::Parameter>();
536+
auto paramC = opp::wrap_type<ov::op::v0::Parameter>();
537+
auto cvtA = opp::wrap_type<ov::op::v0::Convert>({paramA});
538+
auto mulply = opp::wrap_type<ov::op::v1::Multiply>({cvtA, paramC});
539+
auto cvt = opp::wrap_type<ov::op::v0::Convert>({mulply});
540+
541+
auto callback = [=](ov::pass::pattern::Matcher& m) {
542+
auto& node_to_output = m.get_pattern_value_map();
543+
auto matched_nodeA = node_to_output.at(paramA).get_node_shared_ptr();
544+
auto matched_nodeC = node_to_output.at(paramC).get_node_shared_ptr();
545+
546+
NPUW_ASSERT(ov::op::util::is_parameter(matched_nodeA));
547+
NPUW_ASSERT(ov::op::util::is_parameter(matched_nodeC));
548+
549+
auto matched_paramA = std::static_pointer_cast<ov::op::v0::Parameter>(matched_nodeA);
550+
auto matched_paramC = std::static_pointer_cast<ov::op::v0::Parameter>(matched_nodeC);
551+
552+
if (ov::element::i4 == matched_paramA->get_element_type() &&
553+
(ov::element::f16 == matched_paramC->get_element_type() ||
554+
ov::element::f32 == matched_paramC->get_element_type())) {
555+
LOG_DEBUG("Matched: " << matched_paramA << ", set element type to " << dcoff_type);
556+
matched_paramA->set_element_type(dcoff_type);
557+
558+
if (dcoff_mode == DCOffMode::CAST_SCALE) {
559+
NPUW_ASSERT(dcoff_type == ov::element::f16);
560+
561+
LOG_DEBUG("Matched: " << matched_paramC << " - parameter to remove...");
562+
LOG_BLOCK();
563+
564+
// Extra transformation here:
565+
// - remove Multiply + Intermediate Convert
566+
// - mark paramC for removal.
567+
// Convert will be reconnected to paramA directly.
568+
569+
// Record mapping from the Scale coeff parameter to the Real weight parameter
570+
pref.get().scales[matched_paramC] = matched_paramA;
571+
572+
// Disconnect Multiply and Convert from their outputs
573+
auto matched_mulply = node_to_output.at(mulply).get_node_shared_ptr();
574+
auto matched_convrt = node_to_output.at(cvtA).get_node_shared_ptr();
575+
auto drop_outputs = [](std::shared_ptr<ov::Node> node) {
576+
for (auto&& node_outputs : node->outputs()) {
577+
for (auto&& node_reader_port : node_outputs.get_target_inputs()) {
578+
node_outputs.remove_target_input(node_reader_port);
579+
}
580+
}
581+
};
582+
LOG_DEBUG("Dropping the connections...");
583+
drop_outputs(matched_mulply);
584+
drop_outputs(matched_convrt);
585+
586+
LOG_DEBUG("Reconnecting the Root...");
587+
auto matched_cvt = node_to_output.at(cvt).get_node_shared_ptr();
588+
matched_cvt->input(0).replace_source_output(matched_paramA);
589+
}
590+
LOG_DEBUG("Done");
591+
}
592+
return false; // root node hasn't changed
593+
};
594+
595+
register_matcher(std::make_shared<opp::Matcher>(cvt, "TagDCOFFPassCWAI3"), std::move(callback));
596+
}
597+
515598
//------------------------------------------------------------------------------
516599
// Pattern: 4SymW16A for CWAI
517600
//

src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/dcoff.hpp

+5
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ class DCOFFPassReshape2 : public ov::pass::MatcherPass {
129129
DCOFFPassReshape2(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref);
130130
};
131131

132+
class DCOFFPassCWAI3 : public ov::pass::MatcherPass {
133+
public:
134+
DCOFFPassCWAI3(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref);
135+
};
136+
132137
class CWAI1 : public ov::pass::MatcherPass {
133138
public:
134139
using CPtr = std::shared_ptr<ov::op::v0::Constant>;

0 commit comments

Comments
 (0)