Skip to content

Commit 0cb412b

Browse files
committed
rfc: instance normalization
1 parent cf5ec68 commit 0cb412b

11 files changed

+616
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
Proposal for Instance Normalization
2+
===================================
3+
4+
## 0. Summary (TLDR)
5+
6+
Instance normalization is a another way to normalize data that solves some of
7+
batch normalization issues. Instance normalization is used in 3D U-Net and Stable
8+
Diffusion. This RFC proposes to add a Group normalization primitive which is a
9+
generalization of Instance normalization to cover both 3D U-Net and Stable Diffusion
10+
models.
11+
12+
13+
## 1. Introduction
14+
15+
Instance normalization was introduced as a replacement of Batch normalization in
16+
real-time image generation workloads for both inference and training[[#1]][1].
17+
18+
- Batch normalization:
19+
- Mean:
20+
21+
![Mean](batch_norm_mean.png)
22+
- Variance:
23+
24+
![Variance](batch_norm_variance.png)
25+
- Normalization:
26+
27+
![Normalization](batch_norm.png)
28+
29+
- Instance normalization:
30+
- Mean:
31+
32+
![Mean](instance_norm_mean.png)
33+
- Variance:
34+
35+
![Variance](instance_norm_variance.png)
36+
- Normalization:
37+
38+
![Normalization](instance_norm.png)
39+
40+
The only difference between these 2 normalization algorithms is that Instance
41+
normalization compute separate Mean and Variance values for each batch.
42+
43+
There is another normalization called Group normalization[[#2]][2]. This
44+
normalization is a generalization of Instance normalization that splits channels
45+
onto groups:
46+
- Mean:
47+
48+
![Mean](group_norm_mean.png)
49+
- Variance:
50+
51+
![Variance](group_norm_variance.png)
52+
- Normalization:
53+
54+
![Normalization](group_norm.png)
55+
56+
**Notes**:
57+
- Channels dimension must be divisible by groups.
58+
- Group normalization is equal to Instance normalization when `groups = channels`.
59+
60+
### 1.1. Frameworks support
61+
62+
#### OpenVino
63+
64+
OpenVino doesn't support Instance normalization as part of OpSet but it has a
65+
transformation[[#3]][3] that converts a PyTorch InstanceNormalization into an OpenVino
66+
BatchNormalization. This transformation adds a few extra nodes to the graph to make
67+
this conversion possible.
68+
69+
#### TensorFlow
70+
71+
TensorFlow supports Instance normalization as part of TensorFlow SIG Addons (tensorflow_addons)[[#4]][4].
72+
This is a repository of community contributions that implements new functionality
73+
not available in core TensorFlow[[#5]][5].
74+
75+
Also TensorFlow supports Group normalization[[#6]][6].
76+
77+
#### PyTorch
78+
79+
PyTorch supports Instance normalization [[#7]][7].
80+
81+
Also PyTorch supports Group normalization[[#8]][8].
82+
83+
### 1.2. Models
84+
85+
Instance normalization is used in:
86+
- TensorFlow-based 3D U-Net[[#9]][9]. According to the projections Instance normalization takes up to 25% of model time on inference in case of int8 or up to 10% in case of f32 which is significant.
87+
- U-Net-based Stable Diffusion [[#10]][10]. This model also uses Group normalization.
88+
89+
## 2. Optimizing Instance normalization using oneDNN
90+
91+
There are a few ways to optimize Instance normalization using oneDNN
92+
93+
### 2.1. Batch normalization
94+
95+
### 2.1.1 A loop over a Batch normalization primitive
96+
97+
Instance normalization can be supported via a sequence of Batch normalization
98+
primitives where each of them normalizes a single instance:
99+
```
100+
src = array(n, c, h, w)
101+
dst = array(n, c, h, w)
102+
gama = array(c)
103+
beta = array(c)
104+
105+
# Emulation of Instance normalization
106+
for i in range(1, n):
107+
src_i = src.offset(i, 0, 0, 0)
108+
dst_i = dst.offset(i, 0, 0, 0)
109+
dst_i = batch_norm(src_i, gama, beta)
110+
```
111+
112+
Pros:
113+
- No changes to oneDNN.
114+
115+
Cons:
116+
- In case of small feature size there is a potential performance penalty because
117+
Batch normalization primitive does not utilize batch dimension to split work
118+
across cores.
119+
120+
### 2.1.2 a Batch normalization primitive with joined batch and feature dimensions
121+
122+
Another way to support Instance normalization using batch normalization is to
123+
join batch and feature dimensions similar to OpenVino approach:
124+
```
125+
src = array(n, c, h, w)
126+
dst = array(n, c, h, w)
127+
gama = array(c)
128+
beta = array(c)
129+
130+
# Emulation of Instance normalization
131+
src_new = src.reshape(1, n * c, h, w)
132+
dst_new = dst.reshape(1, n * c, h, w)
133+
gamma_new = gamma.broadcast(n, c)
134+
beta_new = beta.broadcast(n, c)
135+
dst_new = batch_norm(src_new, gama_new, beta_new)
136+
```
137+
138+
Pros:
139+
- No changes to oneDNN.
140+
141+
Cons:
142+
- A potential performance penalty because gama and beta should be broadcasted
143+
on forward pass and reduced on backward pass.
144+
- Data should be reordered into `nchw` since in `nhwc` `n` and `c` can't be collapsed.
145+
146+
### 2.2. Layer normalization
147+
148+
Layer normalization is another type of normalization used in machine translation
149+
workloads. It computes normalization across features dimension which is the last dimension
150+
in a tensor with dimensions `[t, n, c]`. Instance normalization can be performed
151+
using layer normalization if the original tensor `[n, c, h, w]` is reshaped to `[n, c, hw]`.
152+
However gamma and beta can't be used as part of Layer normalization, because
153+
otherwise they will be applied across `hw` and not `c`:
154+
```
155+
src = array(n, c, h, w)
156+
dst = array(n, c, h, w)
157+
gama = array(c)
158+
beta = array(c)
159+
160+
# Emulation of Instance normalization
161+
src_new = src.reshape(n, c, h*w)
162+
dst_new = dst.reshape(n, c, h*w)
163+
gamma_new = gamma.reshape(1, c, 1)
164+
beta_new = beta.broadcast(1, c, 1)
165+
dst_new = gama * layer_norm(src_new) + beta
166+
```
167+
168+
Pros:
169+
- No changes to oneDNN.
170+
171+
Cons:
172+
- A performance penalty due to an additional pass over memory to apply gamma
173+
and beta parameters.
174+
175+
### 2.3. A Dedicated primitive
176+
177+
#### 2.3.1. Instance normalization
178+
179+
A dedicated primitive would ease integration of oneDNN into frameworks that
180+
support Instance normalization.
181+
182+
API:
183+
```c
184+
// include/oneapi/dnnl/dnnl.h
185+
186+
dnnl_status_t DNNL_API dnnl_instance_normalization_forward_primitive_desc_create(
187+
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
188+
dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
189+
const_dnnl_memory_desc_t dst_desc, float epsilon, unsigned flags,
190+
const_dnnl_primitive_attr_t attr);
191+
192+
dnnl_status_t DNNL_API dnnl_instance_normalization_backward_primitive_desc_create(
193+
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
194+
dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc,
195+
const_dnnl_memory_desc_t diff_dst_desc,
196+
const_dnnl_memory_desc_t src_desc, float epsilon, unsigned flags,
197+
const_dnnl_primitive_desc_t hint_fwd_pd,
198+
const_dnnl_primitive_attr_t attr);
199+
```
200+
201+
Pros:
202+
- Straight-forward integration.
203+
204+
Cons:
205+
- Application of Instance normalization is very limited.
206+
207+
208+
#### 2.3.2. Group normalization
209+
210+
Instance normalization is a special case of a Group normalization where number
211+
of groups is equal to number of channels. Group normalization was introduced as
212+
an attempt to solve Batch normalization requirement for sufficiently large batch
213+
size on training[[#2]][2].
214+
215+
```
216+
src = array(n, c, h, w)
217+
dst = array(n, c, h, w)
218+
gama = array(c)
219+
beta = array(c)
220+
221+
# Emulation of Instance normalization
222+
dst = group_norm(src, gamma, beta, n_groups=c)
223+
```
224+
225+
API:
226+
```c
227+
// include/oneapi/dnnl/dnnl.h
228+
229+
dnnl_status_t DNNL_API dnnl_group_normalization_forward_primitive_desc_create(
230+
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
231+
dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
232+
const_dnnl_memory_desc_t dst_desc, int groups, float epsilon,
233+
unsigned flags, const_dnnl_primitive_attr_t attr);
234+
235+
dnnl_status_t DNNL_API dnnl_group_normalization_backward_primitive_desc_create(
236+
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
237+
dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc,
238+
const_dnnl_memory_desc_t diff_dst_desc,
239+
const_dnnl_memory_desc_t src_desc, int groups, float epsilon,
240+
unsigned flags, const_dnnl_primitive_desc_t hint_fwd_pd,
241+
const_dnnl_primitive_attr_t attr);
242+
```
243+
244+
Pros:
245+
- Straight-forward integration.
246+
247+
Cons:
248+
- The parameter `groups` makes algorithm very flexible which makes scope of
249+
optimizations and validation wide. As a result we might end up with a sub-optimal
250+
implementation for all configurations except for `n_groups=1` and `n_groups=c`.
251+
252+
253+
### 3. Experiments
254+
255+
A loop over a Batch normalization is the only way that is flexible enough to
256+
support both `nchw` and `nhwc` tags and scale and shift parameters. The next
257+
experiment estimates how much can we gain if we add a dedicated Instance
258+
normalization primitive comparing to a loop over a Batch normalization.
259+
260+
Here is an example of Instance normalization implemented using the following approaches:
261+
1. Reshape `[n, c, h, w] -> [n, c, h*w]`+ Layer normalization;
262+
2. A loop over Batch normalization primitive created with batch size equal to 1;
263+
3. Reshape `[n, c, h, w] -> [1, n*c, h, w]` + Batch normalization.
264+
265+
<details>
266+
<summary>Source code</summary>
267+
https://github.com/igorsafo/oneDNN/blob/b9e08b84ad03d837a7def82dab1816577f7fb18e/rfcs/20230315-instance-normalization/instance_normalization.cpp#L1-L288
268+
</details>
269+
270+
We use this example to benchmark different approaches on cases from 3D U-Net:
271+
| Shape | lnorm | bnorm: loop | bnorm: collapsed dims | best of 3 vs bnorm: loop |
272+
| ---------------------- | -------- | ----------- | --------------------- | ------------------------ |
273+
| [6, 32, 160, 224, 224] | 0.277857 | 0.292688 | 0.292637 | 1.053376377 |
274+
| [6, 256, 20, 28, 28] | 0.002201 | 0.004927 | 0.004552 | 2.238527942 |
275+
| [6, 320, 10, 14, 14] | 0.000051 | 0.000777 | 0.000686 | 15.23529412 |
276+
| [6, 128, 6, 7, 7] | 0.000012 | 0.000133 | 0.000053 | 11.08333333 |
277+
278+
279+
The performance is collected on a 28 core Intel(R) Xeon(R) Platinum 8280 CPU using the following command:
280+
```sh
281+
$ numactl --physcpubind=0-27 -l ./build/examples/primitives-instance-normalization-cpp
282+
```
283+
284+
A loop over Batch normalization is slower than a single call to either layer or
285+
batch normalization. One of important differences is Instance normalization has
286+
`batch * channel` of independent normalizations that can be split across cores
287+
but a loop over a batch normalization only utilizes parallelism across `channel`
288+
dimension which is limited. When number of cores increases the difference
289+
becomes bigger.
290+
291+
292+
### 4. Proposal
293+
294+
The recommendation is to introduce a Group normalization primitive because:
295+
- it will cover Instance normalization needs for 3D U-Net:
296+
- ease of integration
297+
- additional performance
298+
- it will cover Stable diffusion:
299+
- Instance normalization in the U-Net part
300+
- Group normalization in the Attention part
301+
302+
303+
## References
304+
305+
1. [Instance Normalization: The Missing Ingredient for Fast Stylization][1]
306+
2. [Group normalization][2]
307+
3. [OpenVino PyTorch frontend Instance normalization][3]
308+
4. [TensorFlow Instance normalization][4]
309+
5. [TensorFlow Addons][5]
310+
6. [TensorFlow Group normalization][6]
311+
7. [PyTorch Instance normalization][7]
312+
8. [PyTorch Group normalization][8]
313+
9. [IntelAI models: 3D U-Net][9]
314+
10. [The Annotated Diffusion Model][10]
315+
316+
[1]: https://arxiv.org/pdf/1607.08022v3.pdf
317+
[2]: https://arxiv.org/pdf/1803.08494.pdf
318+
[3]: https://github.com/openvinotoolkit/openvino/blob/c09b2ff8b10aa344eedb28ce24ed9a6eeef5e9fb/src/frontends/pytorch/src/op/instance_norm.cpp#L48
319+
[4]: https://www.tensorflow.org/addons/api_docs/python/tfa/layers/InstanceNormalization
320+
[5]: https://www.tensorflow.org/addons
321+
[6]: https://www.tensorflow.org/api_docs/python/tf/keras/layers/GroupNormalization
322+
[7]: https://pytorch.org/docs/stable/generated/torch.nn.InstanceNorm2d.html
323+
[8]: https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html
324+
[9]: https://github.com/IntelAI/models/blob/ff7d9c7041590a78fbc8885fddc0d74d5c2564dd/models/image_segmentation/tensorflow/3d_unet/inference/fp32/unet3d/training.py#L45
325+
[10]: https://huggingface.co/blog/annotated-diffusion
326+
---
327+
328+
EOD
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading

0 commit comments

Comments
 (0)