Skip to content

Commit 9284ad1

Browse files
authored
fix: set lower length_ms as default (#66)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
1 parent bc58957 commit 9284ad1

File tree

8 files changed

+133
-148
lines changed

8 files changed

+133
-148
lines changed

BUILD.bazel

+1
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ pybind_extension(
128128
deps = [
129129
":audio_lib",
130130
":context_lib",
131+
"@com_github_ggerganov_whisper//:common",
131132
],
132133
)
133134

examples/stream/stream.py

+24-14
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,23 @@
11
"""Some streaming examples."""
22

3+
import os
34
import sys
45
import typing as t
56

67
import whispercpp as w
78

89

910
def main(**kwargs: t.Any):
11+
kwargs.pop("list_audio_devices")
12+
mname = kwargs.pop("model_name", os.getenv("GGML_MODEL", "tiny.en"))
1013
iterator: t.Iterator[str] | None = None
1114
try:
12-
iterator = w.Whisper.from_pretrained(kwargs["model_name"]).stream_transcribe(
13-
length_ms=kwargs["length_ms"],
14-
device_id=kwargs["device_id"],
15-
sample_rate=kwargs["sample_rate"],
16-
step_ms=kwargs["step_ms"],
17-
n_threads=kwargs["n_threads"],
18-
no_timestamp=True,
19-
)
15+
iterator = w.Whisper.from_pretrained(mname).stream_transcribe(**kwargs)
2016
finally:
2117
assert iterator is not None, "Something went wrong!"
22-
sys.stderr.writelines(f"- {it}\n" for it in iterator)
23-
sys.stderr.write("Transcriptions:\n")
18+
sys.stderr.writelines(
19+
["\nTranscription (line by line):\n"] + [f"{it}\n" for it in iterator]
20+
)
2421
sys.stderr.flush()
2522

2623

@@ -46,18 +43,31 @@ def main(**kwargs: t.Any):
4643
help="Sample rate of the audio device",
4744
default=w.api.SAMPLE_RATE,
4845
)
46+
parser.add_argument(
47+
"--n_threads",
48+
type=int,
49+
help="Number of threads to use for decoding",
50+
default=8,
51+
)
4952
parser.add_argument(
5053
"--step_ms",
5154
type=int,
5255
help="Step size of the audio buffer in milliseconds",
53-
default=500,
56+
default=2000,
5457
)
5558
parser.add_argument(
56-
"--n_threads",
59+
"--keep_ms",
5760
type=int,
58-
help="Number of threads to use for decoding",
59-
default=4,
61+
help="Length of the audio buffer to keep in milliseconds",
62+
default=200,
63+
)
64+
parser.add_argument(
65+
"--max_tokens",
66+
type=int,
67+
help="Maximum number of tokens to decode",
68+
default=32,
6069
)
70+
parser.add_argument("--audio_ctx", type=int, help="Audio context", default=0)
6171
parser.add_argument(
6272
"--list_audio_devices",
6373
action="store_true",

rules/deps.bzl

+3-3
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@ def internal_deps():
8181
http_archive(
8282
name = "com_github_libsdl_sdl2",
8383
build_file = Label("//extern:sdl2.BUILD"),
84-
sha256 = "03ab539ff65f6f544969eb3fed138a3fd7224496aa8404eda5e8355877b6dca1",
85-
strip_prefix = "SDL-6c495a92f0bbc5637d565b5339afa943a78108f7",
86-
urls = ["https://github.com/libsdl-org/SDL/archive/6c495a92f0bbc5637d565b5339afa943a78108f7.zip"],
84+
sha256 = "e2ac043bd2b67be328f875043617b904a0bb7d277ba239fe8ac6b9c94b85cbac",
85+
strip_prefix = "SDL-dca3fd8307c2c9ebda8d8ea623bbbf19649f5e22",
86+
urls = ["https://github.com/libsdl-org/SDL/archive/dca3fd8307c2c9ebda8d8ea623bbbf19649f5e22.zip"],
8787
)
8888

8989
git_repository(

src/whispercpp/__init__.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(self, *args: t.Any, **kwargs: t.Any):
7373
params: api.Params
7474
no_state: bool
7575
basedir: str | None
76+
_transcript: list[str]
7677

7778
_context_initialized: bool = False
7879

@@ -112,6 +113,7 @@ def from_pretrained(
112113
)
113114
context.reset_timings()
114115
_context_initialized = not no_state
116+
_transcript = []
115117
_ref.__dict__.update(locals())
116118
return _ref
117119

@@ -149,6 +151,7 @@ def from_params(
149151
)
150152
context.reset_timings()
151153
_context_initialized = not no_state
154+
_transcript = []
152155
_ref.__dict__.update(locals())
153156
return _ref
154157

@@ -214,9 +217,8 @@ def stream_transcribe(
214217
device_id: int = 0,
215218
sample_rate: int | None = None,
216219
**kwargs: t.Any,
217-
) -> t.Generator[str, None, list[str]]:
218-
"""
219-
Streaming transcription from microphone. Note that this function is blocking.
220+
) -> list[str]:
221+
"""Streaming transcription from microphone. Note that this function is blocking.
220222
221223
Args:
222224
length_ms (int, optional): Length of audio to transcribe in milliseconds. Defaults to 10000.
@@ -227,30 +229,34 @@ def stream_transcribe(
227229
Returns:
228230
A generator of all transcripted text from given audio device.
229231
"""
230-
is_running = True
231-
232232
if sample_rate is None:
233233
sample_rate = api.SAMPLE_RATE
234-
length_ms = kwargs.pop("length_ms", 10000)
234+
if "length_ms" not in kwargs:
235+
kwargs["length_ms"] = 5000
236+
if "step_ms" not in kwargs:
237+
kwargs["step_ms"] = 700
238+
239+
if kwargs["step_ms"] < 500:
240+
raise ValueError("step_ms must be >= 500")
235241

236-
ac = audio.AudioCapture(length_ms)
242+
ac = audio.AudioCapture(kwargs["length_ms"])
237243
if not ac.init_device(device_id, sample_rate):
238244
raise RuntimeError("Failed to initialize audio capture device.")
239245

246+
self.params.on_new_segment(self._store_transcript_handler, self._transcript)
247+
240248
try:
241-
while is_running:
242-
is_running = audio.sdl_poll_events()
243-
if not is_running:
244-
break
245-
ac.stream_transcribe(
246-
self.context, self.params, length_ms=length_ms, **kwargs
247-
)
249+
ac.stream_transcribe(self.context, self.params, **kwargs)
248250
except KeyboardInterrupt:
249251
# handled from C++
250252
pass
251-
finally:
252-
yield from ac.transcript
253-
return ac.transcript
253+
return self._transcript
254+
255+
def _store_transcript_handler(self, ctx: api.Context, n_new: int, data: list[str]):
256+
segment = ctx.full_n_segments() - n_new
257+
while segment < ctx.full_n_segments():
258+
data.append(ctx.full_get_segment_text(segment))
259+
segment += 1
254260

255261

256262
__all__ = ["Whisper", "api", "utils", "audio"]

src/whispercpp/__init__.pyi

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import Iterator
34
from typing import overload
4-
from typing import Generator
55
from typing import TYPE_CHECKING
66

77
from . import api as api
@@ -31,7 +31,7 @@ class Whisper:
3131
self, filename: str, num_proc: int = ..., strict: bool = ...
3232
) -> str: ...
3333
@overload
34-
def stream_transcribe(self) -> Generator[str, None, list[str]]: ...
34+
def stream_transcribe(self) -> Iterator[str]: ...
3535
@overload
3636
def stream_transcribe(
3737
self,
@@ -40,7 +40,7 @@ class Whisper:
4040
device_id: int = ...,
4141
sample_rate: int | None = ...,
4242
step_ms: int = ...,
43-
) -> Generator[str, None, list[str]]: ...
43+
) -> Iterator[str]: ...
4444
@classmethod
4545
@overload
4646
def from_pretrained(cls, model_name: str) -> Whisper: ...

0 commit comments

Comments
 (0)