@@ -545,13 +545,18 @@ status_t infer_dnnl_binary_output_shape(op_t *n,
545
545
}
546
546
}
547
547
548
- // TODO(GX): revisit this function to correct logic, check if shape is given
549
548
status_t infer_dnnl_sdpa_output_shape (op_t *n,
550
549
std::vector<logical_tensor_t *> &inputs,
551
550
std::vector<logical_tensor_t *> &outputs) {
551
+ // [batch_size, num_heads_q, seq_len_q, head_size_qk]
552
552
auto query = logical_tensor_wrapper_t (inputs[0 ]);
553
+ // [batch_size, num_heads_q, head_size_qk, seq_len_kv,]
553
554
auto key = logical_tensor_wrapper_t (inputs[1 ]);
554
- auto value = logical_tensor_wrapper_t (inputs[1 ]);
555
+ // [batch_size, num_heads_v, seq_len_kv, head_size_v]
556
+ auto value = logical_tensor_wrapper_t (inputs[2 ]);
557
+ // [batch_size, num_heads_q, seq_len_q, head_size_v]
558
+ auto out0 = logical_tensor_wrapper_t (outputs[0 ]);
559
+
555
560
dims query_dims = query.vdims ();
556
561
dims key_dims = key.vdims ();
557
562
dims value_dims = value.vdims ();
@@ -563,7 +568,36 @@ status_t infer_dnnl_sdpa_output_shape(op_t *n,
563
568
op_t::kind2str (n->get_kind ()).c_str (), dims2str (query_dims).c_str (),
564
569
dims2str (key_dims).c_str (), dims2str (value_dims).c_str ());
565
570
566
- dims inferred_output_shape = query_dims;
571
+ VCHECK_INVALID_SHAPE ((query_dims.size () == 4 ),
572
+ " %s, only support 4D input for all Q/K/V. input0 dimension: %s, "
573
+ " input1 dimension: %s, input2 dimension: %s " ,
574
+ op_t::kind2str (n->get_kind ()).c_str (),
575
+ std::to_string (query_dims.size ()).c_str (),
576
+ std::to_string (key_dims.size ()).c_str (),
577
+ std::to_string (value_dims.size ()).c_str ());
578
+
579
+ VCHECK_INVALID_SHAPE ((query_dims[3 ] == key_dims[2 ]),
580
+ " %s, query head size should be match with key head size. query "
581
+ " dims: %s, Key dims: %s" ,
582
+ op_t::kind2str (n->get_kind ()).c_str (), dims2str (query_dims).c_str (),
583
+ dims2str (key_dims).c_str ());
584
+
585
+ VCHECK_INVALID_SHAPE ((key_dims[3 ] == value_dims[2 ]),
586
+ " %s, key sequence length should be match with value sequence "
587
+ " length. key dims: %s, value dims: %s " ,
588
+ op_t::kind2str (n->get_kind ()).c_str (), dims2str (key_dims).c_str (),
589
+ dims2str (value_dims).c_str ());
590
+
591
+ dims inferred_output_shape;
592
+ inferred_output_shape
593
+ = {query_dims[0 ], query_dims[1 ], query_dims[2 ], value_dims[3 ]};
594
+
595
+ if (out0.ndims () != -1 ) {
596
+ VCHECK_INVALID_SHAPE (validate (inferred_output_shape, out0.vdims ()),
597
+ " %s, inferred out shape and output shape are not compatible" ,
598
+ op_t::kind2str (n->get_kind ()).c_str ());
599
+ }
600
+
567
601
set_shape_and_strides (*outputs[0 ], inferred_output_shape);
568
602
return status::success;
569
603
}
0 commit comments