-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
Copy pathtest_stft.py
152 lines (125 loc) · 6.52 KB
/
test_stft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestSTFT(PytorchLayerTest):
def _prepare_input(self, win_length, signal_shape, rand_data=False, out_dtype="float32"):
import numpy as np
if rand_data:
signal = np.random.randn(*signal_shape).astype(out_dtype)
else:
num_samples = signal_shape[-1]
half_idx = num_samples // 2
t = np.linspace(0, 1, num_samples)
signal = np.sin(2 * np.pi * 5 * t)
signal[half_idx:] += np.sin(2 * np.pi * 10 * t[half_idx:])
signal = np.broadcast_to(signal, signal_shape).astype(out_dtype)
window = np.hanning(win_length).reshape([win_length])
return (signal, window.astype(out_dtype))
def create_model(self, n_fft, hop_length, win_length, normalized):
import torch
class aten_stft(torch.nn.Module):
def __init__(self, n_fft, hop_length, win_length, normalized):
super(aten_stft, self).__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.normalized = normalized
def forward(self, x, window):
return torch.stft(
x,
self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=window,
center=False,
pad_mode="reflect",
normalized=self.normalized,
onesided=True,
return_complex=False,
)
ref_net = None
return aten_stft(n_fft, hop_length, win_length, normalized), ref_net, "aten::stft"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize(("trace_model"), [True, False])
@pytest.mark.parametrize(("signal_shape"), [(1, 256), (2, 128), (128,)])
@pytest.mark.parametrize(("n_fft", "hop_length", "window_size"), [
[16, 4, 16],
[32, 32, 32],
[32, 16, 24],
[24, 32, 20],
[128, 128, 128],
])
@pytest.mark.parametrize(("normalized"), [True, False])
def test_stft(self, n_fft, hop_length, window_size, signal_shape, normalized, ie_device, precision, ir_version, trace_model):
if ie_device == "GPU":
pytest.xfail(reason="STFT op is not supported on GPU yet")
self._test(*self.create_model(n_fft, hop_length, window_size, normalized), ie_device, precision,
ir_version, kwargs_to_prepare_input={"win_length": window_size, "signal_shape": signal_shape}, trace_model=trace_model)
class TestSTFTAttrs(PytorchLayerTest):
def _prepare_input(self, out=False, out_dtype="float32"):
import numpy as np
signal = np.random.randn(2, 512).astype(out_dtype)
return (signal,)
def create_model_with_attrs(self, n_fft, hop_length, win_length, center, pad_mode, normalized, onesided, return_complex):
import torch
class aten_stft_attrs(torch.nn.Module):
def __init__(self, n_fft, hop_length, win_length, center, pad_mode, normalized, onesided, return_complex):
super(aten_stft_attrs, self).__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.window = None # Default window
self.center = center
self.pad_mode = pad_mode
self.normalized = normalized
self.onesided = onesided
self.return_complex = return_complex
def forward(self, x):
stft = torch.stft(
x,
self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode=self.pad_mode,
normalized=self.normalized,
onesided=self.onesided,
return_complex=self.return_complex,
)
if self.return_complex:
return torch.view_as_real(stft)
else:
return stft
ref_net = None
return aten_stft_attrs(n_fft, hop_length, win_length, center, pad_mode, normalized, onesided, return_complex), ref_net, "aten::stft"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize(("trace_model"), [True, False])
@pytest.mark.parametrize(("n_fft", "hop_length", "win_length", "center", "pad_mode", "normalized", "onesided", "return_complex"), [
[16, 4, 16, False, "reflect", False, True, False], # default window
[16, 4, 14, True, "reflect", False, True, False], # center True
[16, 4, 14, True, "reflect", False, True, False], # center True
[16, 4, 14, True, "replicate", False, True, False], # center True
[16, 4, 14, False, "replicate", False, True, False], # center False
[16, None, 16, False, "reflect", False, True, False], # hop_length None
[16, None, None, False, "reflect", False, True, False], # hop & win length None
[16, 4, None, False, "reflect", False, True, False], # win_length None
[16, 4, 16, False, "reflect", True, True, False], # normalized True
[16, 4, 16, False, "reflect", False, True, True], # return_complex True
# Unsupported cases:
[16, 4, 16, False, "reflect", False, False, False], # onesided False
])
def test_stft_not_supported_attrs(self, n_fft, hop_length, win_length, center, pad_mode, normalized, onesided, return_complex, ie_device, precision, ir_version, trace_model):
if ie_device == "GPU":
pytest.xfail(reason="STFT op is not supported on GPU yet")
if center is True and trace_model is False:
pytest.xfail(
reason="torch stft uses list() for `center` subgrpah before aten::stft, that leads to error: No conversion rule found for operations: aten::list")
if onesided is False:
pytest.xfail(
reason="aten::stft conversion is currently supported with onesided=True only")
self._test(*self.create_model_with_attrs(n_fft, hop_length, win_length, center, pad_mode, normalized, onesided, return_complex), ie_device, precision,
ir_version, kwargs_to_prepare_input={}, trace_model=trace_model)