@@ -512,6 +512,89 @@ DCOFFPassReshape2::DCOFFPassReshape2(DCOffMode dcoff_mode, ov::element::Type dco
512
512
register_matcher (std::make_shared<opp::Matcher>(reshpe, " TagDCOFFReshape2" ), std::move (callback));
513
513
}
514
514
515
+ // Pattern: Phi-3 4SymW16A/GPTQ
516
+ //
517
+ //
518
+ // "tensor" "scale" > "tensor"
519
+ // Param:A Param:C > Param:A
520
+ // i4 f16|f32 > f16
521
+ // : : > :
522
+ // V : > V
523
+ // Convert : > Convert
524
+ // f16|f32 : > f32
525
+ // : : >
526
+ // V V >
527
+ // Multiply >
528
+ // f16|f32 >
529
+ // : >
530
+ // : >
531
+ // V >
532
+ // Convert
533
+
534
+ DCOFFPassCWAI3::DCOFFPassCWAI3 (DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref) {
535
+ auto paramA = opp::wrap_type<ov::op::v0::Parameter>();
536
+ auto paramC = opp::wrap_type<ov::op::v0::Parameter>();
537
+ auto cvtA = opp::wrap_type<ov::op::v0::Convert>({paramA});
538
+ auto mulply = opp::wrap_type<ov::op::v1::Multiply>({cvtA, paramC});
539
+ auto cvt = opp::wrap_type<ov::op::v0::Convert>({mulply});
540
+
541
+ auto callback = [=](ov::pass::pattern::Matcher& m) {
542
+ auto & node_to_output = m.get_pattern_value_map ();
543
+ auto matched_nodeA = node_to_output.at (paramA).get_node_shared_ptr ();
544
+ auto matched_nodeC = node_to_output.at (paramC).get_node_shared_ptr ();
545
+
546
+ NPUW_ASSERT (ov::op::util::is_parameter (matched_nodeA));
547
+ NPUW_ASSERT (ov::op::util::is_parameter (matched_nodeC));
548
+
549
+ auto matched_paramA = std::static_pointer_cast<ov::op::v0::Parameter>(matched_nodeA);
550
+ auto matched_paramC = std::static_pointer_cast<ov::op::v0::Parameter>(matched_nodeC);
551
+
552
+ if (ov::element::i4 == matched_paramA->get_element_type () &&
553
+ (ov::element::f16 == matched_paramC->get_element_type () ||
554
+ ov::element::f32 == matched_paramC->get_element_type ())) {
555
+ LOG_DEBUG (" Matched: " << matched_paramA << " , set element type to " << dcoff_type);
556
+ matched_paramA->set_element_type (dcoff_type);
557
+
558
+ if (dcoff_mode == DCOffMode::CAST_SCALE) {
559
+ NPUW_ASSERT (dcoff_type == ov::element::f16);
560
+
561
+ LOG_DEBUG (" Matched: " << matched_paramC << " - parameter to remove..." );
562
+ LOG_BLOCK ();
563
+
564
+ // Extra transformation here:
565
+ // - remove Multiply + Intermediate Convert
566
+ // - mark paramC for removal.
567
+ // Convert will be reconnected to paramA directly.
568
+
569
+ // Record mapping from the Scale coeff parameter to the Real weight parameter
570
+ pref.get ().scales [matched_paramC] = matched_paramA;
571
+
572
+ // Disconnect Multiply and Convert from their outputs
573
+ auto matched_mulply = node_to_output.at (mulply).get_node_shared_ptr ();
574
+ auto matched_convrt = node_to_output.at (cvtA).get_node_shared_ptr ();
575
+ auto drop_outputs = [](std::shared_ptr<ov::Node> node) {
576
+ for (auto && node_outputs : node->outputs ()) {
577
+ for (auto && node_reader_port : node_outputs.get_target_inputs ()) {
578
+ node_outputs.remove_target_input (node_reader_port);
579
+ }
580
+ }
581
+ };
582
+ LOG_DEBUG (" Dropping the connections..." );
583
+ drop_outputs (matched_mulply);
584
+ drop_outputs (matched_convrt);
585
+
586
+ LOG_DEBUG (" Reconnecting the Root..." );
587
+ auto matched_cvt = node_to_output.at (cvt).get_node_shared_ptr ();
588
+ matched_cvt->input (0 ).replace_source_output (matched_paramA);
589
+ }
590
+ LOG_DEBUG (" Done" );
591
+ }
592
+ return false ; // root node hasn't changed
593
+ };
594
+
595
+ register_matcher (std::make_shared<opp::Matcher>(cvt, " TagDCOFFPassCWAI3" ), std::move (callback));
596
+ }
597
+
515
598
// ------------------------------------------------------------------------------
516
599
// Pattern: 4SymW16A for CWAI
517
600
//
0 commit comments