Skip to content

Commit 6e9f67b

Browse files
committedMar 20, 2025
xe: sdpa: refactor descriptor checks
1 parent 2fa9425 commit 6e9f67b

File tree

2 files changed

+36
-25
lines changed

2 files changed

+36
-25
lines changed
 

‎src/common/sdpa_test_iface.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ dnnl_status_t DNNL_API sdpa_primitive_desc_create(
3232
bool invert_scale, dnnl_dim_t kv_head_number, bool causal_mask,
3333
const_dnnl_primitive_attr_t attr, const_dnnl_primitive_attr_t kq_attr,
3434
const_dnnl_primitive_attr_t vs_attr) {
35-
if (auto err = sdpa_attr_check(query_desc, key_desc, value_desc, engine,
36-
attr, kq_attr, vs_attr)) {
37-
return err;
38-
}
35+
CHECK(sdpa_desc_check(query_desc, key_desc, value_desc, dst_desc, mask_desc,
36+
engine, attr, kq_attr, vs_attr));
37+
CHECK(sdpa_attr_check(
38+
query_desc, key_desc, value_desc, engine, attr, kq_attr, vs_attr));
39+
3940
dnnl::impl::sdpa_desc_t sdpa_desc = dnnl::impl::create_sdpa_desc(query_desc,
4041
key_desc, value_desc, dst_desc, mask_desc,
4142
(dnnl::impl::data_type_t)scale_dt, invert_scale, kv_head_number,

‎src/common/sdpa_utils.hpp

+31-21
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,35 @@ namespace impl {
4848
VCONDCHECK(primitive, create, check, sdpa, (cond), status::unimplemented, \
4949
msg, ##__VA_ARGS__);
5050

51+
static inline status_t sdpa_desc_check(const memory_desc_t *q_desc,
52+
const memory_desc_t *k_desc, const memory_desc_t *v_desc,
53+
const memory_desc_t *dst_desc, const memory_desc_t *attn_mask_md,
54+
const engine_t *engine, const primitive_attr_t *attr,
55+
const primitive_attr_t *kq_attr, const primitive_attr_t *vs_attr) {
56+
int ndims = dst_desc->ndims;
57+
int r = ndims - 2, c = ndims - 1;
58+
VCHECK_SDPA_COND(utils::everyone_is(ndims, q_desc->ndims, k_desc->ndims,
59+
v_desc->ndims),
60+
"number of dimensions have to match. expected: %d q: %d k: %d v: "
61+
"%d",
62+
ndims, q_desc->ndims, k_desc->ndims, v_desc->ndims);
63+
64+
VCHECK_SDPA_COND(q_desc->dims[c] == k_desc->dims[r],
65+
"q_desc->dims[%d](%s) must match k_desc->dims[%d](%s)", c,
66+
md2dim_str(q_desc).c_str(), r, md2dim_str(k_desc).c_str());
67+
VCHECK_SDPA_COND(k_desc->dims[c] == v_desc->dims[r],
68+
"k_desc->dims[%d](%s) must match v_desc->dims[%d](%s)", c,
69+
md2dim_str(k_desc).c_str(), r, md2dim_str(v_desc).c_str());
70+
VCHECK_SDPA_COND(dst_desc->dims[r] == q_desc->dims[r],
71+
"dst_desc->dims[%d](%s) == q_desc->dims[%d](%s)", r,
72+
md2dim_str(dst_desc).c_str(), r, md2dim_str(q_desc).c_str());
73+
VCHECK_SDPA_COND(dst_desc->dims[c] == v_desc->dims[c],
74+
"dst_desc->dims[%d](%s) == v_desc->dims[%d](%s)", c,
75+
md2dim_str(dst_desc).c_str(), c, md2dim_str(v_desc).c_str());
76+
77+
return status::success;
78+
}
79+
5180
static inline status_t sdpa_attr_check(const memory_desc_t *q_desc,
5281
const memory_desc_t *k_desc, const memory_desc_t *v_desc,
5382
const engine_t *engine, const primitive_attr_t *attr,
@@ -140,32 +169,13 @@ static inline status_t create_sdpa_pd(
140169
const primitive_attr_t *attr, const primitive_attr_t *kq_attr = nullptr,
141170
const primitive_attr_t *vs_attr = nullptr) {
142171
CHECK(sdpa_attr_check(q_md, k_md, v_md, engine, attr, kq_attr, vs_attr));
172+
CHECK(sdpa_desc_check(q_md, k_md, v_md, dst_md, attn_mask_md, engine, attr,
173+
kq_attr, vs_attr));
143174

144175
auto sdpa_desc = create_sdpa_desc(q_md, k_md, v_md, dst_md, attn_mask_md,
145176
scale_dt, invert_scale, kv_head_number, causal_mask, kq_attr,
146177
vs_attr);
147178

148-
int ndims = dst_md->ndims;
149-
int r = ndims - 2, c = ndims - 1;
150-
VCHECK_SDPA_COND(
151-
utils::everyone_is(ndims, q_md->ndims, k_md->ndims, v_md->ndims),
152-
"number of dimensions have to match. expected: %d q: %d k: %d v: "
153-
"%d",
154-
ndims, q_md->ndims, k_md->ndims, v_md->ndims);
155-
156-
VCHECK_SDPA_COND(q_md->dims[c] == k_md->dims[r],
157-
"q_md->dims[%d](%s) must match k_md->dims[%d](%s)", c,
158-
md2dim_str(q_md).c_str(), r, md2dim_str(k_md).c_str());
159-
VCHECK_SDPA_COND(k_md->dims[c] == v_md->dims[r],
160-
"k_md->dims[%d](%s) must match v_md->dims[%d](%s)", c,
161-
md2dim_str(k_md).c_str(), r, md2dim_str(v_md).c_str());
162-
VCHECK_SDPA_COND(dst_md->dims[r] == q_md->dims[r],
163-
"dst_md->dims[%d](%s) == q_md->dims[%d](%s)", r,
164-
md2dim_str(dst_md).c_str(), r, md2dim_str(q_md).c_str());
165-
VCHECK_SDPA_COND(dst_md->dims[c] == v_md->dims[c],
166-
"dst_md->dims[%d](%s) == v_md->dims[%d](%s)", c,
167-
md2dim_str(dst_md).c_str(), c, md2dim_str(v_md).c_str());
168-
169179
primitive_attr_t sdpa_attr = attr ? *attr : default_attr();
170180

171181
primitive_desc_iterator_t it(

0 commit comments

Comments
 (0)