@@ -2657,16 +2657,14 @@ struct sdpa_executable_t : public op_executable_t {
2657
2657
2658
2658
auto scale_dt = impl::data_type::undef;
2659
2659
size_t idx = 3 ;
2660
- with_scale_ = op->get_attr <bool >(
2661
- dnnl::impl::graph::dnnl_impl::op_attr::with_scale);
2660
+ with_scale_ = op->get_attr <bool >(op_attr::with_scale);
2662
2661
if (with_scale_)
2663
2662
scale_dt = op->get_input_value (idx++)
2664
2663
->get_logical_tensor ()
2665
2664
.data_type ;
2666
2665
2667
2666
dnnl::memory::desc md_mask;
2668
- with_mask_ = op->get_attr <bool >(
2669
- dnnl::impl::graph::dnnl_impl::op_attr::with_mask);
2667
+ with_mask_ = op->get_attr <bool >(op_attr::with_mask);
2670
2668
if (with_mask_)
2671
2669
md_mask = make_dnnl_memory_desc (
2672
2670
op->get_input_value (idx++)->get_logical_tensor ());
@@ -2675,13 +2673,10 @@ struct sdpa_executable_t : public op_executable_t {
2675
2673
attr.set_scratchpad_mode (dnnl::scratchpad_mode::user);
2676
2674
attr.set_fpmath_mode (
2677
2675
static_cast <dnnl::fpmath_mode>(mgr.get_fpmath_mode ().mode_ ));
2678
- if (op->has_attr (
2679
- dnnl::impl::graph::dnnl_impl::op_attr::is_invert_scale))
2680
- is_invert_scale_ = op->get_attr <bool >(
2681
- dnnl::impl::graph::dnnl_impl::op_attr::is_invert_scale);
2676
+ if (op->has_attr (op_attr::is_invert_scale))
2677
+ is_invert_scale_ = op->get_attr <bool >(op_attr::is_invert_scale);
2682
2678
2683
- is_causal_mask_ = op->get_attr <bool >(
2684
- dnnl::impl::graph::dnnl_impl::op_attr::with_causal);
2679
+ is_causal_mask_ = op->get_attr <bool >(op_attr::with_causal);
2685
2680
2686
2681
dim_t kv_head_number
2687
2682
= op->get_input_value (1 )->get_logical_tensor ().dims [1 ];
@@ -2692,15 +2687,11 @@ struct sdpa_executable_t : public op_executable_t {
2692
2687
is_initialized_ = false ;
2693
2688
} else {
2694
2689
status_t s = sdpa_pd_->create_primitive (sdpa_prim_, p_engine.get ());
2695
- if (s != dnnl::impl::status::success) {
2696
- is_initialized_ = false ;
2697
- } else {
2698
- is_initialized_ = true ;
2699
- }
2690
+ is_initialized_ = s == status::success ? true : false ;
2700
2691
}
2701
2692
}
2702
2693
2703
- bool is_initialized () { return is_initialized_; }
2694
+ bool is_initialized () const { return is_initialized_; }
2704
2695
2705
2696
void execute (const stream &stream,
2706
2697
const std::unordered_map<int , memory> &args) const override {
0 commit comments