|
1 | 1 | /*******************************************************************************
|
2 |
| -* Copyright 2019-2022 Intel Corporation |
| 2 | +* Copyright 2019-2025 Intel Corporation |
3 | 3 | *
|
4 | 4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | 5 | * you may not use this file except in compliance with the License.
|
|
38 | 38 | using namespace dnnl;
|
39 | 39 |
|
40 | 40 | void simple_net(engine::kind engine_kind) {
|
41 |
| - using tag = memory::format_tag; |
42 |
| - using dt = memory::data_type; |
43 |
| - |
44 | 41 | auto eng = engine(engine_kind, 0);
|
45 | 42 | stream s(eng);
|
46 | 43 |
|
@@ -79,27 +76,36 @@ void simple_net(engine::kind engine_kind) {
|
79 | 76 | conv_bias[i] = sinf((float)i);
|
80 | 77 |
|
81 | 78 | // create memory for user data
|
82 |
| - auto conv_user_src_memory |
83 |
| - = memory({{conv_src_tz}, dt::f32, tag::nchw}, eng); |
| 79 | + auto conv_user_src_memory = memory( |
| 80 | + {{conv_src_tz}, memory::data_type::f32, memory::format_tag::nchw}, |
| 81 | + eng); |
84 | 82 | write_to_dnnl_memory(net_src.data(), conv_user_src_memory);
|
85 | 83 |
|
86 | 84 | auto conv_user_weights_memory
|
87 |
| - = memory({{conv_weights_tz}, dt::f32, tag::oihw}, eng); |
| 85 | + = memory({{conv_weights_tz}, memory::data_type::f32, |
| 86 | + memory::format_tag::oihw}, |
| 87 | + eng); |
88 | 88 | write_to_dnnl_memory(conv_weights.data(), conv_user_weights_memory);
|
89 | 89 |
|
90 |
| - auto conv_user_bias_memory = memory({{conv_bias_tz}, dt::f32, tag::x}, eng); |
| 90 | + auto conv_user_bias_memory = memory( |
| 91 | + {{conv_bias_tz}, memory::data_type::f32, memory::format_tag::x}, |
| 92 | + eng); |
91 | 93 | write_to_dnnl_memory(conv_bias.data(), conv_user_bias_memory);
|
92 | 94 |
|
93 | 95 | // create memory descriptors for bfloat16 convolution data w/ no specified
|
94 | 96 | // format tag(`any`)
|
95 | 97 | // tag `any` lets a primitive(convolution in this case)
|
96 | 98 | // chose the memory format preferred for best performance.
|
97 |
| - auto conv_src_md = memory::desc({conv_src_tz}, dt::bf16, tag::any); |
98 |
| - auto conv_weights_md = memory::desc({conv_weights_tz}, dt::bf16, tag::any); |
99 |
| - auto conv_dst_md = memory::desc({conv_dst_tz}, dt::bf16, tag::any); |
| 99 | + auto conv_src_md = memory::desc( |
| 100 | + {conv_src_tz}, memory::data_type::bf16, memory::format_tag::any); |
| 101 | + auto conv_weights_md = memory::desc({conv_weights_tz}, |
| 102 | + memory::data_type::bf16, memory::format_tag::any); |
| 103 | + auto conv_dst_md = memory::desc( |
| 104 | + {conv_dst_tz}, memory::data_type::bf16, memory::format_tag::any); |
100 | 105 | // here bias data type is set to bf16.
|
101 | 106 | // additionally, f32 data type is supported for bf16 convolution.
|
102 |
| - auto conv_bias_md = memory::desc({conv_bias_tz}, dt::bf16, tag::any); |
| 107 | + auto conv_bias_md = memory::desc( |
| 108 | + {conv_bias_tz}, memory::data_type::bf16, memory::format_tag::any); |
103 | 109 |
|
104 | 110 | // create a convolution primitive descriptor
|
105 | 111 |
|
@@ -225,11 +231,13 @@ void simple_net(engine::kind engine_kind) {
|
225 | 231 | memory::dims pool_padding = {0, 0};
|
226 | 232 |
|
227 | 233 | // create memory for pool dst data in user format
|
228 |
| - auto pool_user_dst_memory |
229 |
| - = memory({{pool_dst_tz}, dt::f32, tag::nchw}, eng); |
| 234 | + auto pool_user_dst_memory = memory( |
| 235 | + {{pool_dst_tz}, memory::data_type::f32, memory::format_tag::nchw}, |
| 236 | + eng); |
230 | 237 |
|
231 | 238 | // create pool dst memory descriptor in format any for bfloat16 data type
|
232 |
| - auto pool_dst_md = memory::desc({pool_dst_tz}, dt::bf16, tag::any); |
| 239 | + auto pool_dst_md = memory::desc( |
| 240 | + {pool_dst_tz}, memory::data_type::bf16, memory::format_tag::any); |
233 | 241 |
|
234 | 242 | // create a pooling primitive descriptor
|
235 | 243 | auto pool_pd = pooling_forward::primitive_desc(eng, prop_kind::forward,
|
@@ -269,14 +277,17 @@ void simple_net(engine::kind engine_kind) {
|
269 | 277 | net_diff_dst[i] = sinf((float)i);
|
270 | 278 |
|
271 | 279 | // create memory for user diff dst data stored in float data type
|
272 |
| - auto pool_user_diff_dst_memory |
273 |
| - = memory({{pool_dst_tz}, dt::f32, tag::nchw}, eng); |
| 280 | + auto pool_user_diff_dst_memory = memory( |
| 281 | + {{pool_dst_tz}, memory::data_type::f32, memory::format_tag::nchw}, |
| 282 | + eng); |
274 | 283 | write_to_dnnl_memory(net_diff_dst.data(), pool_user_diff_dst_memory);
|
275 | 284 |
|
276 | 285 | // Backward pooling
|
277 | 286 | // create memory descriptors for pooling
|
278 |
| - auto pool_diff_src_md = memory::desc({lrn_data_tz}, dt::bf16, tag::any); |
279 |
| - auto pool_diff_dst_md = memory::desc({pool_dst_tz}, dt::bf16, tag::any); |
| 287 | + auto pool_diff_src_md = memory::desc( |
| 288 | + {lrn_data_tz}, memory::data_type::bf16, memory::format_tag::any); |
| 289 | + auto pool_diff_dst_md = memory::desc( |
| 290 | + {pool_dst_tz}, memory::data_type::bf16, memory::format_tag::any); |
280 | 291 |
|
281 | 292 | // backward primitive descriptor needs to hint forward descriptor
|
282 | 293 | auto pool_bwd_pd = pooling_backward::primitive_desc(eng,
|
@@ -305,7 +316,8 @@ void simple_net(engine::kind engine_kind) {
|
305 | 316 | {DNNL_ARG_WORKSPACE, pool_workspace_memory}});
|
306 | 317 |
|
307 | 318 | // Backward lrn
|
308 |
| - auto lrn_diff_dst_md = memory::desc({lrn_data_tz}, dt::bf16, tag::any); |
| 319 | + auto lrn_diff_dst_md = memory::desc( |
| 320 | + {lrn_data_tz}, memory::data_type::bf16, memory::format_tag::any); |
309 | 321 | const auto &lrn_diff_src_md = lrn_diff_dst_md;
|
310 | 322 |
|
311 | 323 | // create backward lrn primitive descriptor
|
@@ -335,8 +347,10 @@ void simple_net(engine::kind engine_kind) {
|
335 | 347 | {DNNL_ARG_WORKSPACE, lrn_workspace_memory}});
|
336 | 348 |
|
337 | 349 | // Backward relu
|
338 |
| - auto relu_diff_src_md = memory::desc({relu_data_tz}, dt::bf16, tag::any); |
339 |
| - auto relu_diff_dst_md = memory::desc({relu_data_tz}, dt::bf16, tag::any); |
| 350 | + auto relu_diff_src_md = memory::desc( |
| 351 | + {relu_data_tz}, memory::data_type::bf16, memory::format_tag::any); |
| 352 | + auto relu_diff_dst_md = memory::desc( |
| 353 | + {relu_data_tz}, memory::data_type::bf16, memory::format_tag::any); |
340 | 354 | auto relu_src_md = conv_pd.dst_desc();
|
341 | 355 |
|
342 | 356 | // create backward relu primitive_descriptor
|
@@ -367,14 +381,20 @@ void simple_net(engine::kind engine_kind) {
|
367 | 381 | // create user format diff weights and diff bias memory for float data type
|
368 | 382 |
|
369 | 383 | auto conv_user_diff_weights_memory
|
370 |
| - = memory({{conv_weights_tz}, dt::f32, tag::nchw}, eng); |
371 |
| - auto conv_diff_bias_memory = memory({{conv_bias_tz}, dt::f32, tag::x}, eng); |
| 384 | + = memory({{conv_weights_tz}, memory::data_type::f32, |
| 385 | + memory::format_tag::nchw}, |
| 386 | + eng); |
| 387 | + auto conv_diff_bias_memory = memory( |
| 388 | + {{conv_bias_tz}, memory::data_type::f32, memory::format_tag::x}, |
| 389 | + eng); |
372 | 390 |
|
373 | 391 | // create memory descriptors for bfloat16 convolution data
|
374 |
| - auto conv_bwd_src_md = memory::desc({conv_src_tz}, dt::bf16, tag::any); |
375 |
| - auto conv_diff_weights_md |
376 |
| - = memory::desc({conv_weights_tz}, dt::bf16, tag::any); |
377 |
| - auto conv_diff_dst_md = memory::desc({conv_dst_tz}, dt::bf16, tag::any); |
| 392 | + auto conv_bwd_src_md = memory::desc( |
| 393 | + {conv_src_tz}, memory::data_type::bf16, memory::format_tag::any); |
| 394 | + auto conv_diff_weights_md = memory::desc({conv_weights_tz}, |
| 395 | + memory::data_type::bf16, memory::format_tag::any); |
| 396 | + auto conv_diff_dst_md = memory::desc( |
| 397 | + {conv_dst_tz}, memory::data_type::bf16, memory::format_tag::any); |
378 | 398 |
|
379 | 399 | // use diff bias provided by the user
|
380 | 400 | auto conv_diff_bias_md = conv_diff_bias_memory.get_desc();
|
|
0 commit comments