Skip to content

Commit d6147c2

Browse files
mitruskammikolajcz
andauthored
[Op][PT FE] Enable ISTFT for Pytorch Frontend (#28743)
### Details: - Enable ISTFT for Pytorch Frontend - Adjust shape_infer for the case with odd value for frame_size (round down to even) ### Tickets: - 159383 --------- Co-authored-by: Mateusz Mikolajczyk <mateusz.mikolajczyk@intel.com>
1 parent 1a6536a commit d6147c2

File tree

6 files changed

+368
-6
lines changed

6 files changed

+368
-6
lines changed

src/core/reference/src/op/istft.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ void istft(const float* in_data,
3333
const auto num_frames = data_shape[frames_axis];
3434

3535
const auto signal_length = (num_frames - 1) * frame_step + frame_size;
36-
const int64_t final_signal_length = length > 0 ? length : (center ? (signal_length - frame_size) : signal_length);
36+
const int64_t final_signal_length =
37+
length > 0 ? length : (center ? (signal_length - (frame_size & ~1)) : signal_length);
3738
std::fill(final_result, final_result + batch_size * final_signal_length, 0.f);
3839

3940
std::vector<float> mid_result(batch_size * signal_length, 0.f);

src/core/shape_inference/include/istft_shape_inference.hpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ std::vector<TRShape> shape_infer(const ISTFT* op,
111111

112112
const int64_t frames_axis = 1 + (is_data_3D ? 0 : 1);
113113
const TDim& num_frames_dim = data_shape[frames_axis];
114-
TDim signal_length = (num_frames_dim - 1) * frame_step_val;
115-
if (!op->get_center()) {
116-
signal_length += frame_size_val;
114+
TDim signal_length = (num_frames_dim - 1) * frame_step_val + frame_size_val;
115+
if (op->get_center()) {
116+
signal_length = signal_length - (frame_size_val & ~1);
117117
}
118118
output_shapes[0][0] = std::move(signal_length);
119119
}

src/core/tests/type_prop/istft.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ INSTANTIATE_TEST_SUITE_P(
135135
std::make_tuple(PartialShape{1, 48}, PartialShape{16}, 16, 16, false, PartialShape{1, 9, 3, 2}),
136136
std::make_tuple(PartialShape{2, 48}, PartialShape{8}, 16, 4, false, PartialShape{2, 9, 9, 2}),
137137
std::make_tuple(PartialShape{2, 9}, PartialShape{5}, 9, 100, false, PartialShape{2, 5, 1, 2}),
138-
std::make_tuple(PartialShape{2, 0}, PartialShape{5}, 9, 100, true, PartialShape{2, 5, 1, 2}),
138+
std::make_tuple(PartialShape{2, 1}, PartialShape{5}, 9, 100, true, PartialShape{2, 5, 1, 2}),
139139
std::make_tuple(PartialShape{4, 47},
140140
PartialShape{7},
141141
11,
@@ -151,7 +151,7 @@ INSTANTIATE_TEST_SUITE_P(
151151
3,
152152
false,
153153
PartialShape{{2, 4}, 6, {1, -1}, 2}),
154-
std::make_tuple(PartialShape{{2, 4}, {-1, -1}},
154+
std::make_tuple(PartialShape{{2, 4}, {1, -1}},
155155
PartialShape{7},
156156
11,
157157
3,
+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "openvino/op/istft.hpp"
6+
7+
#include "openvino/frontend/complex_type_mark.hpp"
8+
#include "openvino/frontend/pytorch/node_context.hpp"
9+
#include "openvino/op/broadcast.hpp"
10+
#include "openvino/op/constant.hpp"
11+
#include "openvino/op/convert_like.hpp"
12+
#include "openvino/op/divide.hpp"
13+
#include "openvino/op/unsqueeze.hpp"
14+
#include "utils.hpp"
15+
16+
namespace ov {
17+
namespace frontend {
18+
namespace pytorch {
19+
namespace op {
20+
21+
using namespace ov::op;
22+
23+
OutputVector translate_istft(const NodeContext& context) {
24+
// aten::istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool
25+
// center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False)
26+
num_inputs_check(context, 2, 10, true);
27+
28+
auto input = context.get_input(0);
29+
auto complex_type_mark = as_type_ptr<ComplexTypeMark>(input.get_node_shared_ptr());
30+
if (complex_type_mark) {
31+
input = complex_type_mark->input_value(0);
32+
}
33+
34+
auto n_fft = context.get_input(1);
35+
36+
ov::Output<ov::Node> hop_length;
37+
if (!context.input_is_none(2)) {
38+
hop_length = context.get_input(2);
39+
} else {
40+
// Defualt floor(n_fft / 4)
41+
const auto four = context.mark_node(std::make_shared<ov::op::v0::Constant>(ov::element::i32, Shape{}, 4));
42+
const auto four_cast = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(four, n_fft));
43+
hop_length = context.mark_node(std::make_shared<ov::op::v1::Divide>(n_fft, four_cast));
44+
}
45+
46+
ov::Output<ov::Node> win_length;
47+
if (!context.input_is_none(3)) {
48+
win_length = context.get_input(3);
49+
} else {
50+
win_length = n_fft;
51+
}
52+
53+
ov::Output<ov::Node> window;
54+
if (!context.input_is_none(4)) {
55+
window = context.get_input(4);
56+
} else {
57+
const auto one = context.mark_node(std::make_shared<ov::op::v0::Constant>(ov::element::i32, Shape{}, 1));
58+
const auto one_cast = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(one, input));
59+
const auto zero = context.mark_node(std::make_shared<ov::op::v0::Constant>(ov::element::i32, Shape{1}, 0));
60+
const auto win_length_cast =
61+
context.mark_node(std::make_shared<ov::op::v0::Convert>(win_length, ov::element::i64));
62+
const auto win_len_vec = context.mark_node(std::make_shared<ov::op::v0::Unsqueeze>(win_length_cast, zero));
63+
window = context.mark_node(std::make_shared<ov::op::v3::Broadcast>(one_cast, win_len_vec));
64+
}
65+
66+
bool center = true;
67+
if (!context.input_is_none(5)) {
68+
center = context.const_input<bool>(5);
69+
}
70+
71+
bool normalized = false;
72+
if (!context.input_is_none(6)) {
73+
normalized = context.const_input<bool>(6);
74+
}
75+
76+
bool onesided = true;
77+
if (!context.input_is_none(7)) {
78+
onesided = context.const_input<bool>(7);
79+
}
80+
PYTORCH_OP_CONVERSION_CHECK(onesided, "aten::istft conversion is currently supported with onesided=True only.");
81+
82+
bool return_complex = false;
83+
if (!context.input_is_none(9)) {
84+
return_complex = context.const_input<bool>(9);
85+
}
86+
87+
// Perform ISTFT
88+
ov::Output<ov::Node> istft;
89+
if (context.input_is_none(8)) {
90+
istft = context.mark_node(std::make_shared<v16::ISTFT>(input, window, n_fft, hop_length, center, normalized));
91+
} else {
92+
auto signal_length = context.get_input(8);
93+
istft = context.mark_node(
94+
std::make_shared<v16::ISTFT>(input, window, n_fft, hop_length, signal_length, center, normalized));
95+
}
96+
97+
if (return_complex) {
98+
return {context.mark_node(std::make_shared<ComplexTypeMark>(istft, istft.get_element_type()))};
99+
} else {
100+
return {istft};
101+
}
102+
};
103+
} // namespace op
104+
} // namespace pytorch
105+
} // namespace frontend
106+
} // namespace ov

src/frontends/pytorch/src/op_table.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ OP_CONVERTER(translate_index_select);
125125
OP_CONVERTER(translate_instance_norm);
126126
OP_CONVERTER(translate_int);
127127
OP_CONVERTER(translate_inverse);
128+
OP_CONVERTER(translate_istft);
128129
OP_CONVERTER(translate_is_nonzero);
129130
OP_CONVERTER(translate_layer_norm);
130131
OP_CONVERTER(translate_len);
@@ -523,6 +524,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
523524
{"aten::Int", op::translate_int},
524525
{"aten::IntImplicit", op::translate_int},
525526
{"aten::is_grad_enabled", op::return_false_scalar},
527+
{"aten::istft", op::translate_istft},
526528
{"aten::is_nonzero", op::translate_is_nonzero},
527529
{"aten::isfinite", op::translate_1to1_match_1_inputs<opset10::IsFinite>},
528530
{"aten::isinf", op::translate_1to1_match_1_inputs<opset10::IsInf>},

0 commit comments

Comments
 (0)