@@ -48,6 +48,35 @@ namespace impl {
48
48
VCONDCHECK (primitive, create, check, sdpa, (cond), status::unimplemented, \
49
49
msg, ##__VA_ARGS__);
50
50
51
+ static inline status_t sdpa_desc_check (const memory_desc_t *q_desc,
52
+ const memory_desc_t *k_desc, const memory_desc_t *v_desc,
53
+ const memory_desc_t *dst_desc, const memory_desc_t *attn_mask_md,
54
+ const engine_t *engine, const primitive_attr_t *attr,
55
+ const primitive_attr_t *kq_attr, const primitive_attr_t *vs_attr) {
56
+ int ndims = dst_desc->ndims ;
57
+ int r = ndims - 2 , c = ndims - 1 ;
58
+ VCHECK_SDPA_COND (utils::everyone_is (ndims, q_desc->ndims , k_desc->ndims ,
59
+ v_desc->ndims ),
60
+ " number of dimensions have to match. expected: %d q: %d k: %d v: "
61
+ " %d" ,
62
+ ndims, q_desc->ndims , k_desc->ndims , v_desc->ndims );
63
+
64
+ VCHECK_SDPA_COND (q_desc->dims [c] == k_desc->dims [r],
65
+ " q_desc->dims[%d](%s) must match k_desc->dims[%d](%s)" , c,
66
+ md2dim_str (q_desc).c_str (), r, md2dim_str (k_desc).c_str ());
67
+ VCHECK_SDPA_COND (k_desc->dims [c] == v_desc->dims [r],
68
+ " k_desc->dims[%d](%s) must match v_desc->dims[%d](%s)" , c,
69
+ md2dim_str (k_desc).c_str (), r, md2dim_str (v_desc).c_str ());
70
+ VCHECK_SDPA_COND (dst_desc->dims [r] == q_desc->dims [r],
71
+ " dst_desc->dims[%d](%s) == q_desc->dims[%d](%s)" , r,
72
+ md2dim_str (dst_desc).c_str (), r, md2dim_str (q_desc).c_str ());
73
+ VCHECK_SDPA_COND (dst_desc->dims [c] == v_desc->dims [c],
74
+ " dst_desc->dims[%d](%s) == v_desc->dims[%d](%s)" , c,
75
+ md2dim_str (dst_desc).c_str (), c, md2dim_str (v_desc).c_str ());
76
+
77
+ return status::success;
78
+ }
79
+
51
80
static inline status_t sdpa_attr_check (const memory_desc_t *q_desc,
52
81
const memory_desc_t *k_desc, const memory_desc_t *v_desc,
53
82
const engine_t *engine, const primitive_attr_t *attr,
@@ -140,32 +169,13 @@ static inline status_t create_sdpa_pd(
140
169
const primitive_attr_t *attr, const primitive_attr_t *kq_attr = nullptr ,
141
170
const primitive_attr_t *vs_attr = nullptr ) {
142
171
CHECK (sdpa_attr_check (q_md, k_md, v_md, engine, attr, kq_attr, vs_attr));
172
+ CHECK (sdpa_desc_check (q_md, k_md, v_md, dst_md, attn_mask_md, engine, attr,
173
+ kq_attr, vs_attr));
143
174
144
175
auto sdpa_desc = create_sdpa_desc (q_md, k_md, v_md, dst_md, attn_mask_md,
145
176
scale_dt, invert_scale, kv_head_number, causal_mask, kq_attr,
146
177
vs_attr);
147
178
148
- int ndims = dst_md->ndims ;
149
- int r = ndims - 2 , c = ndims - 1 ;
150
- VCHECK_SDPA_COND (
151
- utils::everyone_is (ndims, q_md->ndims , k_md->ndims , v_md->ndims ),
152
- " number of dimensions have to match. expected: %d q: %d k: %d v: "
153
- " %d" ,
154
- ndims, q_md->ndims , k_md->ndims , v_md->ndims );
155
-
156
- VCHECK_SDPA_COND (q_md->dims [c] == k_md->dims [r],
157
- " q_md->dims[%d](%s) must match k_md->dims[%d](%s)" , c,
158
- md2dim_str (q_md).c_str (), r, md2dim_str (k_md).c_str ());
159
- VCHECK_SDPA_COND (k_md->dims [c] == v_md->dims [r],
160
- " k_md->dims[%d](%s) must match v_md->dims[%d](%s)" , c,
161
- md2dim_str (k_md).c_str (), r, md2dim_str (v_md).c_str ());
162
- VCHECK_SDPA_COND (dst_md->dims [r] == q_md->dims [r],
163
- " dst_md->dims[%d](%s) == q_md->dims[%d](%s)" , r,
164
- md2dim_str (dst_md).c_str (), r, md2dim_str (q_md).c_str ());
165
- VCHECK_SDPA_COND (dst_md->dims [c] == v_md->dims [c],
166
- " dst_md->dims[%d](%s) == v_md->dims[%d](%s)" , c,
167
- md2dim_str (dst_md).c_str (), c, md2dim_str (v_md).c_str ());
168
-
169
179
primitive_attr_t sdpa_attr = attr ? *attr : default_attr ();
170
180
171
181
primitive_desc_iterator_t it (
0 commit comments