@@ -512,6 +512,88 @@ DCOFFPassReshape2::DCOFFPassReshape2(DCOffMode dcoff_mode, ov::element::Type dco
512
512
register_matcher (std::make_shared<opp::Matcher>(reshpe, " TagDCOFFReshape2" ), std::move (callback));
513
513
}
514
514
515
+ // Pattern: Phi-3 4SymW16A/GPTQ for CWAI
516
+ //
517
+ // FIXME: Think how it can be unified with the above
518
+ //
519
+ // "tensor" "scale" > "tensor"
520
+ // Param:A Param:C > Param:A
521
+ // i4 f16|f32 > f16
522
+ // : : > :
523
+ // V : > V
524
+ // Convert : > Convert
525
+ // f16|f32 : > f32
526
+ // : : >
527
+ // V V >
528
+ // Multiply >
529
+ // f16|f32 >
530
+ // : >
531
+ // : >
532
+ // V >
533
+ // Convert
534
+
535
+ DCOFFPassCWAI3::DCOFFPassCWAI3 (DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref) {
536
+ auto paramA = opp::wrap_type<ov::op::v0::Parameter>();
537
+ auto paramC = opp::wrap_type<ov::op::v0::Parameter>();
538
+ auto cvtA = opp::wrap_type<ov::op::v0::Convert>({paramA});
539
+ auto mulply = opp::wrap_type<ov::op::v1::Multiply>({cvtA, paramC});
540
+ auto cvt = opp::wrap_type<ov::op::v0::Convert>({mulply});
541
+
542
+ auto callback = [=](ov::pass::pattern::Matcher& m) {
543
+ auto & node_to_output = m.get_pattern_value_map ();
544
+ auto matched_nodeA = node_to_output.at (paramA).get_node_shared_ptr ();
545
+ auto matched_nodeC = node_to_output.at (paramC).get_node_shared_ptr ();
546
+
547
+ NPUW_ASSERT (ov::op::util::is_parameter (matched_nodeA));
548
+ NPUW_ASSERT (ov::op::util::is_parameter (matched_nodeC));
549
+
550
+ auto matched_paramA = std::static_pointer_cast<ov::op::v0::Parameter>(matched_nodeA);
551
+ auto matched_paramC = std::static_pointer_cast<ov::op::v0::Parameter>(matched_nodeC);
552
+
553
+ if (ov::element::i4 == matched_paramA->get_element_type () &&
554
+ (ov::element::f16 == matched_paramC->get_element_type () ||
555
+ ov::element::f32 == matched_paramC->get_element_type ())) {
556
+ LOG_DEBUG (" Matched: " << matched_paramA << " , set element type to " << dcoff_type);
557
+ matched_paramA->set_element_type (dcoff_type);
558
+
559
+ if (dcoff_mode == DCOffMode::CAST_SCALE) {
560
+ NPUW_ASSERT (dcoff_type == ov::element::f16);
561
+
562
+ LOG_DEBUG (" Matched: " << matched_paramC << " - parameter to remove..." );
563
+ LOG_BLOCK ();
564
+
565
+ // Extra transformation here:
566
+ // - remove Multiply + Intermediate Convert
567
+ // - mark paramC for removal.
568
+ // Convert will be reconnected to paramA directly.
569
+
570
+ pref.get ().scales [matched_paramC] = std::move (matched_paramA);
571
+ // Disconnect Multiply and Convert from their outputs
572
+ auto matched_mulply = node_to_output.at (mulply).get_node_shared_ptr ();
573
+ auto matched_convrt = node_to_output.at (cvtA).get_node_shared_ptr ();
574
+ auto drop_outputs = [](std::shared_ptr<ov::Node> node) {
575
+ for (auto && node_outputs : node->outputs ()) {
576
+ for (auto && node_reader_port : node_outputs.get_target_inputs ()) {
577
+ node_outputs.remove_target_input (node_reader_port);
578
+ }
579
+ }
580
+ };
581
+ LOG_DEBUG (" Dropping the connections..." );
582
+ drop_outputs (std::move (matched_mulply));
583
+ drop_outputs (std::move (matched_convrt));
584
+
585
+ LOG_DEBUG (" Reconnecting the Root..." );
586
+ auto matched_cvt = node_to_output.at (cvt).get_node_shared_ptr ();
587
+ matched_cvt->input (0 ).replace_source_output (matched_paramA);
588
+ }
589
+ LOG_DEBUG (" Done" );
590
+ }
591
+ return false ; // root node hasn't changed
592
+ };
593
+
594
+ register_matcher (std::make_shared<opp::Matcher>(cvt, " TagDCOFFPassCWAI3" ), std::move (callback));
595
+ }
596
+
515
597
// ------------------------------------------------------------------------------
516
598
// Pattern: 4SymW16A for CWAI
517
599
//
0 commit comments