Skip to content

Commit a2f5367

Browse files
committed
graph: dnnl: catch sdpa creation error for fallback
1 parent 786664f commit a2f5367

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

src/graph/backend/dnnl/op_executable.hpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -2685,9 +2685,12 @@ struct sdpa_executable_t : public op_executable_t {
26852685

26862686
dim_t kv_head_number
26872687
= op->get_input_value(1)->get_logical_tensor().dims[1];
2688-
create_sdpa_pd(sdpa_pd_, p_engine.get(), md_q.get(), md_k.get(),
2689-
md_v.get(), md_dst.get(), md_mask.get(), scale_dt,
2688+
status_t s = create_sdpa_pd(sdpa_pd_, p_engine.get(), md_q.get(),
2689+
md_k.get(), md_v.get(), md_dst.get(), md_mask.get(), scale_dt,
26902690
is_invert_scale_, kv_head_number, is_causal_mask_, attr.get());
2691+
if (s != dnnl::impl::status::success) {
2692+
throw std::runtime_error("create_sdpa_pd failed");
2693+
}
26912694

26922695
sdpa_pd_->create_primitive(sdpa_prim_, p_engine.get());
26932696
}

src/graph/backend/dnnl/passes/compile_ops.cpp

+12-6
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,20 @@ status_t compile_ops(std::shared_ptr<subgraph_t> &sg) {
5959
auto cur_op = op->shared_from_this();
6060
auto creator = opm->get_additional_item<executable_creator_func>(
6161
"executable_creator");
62-
std::shared_ptr<op_executable_t> exec
63-
= creator(cur_op, p_engine, mgr, pd_cache);
6462

65-
VCHECK_COMPILE_OPS(exec != nullptr, status::invalid_graph_op,
66-
"unimplemented op, can't compile op %s",
67-
op->get_name().c_str());
63+
try {
64+
std::shared_ptr<op_executable_t> exec
65+
= creator(cur_op, p_engine, mgr, pd_cache);
66+
VCHECK_COMPILE_OPS(exec != nullptr, status::invalid_graph_op,
67+
"unimplemented op, can't compile op %s",
68+
op->get_name().c_str());
6869

69-
sg->execs_.emplace_back(exec);
70+
sg->execs_.emplace_back(exec);
71+
} catch (const std::runtime_error &e) {
72+
VCHECK_COMPILE_OPS(false, status::unimplemented,
73+
"failed to create executable for op %s: %s",
74+
op->get_name().c_str(), e.what());
75+
}
7076
sg->is_constant_.push_back(op->has_attr(op_attr::is_constant)
7177
&& op->get_attr<bool>(op_attr::is_constant));
7278
return status::success;

0 commit comments

Comments
 (0)