Skip to content

Commit a8999db

Browse files
committed
generic: fix several gtest issues
1 parent cfe12d8 commit a8999db

39 files changed

+406
-2
lines changed

src/gpu/generic/ref_concat.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ struct ref_concat_t : public gpu::primitive_t {
129129
}
130130

131131
status_t execute(const exec_ctx_t &ctx) const override {
132+
if (memory_desc_wrapper(pd()->dst_md()).size() == 0)
133+
return status::success;
134+
132135
using namespace memory_tracking::names;
133136
impl::engine_t *engine = ctx.stream()->engine();
134137
const auto n = pd()->n_inputs();

src/gpu/generic/sycl/ref_binary.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,15 @@ status_t ref_binary_t::pd_t::init_conf() {
5555
}
5656

5757
status_t ref_binary_t::init(impl::engine_t *engine) {
58+
if (memory_desc_wrapper(pd()->dst_md()).size() == 0) return status::success;
59+
5860
const auto kid = ::sycl::get_kernel_id<binary_kernel_vec_t>();
5961
CHECK(create_kernel(engine, kid, &kernel_));
6062
return status::success;
6163
}
6264

6365
status_t ref_binary_t::execute(const exec_ctx_t &ctx) const {
66+
if (memory_desc_wrapper(pd()->dst_md()).size() == 0) return status::success;
6467

6568
ctx.zero_pad_output(DNNL_ARG_TO);
6669

src/gpu/generic/sycl/ref_convolution.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ status_t ref_convolution_fwd_t::init(impl::engine_t *engine) {
7474
}
7575

7676
status_t ref_convolution_fwd_t::execute(const exec_ctx_t &ctx) const {
77+
if (memory_desc_wrapper(pd()->dst_md()).size() == 0) return status::success;
78+
7779
parallel_for(ctx, kernel_, [&](::sycl::handler &cgh) {
7880
convolution_kernel_fwd_t convolution_kernel(pd()->conf_, cgh, ctx);
7981

@@ -134,6 +136,8 @@ status_t ref_convolution_bwd_data_t::init(impl::engine_t *engine) {
134136
}
135137

136138
status_t ref_convolution_bwd_data_t::execute(const exec_ctx_t &ctx) const {
139+
if (memory_desc_wrapper(pd()->dst_md()).size() == 0) return status::success;
140+
137141
parallel_for(ctx, kernel_, [&](::sycl::handler &cgh) {
138142
convolution_kernel_bwd_data_t convolution_kernel(pd()->conf_, cgh, ctx);
139143

src/gpu/generic/sycl/ref_layer_normalizations.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ status_t ref_layer_normalization_fwd_t::init(impl::engine_t *engine) {
9191

9292
status_t ref_layer_normalization_fwd_t::execute_forward(
9393
const exec_ctx_t &ctx) const {
94+
if (memory_desc_wrapper(pd()->dst_md()).size() == 0) return status::success;
95+
9496
if (pd()->stats_are_src()) {
9597
return parallel_for(ctx, kernel_, [&](::sycl::handler &cgh) {
9698
layer_normalization_fwd_kernel_vec_t layer_normalization_fwd_kernel(
@@ -163,6 +165,8 @@ status_t ref_layer_normalization_bwd_t::init(impl::engine_t *engine) {
163165

164166
status_t ref_layer_normalization_bwd_t::execute_backward(
165167
const exec_ctx_t &ctx) const {
168+
if (memory_desc_wrapper(pd()->dst_md()).size() == 0) return status::success;
169+
166170
if (pd()->conf_.use_scale || pd()->conf_.use_shift) {
167171
auto status = parallel_for(ctx, kernel_, [&](::sycl::handler &cgh) {
168172
auto nelems_A = memory_desc_wrapper(pd()->src_md(0)).nelems();

src/gpu/generic/sycl/ref_matmul.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ status_t ref_matmul_t::init(impl::engine_t *engine) {
142142
}
143143

144144
status_t ref_matmul_t::execute(const exec_ctx_t &ctx) const {
145+
if (memory_desc_wrapper(pd()->dst_md()).size() == 0) return status::success;
146+
145147
sycl_matmul_conf_t conf = pd()->conf_;
146148
if (pd()->any_runtime_params_) {
147149
const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md());

src/gpu/generic/sycl/ref_pooling.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ status_t ref_pooling_fwd_t::init(impl::engine_t *engine) {
7474
}
7575

7676
status_t ref_pooling_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
77+
// XXX: Add support for 0-dim src
78+
for (auto &md : {pd()->src_md(), pd()->dst_md()}) {
79+
if (memory_desc_wrapper(md).size() == 0) return status::success;
80+
}
81+
7782
return parallel_for(ctx, kernel_, [&](::sycl::handler &cgh) {
7883
auto nelems_A = memory_desc_wrapper(pd()->src_md(0)).nelems();
7984
pooling_fwd_kernel_vec_t pooling_fwd_kernel(pd()->conf_, cgh, ctx);
@@ -135,6 +140,10 @@ status_t ref_pooling_bwd_t::init(impl::engine_t *engine) {
135140
}
136141

137142
status_t ref_pooling_bwd_t::execute_backward(const exec_ctx_t &ctx) const {
143+
for (auto &md : {pd()->diff_src_md(), pd()->diff_dst_md()}) {
144+
if (memory_desc_wrapper(md).size() == 0) return status::success;
145+
}
146+
138147
return parallel_for(ctx, kernel_, [&](::sycl::handler &cgh) {
139148
auto nelems_A = memory_desc_wrapper(pd()->diff_src_md(0)).nelems();
140149
pooling_bwd_kernel_vec_t pooling_bwd_kernel(pd()->conf_, cgh, ctx);

src/gpu/generic/sycl/ref_prelu.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ struct ref_prelu_fwd_t : public gpu::generic::sycl::primitive_t {
5555
&& (src_md(0)->format_desc.blocking.inner_nblks == 0)
5656
&& (weights_md(0)->format_desc.blocking.inner_nblks == 0)
5757
&& md_dims_in_range(src_md())
58-
&& md_dims_in_range(weights_md());
58+
&& md_dims_in_range(weights_md())
59+
&& attr()->has_default_values();
5960

6061
if (!ok) return status::unimplemented;
6162
return init_conf();
@@ -98,7 +99,8 @@ struct ref_prelu_bwd_t : public gpu::generic::sycl::primitive_t {
9899
&& diff_src_md(0)->data_type == src_md(0)->data_type
99100
&& diff_weights_md(0)->data_type == weights_md(0)->data_type
100101
&& md_dims_in_range(diff_src_md())
101-
&& md_dims_in_range(weights_md());
102+
&& md_dims_in_range(weights_md())
103+
&& attr()->has_default_values();
102104

103105
if (!ok) return status::unimplemented;
104106

src/gpu/generic/sycl/ref_shuffle.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ struct ref_shuffle_t : public gpu::generic::sycl::primitive_t {
5555
&& IMPLICATION(is_fwd(),
5656
src_md(0)->format_desc.blocking.inner_nblks == 0)
5757
&& attr()->has_default_values()
58+
&& memory_desc_wrapper(src_data_md)
59+
== memory_desc_wrapper(dst_data_md)
5860
&& md_dims_in_range(src_md());
5961
if (!ok) return status::unimplemented;
6062
return init_conf();

src/gpu/generic/sycl/ref_softmax.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ status_t ref_sycl_softmax_fwd_t::init(impl::engine_t *engine) {
5252
}
5353

5454
status_t ref_sycl_softmax_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
55+
if (pd()->has_zero_dim_memory()) return status::success;
56+
5557
return parallel_for(ctx, kernel_, [&](::sycl::handler &cgh) {
5658
softmax_fwd_kernel_vec_t softmax_fwd_kernel_(pd()->conf_, cgh, ctx);
5759

@@ -82,6 +84,8 @@ status_t ref_sycl_softmax_bwd_t::init(impl::engine_t *engine) {
8284
}
8385

8486
status_t ref_sycl_softmax_bwd_t::execute_backward(const exec_ctx_t &ctx) const {
87+
if (pd()->has_zero_dim_memory()) return status::success;
88+
8589
return parallel_for(ctx, kernel_, [&](::sycl::handler &cgh) {
8690
softmax_bwd_kernel_vec_t softmax_bwd_kernel(pd()->conf_, cgh, ctx);
8791

src/gpu/generic/sycl/ref_softmax.hpp

+13
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ struct ref_sycl_softmax_fwd_t : public gpu::generic::sycl::primitive_t {
4848
&& sycl_post_ops_t::post_ops_ok(attr(), true, false)
4949
&& set_default_formats() == status::success
5050
&& attr_.set_default_formats(dst_md()) == status::success
51+
&& check_formats(src_md(), dst_md())
5152
&& md_dims_in_range(src_md());
5253

5354
if (!ok) return status::unimplemented;
@@ -70,6 +71,15 @@ struct ref_sycl_softmax_fwd_t : public gpu::generic::sycl::primitive_t {
7071
return utils::one_of(src, data_type::f32, data_type::bf16,
7172
data_type::f16, data_type::s8, data_type::u8);
7273
}
74+
75+
static bool check_formats(const memory_desc_wrapper &src,
76+
const memory_desc_wrapper &dst) {
77+
for (const auto &mdw : {src, dst}) {
78+
if (!mdw.is_plain()) return false;
79+
}
80+
81+
return true;
82+
}
7383
};
7484

7585
status_t init(impl::engine_t *engine) override;
@@ -101,6 +111,9 @@ struct ref_sycl_softmax_bwd_t : public gpu::generic::sycl::primitive_t {
101111
&& dst_md()->data_type == diff_dst_md()->data_type
102112
&& attr()->has_default_values()
103113
&& set_default_formats() == status::success
114+
&& memory_desc_wrapper(diff_src_md()).is_plain()
115+
&& memory_desc_wrapper(diff_dst_md()).is_plain()
116+
&& memory_desc_wrapper(dst_md()).is_plain()
104117
&& md_dims_in_range(diff_dst_md());
105118

106119
if (!ok) return status::unimplemented;

tests/gtests/dnnl_test_macros.hpp

+15
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,21 @@
6969
#define SKIP_FOR_LOOP_HIP(cond, message)
7070
#endif
7171

72+
#ifdef DNNL_SYCL_GENERIC
73+
#define SKIP_IF_GENERIC(cond, message) \
74+
do { \
75+
SKIP_IF(get_test_engine_kind() == engine::kind::gpu && (cond), \
76+
(message)); \
77+
} while (0)
78+
79+
#define SKIP_FOR_LOOP_GENERIC(cond, message) \
80+
SKIP_FOR_LOOP( \
81+
get_test_engine_kind() == engine::kind::gpu && (cond), (message));
82+
#else
83+
#define SKIP_IF_GENERIC(cond, message)
84+
#define SKIP_FOR_LOOP_GENERIC(cond, message)
85+
#endif
86+
7287
#define TEST_F_(test_fixture, test_name) TEST_F(test_fixture, test_name)
7388

7489
#define CPU_TEST_F(test_fixture, test_name) \

tests/gtests/test_batch_normalization.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ bool hip_check_format_tag(tag first_tag, Rest... rest_tags) {
5959
return hip_check_format_tag(rest_tags...);
6060
}
6161

62+
bool generic_check_format_tag(tag atag) {
63+
return impl::utils::one_of(atag, tag::ncw, tag::nchw, tag::ncdhw, tag::nwc,
64+
tag::nhwc, tag::ndhwc, tag::any);
65+
}
66+
6267
class batch_normalization_test_t
6368
: public ::testing::TestWithParam<batch_normalization_test_params_t> {
6469
private:
@@ -80,6 +85,11 @@ class batch_normalization_test_t
8085
SKIP_IF_HIP(!hip_check_format_tag(p.src_tag, p.dst_tag),
8186
"Unsupported format tag");
8287

88+
SKIP_IF_GENERIC(
89+
!generic_check_format_tag(p.src_tag), "Unsupported format tag");
90+
SKIP_IF_GENERIC(
91+
!generic_check_format_tag(p.dst_tag), "Unsupported format tag");
92+
8393
SKIP_IF_CUDA(p.src_dt != p.dst_dt && p.src_dt != dt::undef
8494
&& p.dst_dt != dt::undef,
8595
"Unsupported different data types for source and "

tests/gtests/test_binary.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ class binary_test_t : public ::testing::TestWithParam<binary_test_params_t> {
6969
"Unsupported source format tag");
7070
SKIP_IF_HIP(!hip_check_format_tag(tag),
7171
"Unsupported source format tag");
72+
SKIP_IF_GENERIC(!generic_check_format_tag(tag),
73+
"Unsupported source format tag");
7274
}
7375
SKIP_IF_CUDA(!cuda_check_format_tag(p.dst_format),
7476
"Unsupported destination format tag");
@@ -101,6 +103,14 @@ class binary_test_t : public ::testing::TestWithParam<binary_test_params_t> {
101103
return atag == tag::abcd || atag == tag::acdb;
102104
}
103105
bool hip_check_format_tag(tag atag) { return atag == tag::abcd; }
106+
bool generic_check_format_tag(tag atag) {
107+
return impl::utils::one_of(atag, tag::a, tag::ab, tag::abc, tag::abcd,
108+
tag::abcde, tag::abcdef, tag::abdec, tag::acb, tag::acbde,
109+
tag::acbdef, tag::acdb, tag::acdeb, tag::ba, tag::bac,
110+
tag::bacd, tag::bca, tag::bcda, tag::bcdea, tag::cba, tag::cdba,
111+
tag::cdeba, tag::decab, tag::defcab, tag::Ab32a, tag::aBc32b,
112+
tag::any);
113+
}
104114

105115
void Test() {
106116
auto eng = get_test_engine();

tests/gtests/test_concat.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,14 @@ class concat_test_t : public ::testing::TestWithParam<concat_test_params_t> {
9999
dnnl_aBcde4b);
100100
}
101101

102+
bool generic_supported_format_tag(memory::format_tag tag) {
103+
return impl::utils::one_of(tag, dnnl_a, dnnl_ab, dnnl_abc, dnnl_abcd,
104+
dnnl_abcde, dnnl_abcdef, dnnl_abdec, dnnl_acb, dnnl_acbde,
105+
dnnl_acbdef, dnnl_acdb, dnnl_acdeb, dnnl_ba, dnnl_bac,
106+
dnnl_bacd, dnnl_bca, dnnl_bcda, dnnl_bcdea, dnnl_cba, dnnl_cdba,
107+
dnnl_cdeba, dnnl_decab, dnnl_defcab);
108+
}
109+
102110
void SetUp() override {
103111
auto data_type = data_traits<data_t>::data_type;
104112
SKIP_IF_HIP(true, "Concat operator is not supported");
@@ -109,10 +117,14 @@ class concat_test_t : public ::testing::TestWithParam<concat_test_params_t> {
109117
for (size_t i = 0; i < p.srcs_cds.size(); i++) {
110118
SKIP_IF_CUDA(!cuda_supported_format_tag(p.srcs_format[i]),
111119
"Unsupported format tag");
120+
SKIP_IF_GENERIC(!generic_supported_format_tag(p.srcs_format[i]),
121+
"Unsupported format tag");
112122
}
113123

114124
SKIP_IF_CUDA(!cuda_supported_format_tag(p.dst_format),
115125
"Unsupported format tag");
126+
SKIP_IF_GENERIC(!generic_supported_format_tag(p.dst_format),
127+
"Unsupported format tag");
116128
catch_expected_failures(
117129
[&]() { Test(); }, p.expect_to_fail, p.expected_status, false);
118130
}

tests/gtests/test_convolution_backward_data_common.hpp

+33
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,23 @@ class convolution_backward_data_test
138138
p.formats.weights_format, p.aalgorithm)),
139139
"Format is not supported.");
140140

141+
SKIP_IF_GENERIC(
142+
!(generic_check_format_tags(p.formats.src_format)
143+
&& generic_check_format_tags(p.formats.dst_format)
144+
&& (generic_check_format_tags(p.formats.weights_format)
145+
|| (impl::utils::one_of(
146+
p.formats.weights_format,
147+
memory::format_tag::goiw,
148+
memory::format_tag::goihw,
149+
memory::format_tag::goidhw,
150+
memory::format_tag::oiw,
151+
memory::format_tag::oihw,
152+
memory::format_tag::oidhw)))
153+
&& check_generic_dt<data_t_diff_src>()
154+
&& check_generic_dt<data_t_diff_dst>()
155+
&& check_generic_dt<data_t_wei>()),
156+
"Format is not supported.");
157+
141158
catch_expected_failures(
142159
[&]() { Test(); }, p.expect_to_fail, p.expected_status);
143160
}
@@ -156,6 +173,14 @@ class convolution_backward_data_test
156173
memory::format_tag::acdeb);
157174
}
158175

176+
bool generic_check_format_tags(memory::format_tag tag) {
177+
return impl::utils::one_of(tag, memory::format_tag::ab,
178+
memory::format_tag::abc, memory::format_tag::abcd,
179+
memory::format_tag::abcde, memory::format_tag::abcdef,
180+
memory::format_tag::acb, memory::format_tag::acdb,
181+
memory::format_tag::acdeb, memory::format_tag::any);
182+
}
183+
159184
bool check_cuda_alg_format(memory::format_tag dst_fmt,
160185
memory::format_tag wei_fmt, algorithm alg) {
161186
bool res = dst_fmt == wei_fmt;
@@ -182,6 +207,14 @@ class convolution_backward_data_test
182207
return res;
183208
}
184209

210+
template <typename dt>
211+
bool check_generic_dt() {
212+
return impl::utils::one_of(data_traits<dt>::data_type,
213+
memory::data_type::f32, memory::data_type::bf16,
214+
memory::data_type::f16, memory::data_type::s32,
215+
memory::data_type::s8, memory::data_type::u8);
216+
}
217+
185218
void Test() {
186219
auto p = ::testing::TestWithParam<
187220
test_convolution_params_t>::GetParam();

tests/gtests/test_convolution_backward_weights_common.hpp

+38
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,28 @@ class convolution_backward_weights_test
170170
p.formats.weights_format, p.aalgorithm)),
171171
"Format is not supported.");
172172

173+
SKIP_IF_GENERIC(
174+
!(generic_check_format_tags(p.formats.src_format)
175+
&& generic_check_format_tags(p.formats.dst_format)
176+
&& (generic_check_format_tags(p.formats.weights_format)
177+
|| (impl::utils::one_of(
178+
p.formats.weights_format,
179+
memory::format_tag::goiw,
180+
memory::format_tag::goihw,
181+
memory::format_tag::goidhw,
182+
memory::format_tag::oiw,
183+
memory::format_tag::oihw,
184+
memory::format_tag::oidhw,
185+
memory::format_tag::bacd,
186+
memory::format_tag::bcda,
187+
memory::format_tag::acbde,
188+
memory::format_tag::iohw,
189+
memory::format_tag::hwigo)))
190+
&& check_generic_dt<data_t_src>()
191+
&& check_generic_dt<data_t_diff_dst>()
192+
&& check_generic_dt<data_t_diff_weights>()),
193+
"Format is not supported.");
194+
173195
catch_expected_failures(
174196
[&]() { Test(); }, p.expect_to_fail, p.expected_status);
175197
}
@@ -214,6 +236,22 @@ class convolution_backward_weights_test
214236
return res;
215237
}
216238

239+
bool generic_check_format_tags(memory::format_tag tag) {
240+
return impl::utils::one_of(tag, memory::format_tag::ab,
241+
memory::format_tag::abc, memory::format_tag::abcd,
242+
memory::format_tag::abcde, memory::format_tag::abcdef,
243+
memory::format_tag::acb, memory::format_tag::acdb,
244+
memory::format_tag::acdeb, memory::format_tag::any);
245+
}
246+
247+
template <typename dt>
248+
bool check_generic_dt() {
249+
return impl::utils::one_of(data_traits<dt>::data_type,
250+
memory::data_type::f32, memory::data_type::bf16,
251+
memory::data_type::f16, memory::data_type::s32,
252+
memory::data_type::s8, memory::data_type::u8);
253+
}
254+
217255
void Test() {
218256
auto p = ::testing::TestWithParam<
219257
test_convolution_params_t>::GetParam();

0 commit comments

Comments
 (0)