Skip to content

Commit d6fa6a5

Browse files
committed
rfc: dropout primitive attribute
1 parent 0cb412b commit d6fa6a5

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

rfcs/20230818-Dropout/README.md

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Introducing Dropout primitive attribute
2+
3+
## Introduction
4+
5+
In many DNN and GNN models, [Dropout](https://en.wikipedia.org/wiki/Convolutional_neural_network#Dropout) is used to improve training results. In some cases, this layer can take a significant amount of time. To enhance the performance of training, we want to optimize it and, as a result, fuse it with the previous primitive.
6+
7+
This idea was [proposed](https://github.com/oneapi-src/oneDNN/pull/760) some time ago. As a result of that discussion, a primitive attribute was proposed to resolve the performance gap.
8+
9+
## Proposal
10+
11+
Additional function to set dropout attibute in C API:
12+
13+
```c
14+
/// Returns the parameters of a drop out attribute.
15+
///
16+
/// @param attr Primitive attributes.
17+
/// @param enable_drop Drop-out enable/disable flag
18+
/// @param mask_desc Output memory descriptor of a drop out mask.
19+
/// @returns #dnnl_success on success and a status describing the error
20+
/// otherwise.
21+
dnnl_status_t DNNL_API dnnl_primitive_attr_get_dropout(
22+
const_dnnl_primitive_attr_t attr, uint8_t *enable_drop,
23+
const_dnnl_memory_desc_t *mask_desc);
24+
25+
/// Set up drop-out primitive attribute.
26+
///
27+
/// @param attr Primitive attributes.
28+
/// @param enable_drop Drop-out enable/disable flag
29+
/// @param mask_desc Output memory descriptor of a drop out mask.
30+
/// @returns #dnnl_success on success and a status describing the error
31+
/// otherwise.
32+
dnnl_status_t DNNL_API dnnl_primitive_attr_set_dropout(
33+
dnnl_primitive_attr_t attr, uint8_t enable_drop,
34+
const_dnnl_memory_desc_t mask_desc);
35+
```
36+
37+
for C++ API:
38+
```c++
39+
/// Returns the parameters of a drop out attribute.
40+
///
41+
/// @param enable_drop Drop-out enable/disable flag
42+
/// @param mask_desc Output memory descriptor of a drop out mask.
43+
void get_dropout(bool &enabled, memory::desc &mask_desc) const {
44+
const_dnnl_memory_desc_t cdesc;
45+
uint8_t enabled_u8;
46+
error::wrap_c_api(
47+
dnnl_primitive_attr_get_dropout(get(), &enabled_u8, &cdesc),
48+
"could not get parameters of a dropout attribute");
49+
dnnl_memory_desc_t cloned_md = nullptr;
50+
error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
51+
"could not clone a memory descriptor");
52+
mask_desc = memory::desc(cloned_md);
53+
enabled = enabled_u8;
54+
}
55+
56+
/// Set up drop-out.
57+
///
58+
/// @param enable_drop Drop-out enable/disable flag
59+
/// @param mask_desc Output memory descriptor of a drop out mask.
60+
void set_dropout(bool enabled, const memory::desc &mask_desc) {
61+
error::wrap_c_api(dnnl_primitive_attr_set_dropout(get(), enabled, mask_desc.get()),
62+
"could not set dropout primitive attribute");
63+
}
64+
```
65+
and runtime dropout arguments: output mask, which can be used in backward pass,
66+
dropout probability and seed.
67+
68+
```c
69+
/// Arguments for drop out output mask.
70+
#define DNNL_ARG_ATTR_DROPOUT_MASK 16385
71+
72+
/// Arguments for drop out probability param.
73+
#define DNNL_ARG_ATTR_DROPOUT_PROBABILITY 16386
74+
75+
/// Arguments for drop out seed.
76+
#define DNNL_ARG_ATTR_DROPOUT_SEED 16387
77+
```
78+
In most frameworks, the dropout operation is enabled only for the forward training pass, while for the backward pass, the binary multiplication operation can be used. For forward inference, nothing should be done to the tensor.
79+

0 commit comments

Comments
 (0)