Skip to content

Commit e0ded0f

Browse files
committed
graph: dnnl: comments fix
1 parent 38c5fc9 commit e0ded0f

File tree

3 files changed

+8
-21
lines changed

3 files changed

+8
-21
lines changed

src/graph/backend/dnnl/kernels/sdp_primitive_v1.cpp

-4
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,6 @@ status_t sdp_primitive_v1_kernel_t::ocl_execute_impl(const stream_t *g_stream,
192192
const std::vector<tensor_t> &inputs,
193193
const std::vector<tensor_t> &outputs,
194194
const std::vector<cl_event> &cl_deps, cl_event *ret_event) {
195-
// sdp_primitive_v1_kernel_t only supports Intel GPU.
196-
#if DNNL_GPU_VENDOR != DNNL_VENDOR_INTEL
197-
return status::unimplemented;
198-
#endif
199195
auto deps = cl_deps;
200196
cl_event returned_event {};
201197

src/graph/backend/dnnl/op_executable.hpp

+7-16
Original file line numberDiff line numberDiff line change
@@ -2657,16 +2657,14 @@ struct sdpa_executable_t : public op_executable_t {
26572657

26582658
auto scale_dt = impl::data_type::undef;
26592659
size_t idx = 3;
2660-
with_scale_ = op->get_attr<bool>(
2661-
dnnl::impl::graph::dnnl_impl::op_attr::with_scale);
2660+
with_scale_ = op->get_attr<bool>(op_attr::with_scale);
26622661
if (with_scale_)
26632662
scale_dt = op->get_input_value(idx++)
26642663
->get_logical_tensor()
26652664
.data_type;
26662665

26672666
dnnl::memory::desc md_mask;
2668-
with_mask_ = op->get_attr<bool>(
2669-
dnnl::impl::graph::dnnl_impl::op_attr::with_mask);
2667+
with_mask_ = op->get_attr<bool>(op_attr::with_mask);
26702668
if (with_mask_)
26712669
md_mask = make_dnnl_memory_desc(
26722670
op->get_input_value(idx++)->get_logical_tensor());
@@ -2675,13 +2673,10 @@ struct sdpa_executable_t : public op_executable_t {
26752673
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
26762674
attr.set_fpmath_mode(
26772675
static_cast<dnnl::fpmath_mode>(mgr.get_fpmath_mode().mode_));
2678-
if (op->has_attr(
2679-
dnnl::impl::graph::dnnl_impl::op_attr::is_invert_scale))
2680-
is_invert_scale_ = op->get_attr<bool>(
2681-
dnnl::impl::graph::dnnl_impl::op_attr::is_invert_scale);
2676+
if (op->has_attr(op_attr::is_invert_scale))
2677+
is_invert_scale_ = op->get_attr<bool>(op_attr::is_invert_scale);
26822678

2683-
is_causal_mask_ = op->get_attr<bool>(
2684-
dnnl::impl::graph::dnnl_impl::op_attr::with_causal);
2679+
is_causal_mask_ = op->get_attr<bool>(op_attr::with_causal);
26852680

26862681
dim_t kv_head_number
26872682
= op->get_input_value(1)->get_logical_tensor().dims[1];
@@ -2692,15 +2687,11 @@ struct sdpa_executable_t : public op_executable_t {
26922687
is_initialized_ = false;
26932688
} else {
26942689
status_t s = sdpa_pd_->create_primitive(sdpa_prim_, p_engine.get());
2695-
if (s != dnnl::impl::status::success) {
2696-
is_initialized_ = false;
2697-
} else {
2698-
is_initialized_ = true;
2699-
}
2690+
is_initialized_ = s == status::success ? true : false;
27002691
}
27012692
}
27022693

2703-
bool is_initialized() { return is_initialized_; }
2694+
bool is_initialized() const { return is_initialized_; }
27042695

27052696
void execute(const stream &stream,
27062697
const std::unordered_map<int, memory> &args) const override {

src/graph/backend/dnnl/passes/transform.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -4272,7 +4272,7 @@ status_t fuse_sdpa(std::shared_ptr<subgraph_t> &sg) {
42724272
sdpa_op->set_attr<bool>(op_attr::is_invert_scale,
42734273
(alg == dnnl::algorithm::binary_div));
42744274
}
4275-
// hanlde explicit mask
4275+
// handle explicit mask
42764276
else if (alg == dnnl::algorithm::binary_add) {
42774277
auto mask_val = op->get_input_value(1);
42784278
mask_val->remove_consumer(*op, 1);

0 commit comments

Comments
 (0)