@@ -571,6 +571,103 @@ status_t insert_reshape_for_ndx2d_matmul(std::shared_ptr<subgraph_t> &sg) {
571
571
return infer_shape (sg);
572
572
}
573
573
574
+ status_t insert_reshape_for_sdpa (std::shared_ptr<subgraph_t > &sg) {
575
+ subgraph_rewriter_t rewriter (sg);
576
+
577
+ for (auto &cur_op : sg->get_ops ()) {
578
+ if (cur_op->get_kind () != op_kind::dnnl_sdpa) continue ;
579
+
580
+ int32_t query_ndims
581
+ = cur_op->get_input_value (0 )->get_logical_tensor ().ndims ;
582
+ if (query_ndims != 5 ) continue ;
583
+
584
+ // Insert reshape for Query
585
+ auto query_dims = logical_tensor_wrapper_t (
586
+ cur_op->get_input_value (0 )->get_logical_tensor ())
587
+ .vdims ();
588
+ dims expected_query_dims {
589
+ query_dims[0 ], -1 , query_dims[3 ], query_dims[4 ]};
590
+ op_ptr reshape_query = std::make_shared<op_t >(op_kind::dnnl_reshape);
591
+ reshape_query->set_attr <bool >(op_attr::special_zero, false );
592
+ reshape_query->set_attr <std::vector<int64_t >>(
593
+ op_attr::shape, expected_query_dims);
594
+ rewriter.insert_op_before (reshape_query, cur_op, 0 );
595
+
596
+ // Insert reshape for Key
597
+ auto key_dims = logical_tensor_wrapper_t (
598
+ cur_op->get_input_value (1 )->get_logical_tensor ())
599
+ .vdims ();
600
+ dims expected_key_dims {key_dims[0 ], -1 , key_dims[3 ], key_dims[4 ]};
601
+ op_ptr reshape_key = std::make_shared<op_t >(op_kind::dnnl_reshape);
602
+ reshape_key->set_attr <bool >(op_attr::special_zero, false );
603
+ reshape_key->set_attr <std::vector<int64_t >>(
604
+ op_attr::shape, expected_key_dims);
605
+ rewriter.insert_op_before (reshape_key, cur_op, 1 );
606
+
607
+ // Insert reshape for value
608
+ auto value_dims = logical_tensor_wrapper_t (
609
+ cur_op->get_input_value (2 )->get_logical_tensor ())
610
+ .vdims ();
611
+ dims expected_value_dims {
612
+ value_dims[0 ], -1 , value_dims[3 ], value_dims[4 ]};
613
+ op_ptr reshape_value = std::make_shared<op_t >(op_kind::dnnl_reshape);
614
+ reshape_value->set_attr <bool >(op_attr::special_zero, false );
615
+ reshape_value->set_attr <std::vector<int64_t >>(
616
+ op_attr::shape, expected_value_dims);
617
+ rewriter.insert_op_before (reshape_value, cur_op, 2 );
618
+
619
+ // Insert reshape for scale
620
+ if (cur_op->get_attr <bool >(op_attr::with_scale)) {
621
+ int32_t scale_ndims
622
+ = cur_op->get_input_value (3 )->get_logical_tensor ().ndims ;
623
+ if (scale_ndims == 5 ) {
624
+ auto scale_dims = logical_tensor_wrapper_t (
625
+ cur_op->get_input_value (3 )->get_logical_tensor ())
626
+ .vdims ();
627
+ dims expected_scale_dims {
628
+ scale_dims[0 ], -1 , scale_dims[3 ], scale_dims[4 ]};
629
+ op_ptr reshape_scale
630
+ = std::make_shared<op_t >(op_kind::dnnl_reshape);
631
+ reshape_scale->set_attr <bool >(op_attr::special_zero, false );
632
+ reshape_scale->set_attr <std::vector<int64_t >>(
633
+ op_attr::shape, expected_scale_dims);
634
+ rewriter.insert_op_before (reshape_scale, cur_op, 3 );
635
+ }
636
+ }
637
+ // Insert reshape for mask
638
+ if (cur_op->get_attr <bool >(op_attr::with_mask)) {
639
+ int32_t mask_ndims
640
+ = cur_op->get_input_value (4 )->get_logical_tensor ().ndims ;
641
+ if (mask_ndims == 5 ) {
642
+ auto mask_dims = logical_tensor_wrapper_t (
643
+ cur_op->get_input_value (4 )->get_logical_tensor ())
644
+ .vdims ();
645
+ dims expected_mask_dims {
646
+ mask_dims[0 ], -1 , mask_dims[3 ], mask_dims[4 ]};
647
+ op_ptr reshape_mask
648
+ = std::make_shared<op_t >(op_kind::dnnl_reshape);
649
+ reshape_mask->set_attr <bool >(op_attr::special_zero, false );
650
+ reshape_mask->set_attr <std::vector<int64_t >>(
651
+ op_attr::shape, expected_mask_dims);
652
+ rewriter.insert_op_before (reshape_mask, cur_op, 4 );
653
+ }
654
+ }
655
+
656
+ // Insert reshape for output
657
+ auto output_dims = logical_tensor_wrapper_t (
658
+ cur_op->get_output_value (0 )->get_logical_tensor ())
659
+ .vdims ();
660
+ dims expected_output_dims {output_dims};
661
+ op_ptr reshape_output = std::make_shared<op_t >(op_kind::dnnl_reshape);
662
+ reshape_output->set_attr <bool >(op_attr::special_zero, false );
663
+ reshape_output->set_attr <std::vector<int64_t >>(
664
+ op_attr::shape, expected_output_dims);
665
+ rewriter.insert_op_after (reshape_output, cur_op, 0 );
666
+ }
667
+ rewriter.run ();
668
+ return infer_shape (sg);
669
+ }
670
+
574
671
status_t insert_unsqueeze_and_squeeze_for_matmul (
575
672
std::shared_ptr<subgraph_t > &sg) {
576
673
subgraph_rewriter_t rewriter (sg);
0 commit comments