@@ -186,9 +186,6 @@ bool primitive_attr_t::defined(dnnl_primitive_attr::skip_mask_t mask) const {
186
186
#define CHECK_ARG (x ) ok = ok && (x)
187
187
#define CHECK_MASK (mask_name, mask_field ) \
188
188
CHECK_ARG (IMPLICATION ((bool )(~mask & (mask_name)), (mask_field).defined ()))
189
- CHECK_MASK (smask_t ::scales, scales_);
190
- CHECK_MASK (smask_t ::zero_points, zero_points_);
191
- CHECK_MASK (smask_t ::post_ops, post_ops_);
192
189
CHECK_MASK (smask_t ::rnn_data_qparams, rnn_data_qparams_);
193
190
CHECK_MASK (smask_t ::rnn_weights_qparams, rnn_weights_qparams_);
194
191
CHECK_MASK (smask_t ::rnn_weights_projection_qparams,
@@ -200,6 +197,8 @@ bool primitive_attr_t::defined(dnnl_primitive_attr::skip_mask_t mask) const {
200
197
201
198
status_t post_ops_t::append_sum (
202
199
float scale, int32_t zero_point, data_type_t dt) {
200
+ if (is_runtime_value (scale)) return invalid_arguments;
201
+
203
202
entry_.emplace_back ();
204
203
auto &e = entry_.back ();
205
204
e.kind = primitive_kind::sum;
@@ -213,6 +212,9 @@ status_t post_ops_t::append_eltwise(
213
212
float scale, alg_kind_t alg, float alpha, float beta) {
214
213
if (!math::is_eltwise_ok (data_type::f32, alg, alpha, beta))
215
214
return invalid_arguments;
215
+ if (is_runtime_value (scale)) return invalid_arguments;
216
+ if (is_runtime_value (alpha)) return invalid_arguments;
217
+ if (is_runtime_value (beta)) return invalid_arguments;
216
218
217
219
entry_.emplace_back ();
218
220
auto &e = entry_.back ();
@@ -310,27 +312,6 @@ status_t post_ops_t::append_prelu(int mask) {
310
312
return success;
311
313
}
312
314
313
- bool post_ops_t::defined () const {
314
- for (int idx = 0 ; idx < len (); ++idx) {
315
- auto kind = entry_[idx].kind ;
316
- if (kind == primitive_kind::sum) {
317
- if (is_runtime_value (entry_[idx].sum .scale )) return false ;
318
- } else if (kind == primitive_kind::eltwise) {
319
- const auto &e = entry_[idx].eltwise ;
320
- if (is_runtime_value (e.scale ) || is_runtime_value (e.alpha )
321
- || is_runtime_value (e.beta ))
322
- return false ;
323
- } else if (utils::one_of (kind, primitive_kind::binary,
324
- primitive_kind::prelu,
325
- primitive_kind::convolution)) {
326
- // binary is always defined
327
- } else {
328
- assert (!" unreachable" );
329
- }
330
- }
331
- return true ;
332
- }
333
-
334
315
status_t post_ops_t::set_default_formats (const memory_desc_t *dst_md) {
335
316
for (int idx = 0 ; idx < len (); ++idx) {
336
317
if (!contain (primitive_kind::binary, idx)) continue ;
0 commit comments