@@ -138,6 +138,23 @@ class convolution_backward_data_test
138
138
p.formats .weights_format , p.aalgorithm )),
139
139
" Format is not supported." );
140
140
141
+ SKIP_IF_GENERIC (
142
+ !(generic_check_format_tags (p.formats .src_format )
143
+ && generic_check_format_tags (p.formats .dst_format )
144
+ && (generic_check_format_tags (p.formats .weights_format )
145
+ || (impl::utils::one_of (
146
+ p.formats .weights_format ,
147
+ memory::format_tag::goiw,
148
+ memory::format_tag::goihw,
149
+ memory::format_tag::goidhw,
150
+ memory::format_tag::oiw,
151
+ memory::format_tag::oihw,
152
+ memory::format_tag::oidhw)))
153
+ && check_generic_dt<data_t_diff_src>()
154
+ && check_generic_dt<data_t_diff_dst>()
155
+ && check_generic_dt<data_t_wei>()),
156
+ " Format is not supported." );
157
+
141
158
catch_expected_failures (
142
159
[&]() { Test (); }, p.expect_to_fail , p.expected_status );
143
160
}
@@ -156,6 +173,14 @@ class convolution_backward_data_test
156
173
memory::format_tag::acdeb);
157
174
}
158
175
176
+ bool generic_check_format_tags (memory::format_tag tag) {
177
+ return impl::utils::one_of (tag, memory::format_tag::ab,
178
+ memory::format_tag::abc, memory::format_tag::abcd,
179
+ memory::format_tag::abcde, memory::format_tag::abcdef,
180
+ memory::format_tag::acb, memory::format_tag::acdb,
181
+ memory::format_tag::acdeb, memory::format_tag::any);
182
+ }
183
+
159
184
bool check_cuda_alg_format (memory::format_tag dst_fmt,
160
185
memory::format_tag wei_fmt, algorithm alg) {
161
186
bool res = dst_fmt == wei_fmt;
@@ -182,6 +207,14 @@ class convolution_backward_data_test
182
207
return res;
183
208
}
184
209
210
+ template <typename dt>
211
+ bool check_generic_dt () {
212
+ return impl::utils::one_of (data_traits<dt>::data_type,
213
+ memory::data_type::f32, memory::data_type::bf16,
214
+ memory::data_type::f16, memory::data_type::s32,
215
+ memory::data_type::s8, memory::data_type::u8);
216
+ }
217
+
185
218
void Test () {
186
219
auto p = ::testing::TestWithParam<
187
220
test_convolution_params_t >::GetParam ();
0 commit comments