Skip to content

Run Reverb using Huggingface Pipelines (Missing Model Files) #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions examples/huggingface/reverb_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Following https://huggingface.co/docs/transformers/en/custom_models
import math
from typing import Dict, List, Optional
from transformers import PretrainedConfig
import numpy as np


def cmvn(means: List[float], variance: List[float], count: int):
""" Calculate cmvn from stats

Returns:
a numpy array of [means, vars]
"""
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
# cmvn = np.array([means, variance])
return [means, variance]


class ReverbConfig(PretrainedConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this is 100% possible, but could we include .json/.yaml (or any other format) with model files and load the whole configuration of this class?

# Not sure what to put but also not reqruied: model_type = "encoderdecoder"
model_type = "reverb_asr"
def __init__(
self,
input_dim: int = 80,
output_dim: int = 10001,
cmvn_mean_stat : List[float] = [33596438528.0, 35418329088.0, 39182106624.0, 41983324160.0, 44419112960.0, 46015381504.0, 46934564864.0, 47058870272.0, 47288012800.0, 47522979840.0, 48491438080.0, 49308729344.0, 50230493184.0, 50796900352.0, 51020386304.0, 51297456128.0, 51333586944.0, 51126181888.0, 51455569920.0, 50636410880.0, 49947033600.0, 50365546496.0, 49383075840.0, 49540546560.0, 49066065920.0, 49236889600.0, 48820707328.0, 49071112192.0, 48968024064.0, 49024458752.0, 49202397184.0, 49374433280.0, 49620660224.0, 49947111424.0, 50326310912.0, 50717818880.0, 51046891520.0, 51345678336.0, 51655733248.0, 51505459200.0, 51813666816.0, 51577262080.0, 51776524288.0, 51754237952.0, 51918598144.0, 52158758912.0, 52405276672.0, 52596776960.0, 52639731712.0, 52631220224.0, 52443103232.0, 52315619328.0, 52219695104.0, 52178399232.0, 52083040256.0, 52064792576.0, 51980918784.0, 51824164864.0, 51550973952.0, 51002216448.0, 50422747136.0, 49847754752.0, 49474338816.0, 48997863424.0, 48617009152.0, 48309174272.0, 48084140032.0, 48095608832.0, 47965765632.0, 47909335040.0, 47780065280.0, 47762370560.0, 47757099008.0, 47731314688.0, 47574110208.0, 47336361984.0, 47009054720.0, 46283513856.0, 44821860352.0, 42771775488.0],
cmvn_var_stat: List[float] = [360475131904.0, 401487724544.0, 484368646144.0, 548414357504.0, 608912080896.0, 651613241344.0, 678013698048.0, 683624693760.0, 689524047872.0, 695375822848.0, 722376851456.0, 746773872640.0, 774244204544.0, 791678353408.0, 798920015872.0, 807307444224.0, 808713453568.0, 802957754368.0, 812319899648.0, 788076953600.0, 767619497984.0, 777970712576.0, 748566544384.0, 751065628672.0, 736340869120.0, 739872473088.0, 727466704896.0, 734006083584.0, 731017904128.0, 732582576128.0, 737590444032.0, 742469861376.0, 749455671296.0, 758746972160.0, 769666121728.0, 781107331072.0, 790730506240.0, 799342002176.0, 808164917248.0, 803454713856.0, 812040585216.0, 804632395776.0, 809866821632.0, 808861499392.0, 813548044288.0, 820701954048.0, 828343779328.0, 834335604736.0, 835754590208.0, 835251011584.0, 829192929280.0, 824705744896.0, 821224734720.0, 819399753728.0, 816182853632.0, 815243788288.0, 812578177024.0, 807846281216.0, 799796035584.0, 784661544960.0, 770915631104.0, 756696285184.0, 746462183424.0, 734193254400.0, 724980072448.0, 717529612288.0, 711156563968.0, 710358204416.0, 706386919424.0, 704228884480.0, 700537110528.0, 699519008768.0, 699025129472.0, 698035535872.0, 693109391360.0, 686047887360.0, 676213948416.0, 655917645824.0, 616676458496.0, 563932168192.0],
cmvn_frame_num: int = 3519342927,
encoder: str = "conformer",
encoder_activation_type: str = "swish",
encoder_attention_dropout_rate: float = 0.1,
encoder_attention_heads: int = 8,
encoder_causal: bool = True,
encoder_cnn_module_kernel: int = 31,
encoder_cnn_module_norm: str = "layer_norm",
encoder_dropout_rate: float = 0.1,
encoder_input_layer: str = "conv2d",
encoder_linear_units: int = 2048,
encoder_normalize_before: bool = True,
encoder_num_blocks: int = 18,
encoder_num_langs: int = 2,
encoder_output_size: int = 640,
encoder_pos_enc_layer_type: str = "rel_pos",
encoder_positional_dropout_rate: float = 0.1,
encoder_selfattention_layer_type: str = "rel_selfattn",
encoder_use_cnn_module: bool = True,
encoder_use_dynamic_chunk: bool = True,
decoder: str = "lslbitransformer",
decoder_attention_heads: int = 8,
decoder_dropout_rate: float = 0.1,
decoder_linear_units: int = 2048,
decoder_num_blocks: int = 6,
decoder_num_langs: int = 2,
decoder_positional_dropout_rate: float = 0.1,
decoder_r_num_blocks: int = 6,
decoder_self_attention_dropout_rate: float = 0.1,
decoder_src_attention_dropout_rate: float = 0.1,
ctc_blank_id: int = 0,
ctc_weight: float = 0.3,
lsm_weight: float = 0.1,
reverse_weight: float = 0.3,
special_tokens: Optional[Dict[str, int]] = None,
**kwargs,
):
self.input_dim = input_dim
self.output_dim = output_dim
self.encoder = encoder
self.encoder_activation_type = encoder_activation_type
self.encoder_attention_dropout_rate = encoder_attention_dropout_rate
self.encoder_attention_heads = encoder_attention_heads
self.encoder_causal = encoder_causal
self.encoder_cnn_module_kernel = encoder_cnn_module_kernel
self.encoder_cnn_module_norm = encoder_cnn_module_norm
self.encoder_dropout_rate = encoder_dropout_rate
self.encoder_input_layer = encoder_input_layer
self.encoder_linear_units = encoder_linear_units
self.encoder_normalize_before = encoder_normalize_before
self.encoder_num_blocks = encoder_num_blocks
self.encoder_num_langs = encoder_num_langs
self.encoder_output_size = encoder_output_size
self.encoder_pos_enc_layer_type = encoder_pos_enc_layer_type
self.encoder_positional_dropout_rate = encoder_positional_dropout_rate
self.encoder_selfattention_layer_type = encoder_selfattention_layer_type
self.encoder_use_cnn_module = encoder_use_cnn_module
self.encoder_use_dynamic_chunk = encoder_use_dynamic_chunk
self.decoder = decoder
self.decoder_attention_heads = decoder_attention_heads
self.decoder_dropout_rate = decoder_dropout_rate
self.decoder_linear_units = decoder_linear_units
self.decoder_num_blocks = decoder_num_blocks
self.decoder_num_langs = decoder_num_langs
self.decoder_positional_dropout_rate = decoder_positional_dropout_rate
self.decoder_r_num_blocks = decoder_r_num_blocks
self.decoder_self_attention_dropout_rate = decoder_self_attention_dropout_rate
self.decoder_src_attention_dropout_rate = decoder_src_attention_dropout_rate
self.ctc_blank_id = ctc_blank_id
self.ctc_weight = ctc_weight
self.lsm_weight = lsm_weight
self.reverse_weight = reverse_weight
if special_tokens is None:
special_tokens = {
"<blank>": 0,
"<sos>": 2,
"<eos>": 2,
"<unk>": 1,
}
self.special_tokens = special_tokens
self.cmvn_mean, self.cmvn_istd = cmvn(cmvn_mean_stat, cmvn_var_stat, cmvn_frame_num)
self.inputs_to_logits_ratio = 1
super().__init__(**kwargs)
98 changes: 98 additions & 0 deletions examples/huggingface/reverb_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Following https://huggingface.co/docs/transformers/en/custom_models

from typing import List, Optional, Tuple, Union
import torch
from transformers import PreTrainedModel
from transformers.modeling_outputs import Seq2SeqLMOutput
from wenet.transformer.asr_model import ASRModel
from wenet.transformer.cmvn import GlobalCMVN
from wenet.transformer.ctc import CTC
from wenet.transformer.decoder import LanguageSpecificBiTransformerDecoder
from wenet.transformer.encoder import ConformerEncoder
from reverb_config import ReverbConfig

class ReverbModel(PreTrainedModel):
config_class = ReverbConfig
main_input_name = "input_features"

def __init__(self, config):
super().__init__(config)
self.config = config
global_cmvn = GlobalCMVN(
torch.Tensor(config.cmvn_mean),
torch.Tensor(config.cmvn_istd),
)
encoder = ConformerEncoder(
config.input_dim,
global_cmvn=global_cmvn,
activation_type=config.encoder_activation_type,
attention_dropout_rate=config.encoder_attention_dropout_rate,
attention_heads=config.encoder_attention_heads,
causal=config.encoder_causal,
cnn_module_kernel=config.encoder_cnn_module_kernel,
cnn_module_norm=config.encoder_cnn_module_norm,
dropout_rate=config.encoder_dropout_rate,
input_layer=config.encoder_input_layer,
linear_units=config.encoder_linear_units,
normalize_before=config.encoder_normalize_before,
num_blocks=config.encoder_num_blocks,
num_langs=config.encoder_num_langs,
output_size=config.encoder_output_size,
pos_enc_layer_type=config.encoder_pos_enc_layer_type,
positional_dropout_rate=config.encoder_positional_dropout_rate,
selfattention_layer_type=config.encoder_selfattention_layer_type,
use_cnn_module=config.encoder_use_cnn_module,
use_dynamic_chunk=config.encoder_use_dynamic_chunk,
)

decoder = LanguageSpecificBiTransformerDecoder(
config.output_dim,
config.encoder_output_size,
attention_heads=config.decoder_attention_heads,
dropout_rate=config.decoder_dropout_rate,
linear_units=config.decoder_linear_units,
num_blocks=config.decoder_num_blocks,
num_langs=config.decoder_num_langs,
positional_dropout_rate=config.decoder_positional_dropout_rate,
r_num_blocks=config.decoder_r_num_blocks,
self_attention_dropout_rate=config.decoder_self_attention_dropout_rate,
src_attention_dropout_rate=config.decoder_src_attention_dropout_rate,
)

ctc = CTC(
config.output_dim,
config.encoder_output_size,
config.ctc_blank_id,
)

self.model = ASRModel(
vocab_size=config.output_dim,
encoder=encoder,
decoder=decoder,
ctc=ctc,
special_tokens=config.special_tokens,
ctc_weight=config.ctc_weight,
lsm_weight=config.lsm_weight,
reverse_weight=config.reverse_weight,
)
self.model.lsl_enc = True
self.model.lsl_dec = True

def forward(
self,
input_features=None,
feats_lengths=None,
labels=None,
labels_lengths=None,
**kwargs,
):
output = self.model.hf_forward(
input_features,
feats_lengths=feats_lengths,
labels=labels,
labels_lengths=labels_lengths,
)
return Seq2SeqLMOutput(
logits=output['ctc_probs'],
loss=output['loss'],
)
132 changes: 132 additions & 0 deletions examples/huggingface/reverb_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import json
from typing import List, Optional, Union
import numpy as np
import sentencepiece as spm
import torch
import torchaudio
from torchaudio.compliance import kaldi
from tqdm import tqdm
from transformers import BatchFeature, PreTrainedTokenizer, ProcessorMixin, SequenceFeatureExtractor
from transformers.utils import logging


logger = logging.get_logger(__name__)


class ReverbFeatureExtractor(SequenceFeatureExtractor):
model_input_names = ["input_features"]
def __init__(
self,
feature_size=80,
sampling_rate=16000,
frame_length=25,
frame_shift=10,
chunk_length=15,
padding_value=0.0,
**kwargs,
):
super().__init__(
feature_size=feature_size,
sampling_rate=sampling_rate,
padding_value=padding_value,
return_attention_mask=False,
**kwargs,
)
self.frame_length = frame_length
self.frame_shift = frame_shift
self.chunk_length = chunk_length
self.max_chunk_size = 2051
self._processor_class = "CTCWithLM"

def __call__(
self,
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
device: Optional[str] = "cpu",
sampling_rate: Optional[int] = None,
**kwargs,
) -> BatchFeature:
if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
ValueError(
f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
f" was sampled with {self.sampling_rate} and not {sampling_rate}."
" Attempting a conversion."
)
else:
logger.warning(
"It is strongly recommended to pass the `sampling_rate` argument to this function. "
"Failing to do so can result in silent errors that might be hard to debug."
)

is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
if is_batched_numpy and len(raw_speech.shape) > 2:
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
is_batched = is_batched_numpy or (
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
)

if is_batched:
raw_speech = [np.asarray([speech], dtype=np.float32) for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech, dtype=np.float32)
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
raw_speech = raw_speech.astype(np.float32)

if not is_batched:
raw_speech = [np.asarray([raw_speech])]

fbank_speech, feats_lengths = [], []
for waveform in raw_speech:
fbank_speech.append(
kaldi.fbank(
torch.tensor(waveform),
num_mel_bins=self.feature_size,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
dither=0.0,
energy_floor=0.0,
sample_frequency=self.sampling_rate,
)
)
feats_lengths.append(fbank_speech[-1].shape[0])
fbank_speech = BatchFeature({
"input_features": fbank_speech,
"feats_lengths": feats_lengths,
})
padded = self.pad(
fbank_speech,
padding="max_length",
max_length=self.max_chunk_size,
)
return padded


class ReverbTokenizer(PreTrainedTokenizer):
def __init__(
self,
model: str,
#units: str,
**kwargs,
):
self.tokenizer = spm.SentencePieceProcessor(model)
"""self.units = dict()
with open(units, 'r') as units_file:
for line in tqdm(units_file.readlines()):
token, id = line.split()
self.units[int(id)] = token.replace('▁', ' ')"""


def encode(
self,
text,
**kwargs
):
return self.tokenizer.encode(text)

def decode(
self,
token_ids,
**kwargs,
):
return self.tokenizer.decode(token_ids[token_ids.nonzero()[0]].tolist())
46 changes: 46 additions & 0 deletions examples/huggingface/transcribe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import numpy as np
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you plan to keep this example call script? maybe add some arguments then?

from pyctcdecode import build_ctcdecoder
import torch
import torchaudio
from transformers import pipeline
from transformers import AutoConfig, AutoModelForSpeechSeq2Seq
from reverb_hf import ReverbModel
from reverb_config import ReverbConfig
from reverb_processor import ReverbFeatureExtractor, ReverbTokenizer


AutoConfig.register("reverb_asr", ReverbConfig)
AutoModelForSpeechSeq2Seq.register(ReverbConfig, ReverbModel)
feature_extractor = ReverbFeatureExtractor(return_tensors='pt')
tokenizer = ReverbTokenizer(
"hf-reverb/tk.model",
)
decoder_ids = []
with open("hf-reverb/tk.units.txt", 'r') as units_file:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above, can we include this modification in the preprocessing step so the inference is as simple as possible?

for line in units_file:
token = line.split()[0]
if len(token) == 0:
continue
if token == '<blank>':
token = ''
decoder_ids.append(token)
decoder = build_ctcdecoder(decoder_ids)

transcribe = pipeline(
"automatic-speech-recognition",
model="hf-reverb",
feature_extractor=feature_extractor,
tokenizer=tokenizer,
framework='pt',
device='cpu', #crucial
decoder=decoder,
decoder_kwargs={"beam_width": 8, "token_min_logp": -10}
)
AUDIO_PATH = ""
waveform, sample_rate = torchaudio.load(AUDIO_PATH, normalize=False)
#print(waveform)
waveform = np.array(waveform.to(torch.float).reshape(-1))

chunk_size_samples = feature_extractor.chunk_length * sample_rate
for idx in range(0,len(waveform),chunk_size_samples):
print(transcribe(waveform[idx: idx+chunk_size_samples])['text'])