-
Notifications
You must be signed in to change notification settings - Fork 26
Run Reverb using Huggingface Pipelines (Missing Model Files) #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# Following https://huggingface.co/docs/transformers/en/custom_models | ||
import math | ||
from typing import Dict, List, Optional | ||
from transformers import PretrainedConfig | ||
import numpy as np | ||
|
||
|
||
def cmvn(means: List[float], variance: List[float], count: int): | ||
""" Calculate cmvn from stats | ||
|
||
Returns: | ||
a numpy array of [means, vars] | ||
""" | ||
for i in range(len(means)): | ||
means[i] /= count | ||
variance[i] = variance[i] / count - means[i] * means[i] | ||
if variance[i] < 1.0e-20: | ||
variance[i] = 1.0e-20 | ||
variance[i] = 1.0 / math.sqrt(variance[i]) | ||
# cmvn = np.array([means, variance]) | ||
return [means, variance] | ||
|
||
|
||
class ReverbConfig(PretrainedConfig): | ||
# Not sure what to put but also not reqruied: model_type = "encoderdecoder" | ||
model_type = "reverb_asr" | ||
def __init__( | ||
self, | ||
input_dim: int = 80, | ||
output_dim: int = 10001, | ||
cmvn_mean_stat : List[float] = [33596438528.0, 35418329088.0, 39182106624.0, 41983324160.0, 44419112960.0, 46015381504.0, 46934564864.0, 47058870272.0, 47288012800.0, 47522979840.0, 48491438080.0, 49308729344.0, 50230493184.0, 50796900352.0, 51020386304.0, 51297456128.0, 51333586944.0, 51126181888.0, 51455569920.0, 50636410880.0, 49947033600.0, 50365546496.0, 49383075840.0, 49540546560.0, 49066065920.0, 49236889600.0, 48820707328.0, 49071112192.0, 48968024064.0, 49024458752.0, 49202397184.0, 49374433280.0, 49620660224.0, 49947111424.0, 50326310912.0, 50717818880.0, 51046891520.0, 51345678336.0, 51655733248.0, 51505459200.0, 51813666816.0, 51577262080.0, 51776524288.0, 51754237952.0, 51918598144.0, 52158758912.0, 52405276672.0, 52596776960.0, 52639731712.0, 52631220224.0, 52443103232.0, 52315619328.0, 52219695104.0, 52178399232.0, 52083040256.0, 52064792576.0, 51980918784.0, 51824164864.0, 51550973952.0, 51002216448.0, 50422747136.0, 49847754752.0, 49474338816.0, 48997863424.0, 48617009152.0, 48309174272.0, 48084140032.0, 48095608832.0, 47965765632.0, 47909335040.0, 47780065280.0, 47762370560.0, 47757099008.0, 47731314688.0, 47574110208.0, 47336361984.0, 47009054720.0, 46283513856.0, 44821860352.0, 42771775488.0], | ||
cmvn_var_stat: List[float] = [360475131904.0, 401487724544.0, 484368646144.0, 548414357504.0, 608912080896.0, 651613241344.0, 678013698048.0, 683624693760.0, 689524047872.0, 695375822848.0, 722376851456.0, 746773872640.0, 774244204544.0, 791678353408.0, 798920015872.0, 807307444224.0, 808713453568.0, 802957754368.0, 812319899648.0, 788076953600.0, 767619497984.0, 777970712576.0, 748566544384.0, 751065628672.0, 736340869120.0, 739872473088.0, 727466704896.0, 734006083584.0, 731017904128.0, 732582576128.0, 737590444032.0, 742469861376.0, 749455671296.0, 758746972160.0, 769666121728.0, 781107331072.0, 790730506240.0, 799342002176.0, 808164917248.0, 803454713856.0, 812040585216.0, 804632395776.0, 809866821632.0, 808861499392.0, 813548044288.0, 820701954048.0, 828343779328.0, 834335604736.0, 835754590208.0, 835251011584.0, 829192929280.0, 824705744896.0, 821224734720.0, 819399753728.0, 816182853632.0, 815243788288.0, 812578177024.0, 807846281216.0, 799796035584.0, 784661544960.0, 770915631104.0, 756696285184.0, 746462183424.0, 734193254400.0, 724980072448.0, 717529612288.0, 711156563968.0, 710358204416.0, 706386919424.0, 704228884480.0, 700537110528.0, 699519008768.0, 699025129472.0, 698035535872.0, 693109391360.0, 686047887360.0, 676213948416.0, 655917645824.0, 616676458496.0, 563932168192.0], | ||
cmvn_frame_num: int = 3519342927, | ||
encoder: str = "conformer", | ||
encoder_activation_type: str = "swish", | ||
encoder_attention_dropout_rate: float = 0.1, | ||
encoder_attention_heads: int = 8, | ||
encoder_causal: bool = True, | ||
encoder_cnn_module_kernel: int = 31, | ||
encoder_cnn_module_norm: str = "layer_norm", | ||
encoder_dropout_rate: float = 0.1, | ||
encoder_input_layer: str = "conv2d", | ||
encoder_linear_units: int = 2048, | ||
encoder_normalize_before: bool = True, | ||
encoder_num_blocks: int = 18, | ||
encoder_num_langs: int = 2, | ||
encoder_output_size: int = 640, | ||
encoder_pos_enc_layer_type: str = "rel_pos", | ||
encoder_positional_dropout_rate: float = 0.1, | ||
encoder_selfattention_layer_type: str = "rel_selfattn", | ||
encoder_use_cnn_module: bool = True, | ||
encoder_use_dynamic_chunk: bool = True, | ||
decoder: str = "lslbitransformer", | ||
decoder_attention_heads: int = 8, | ||
decoder_dropout_rate: float = 0.1, | ||
decoder_linear_units: int = 2048, | ||
decoder_num_blocks: int = 6, | ||
decoder_num_langs: int = 2, | ||
decoder_positional_dropout_rate: float = 0.1, | ||
decoder_r_num_blocks: int = 6, | ||
decoder_self_attention_dropout_rate: float = 0.1, | ||
decoder_src_attention_dropout_rate: float = 0.1, | ||
ctc_blank_id: int = 0, | ||
ctc_weight: float = 0.3, | ||
lsm_weight: float = 0.1, | ||
reverse_weight: float = 0.3, | ||
special_tokens: Optional[Dict[str, int]] = None, | ||
**kwargs, | ||
): | ||
self.input_dim = input_dim | ||
self.output_dim = output_dim | ||
self.encoder = encoder | ||
self.encoder_activation_type = encoder_activation_type | ||
self.encoder_attention_dropout_rate = encoder_attention_dropout_rate | ||
self.encoder_attention_heads = encoder_attention_heads | ||
self.encoder_causal = encoder_causal | ||
self.encoder_cnn_module_kernel = encoder_cnn_module_kernel | ||
self.encoder_cnn_module_norm = encoder_cnn_module_norm | ||
self.encoder_dropout_rate = encoder_dropout_rate | ||
self.encoder_input_layer = encoder_input_layer | ||
self.encoder_linear_units = encoder_linear_units | ||
self.encoder_normalize_before = encoder_normalize_before | ||
self.encoder_num_blocks = encoder_num_blocks | ||
self.encoder_num_langs = encoder_num_langs | ||
self.encoder_output_size = encoder_output_size | ||
self.encoder_pos_enc_layer_type = encoder_pos_enc_layer_type | ||
self.encoder_positional_dropout_rate = encoder_positional_dropout_rate | ||
self.encoder_selfattention_layer_type = encoder_selfattention_layer_type | ||
self.encoder_use_cnn_module = encoder_use_cnn_module | ||
self.encoder_use_dynamic_chunk = encoder_use_dynamic_chunk | ||
self.decoder = decoder | ||
self.decoder_attention_heads = decoder_attention_heads | ||
self.decoder_dropout_rate = decoder_dropout_rate | ||
self.decoder_linear_units = decoder_linear_units | ||
self.decoder_num_blocks = decoder_num_blocks | ||
self.decoder_num_langs = decoder_num_langs | ||
self.decoder_positional_dropout_rate = decoder_positional_dropout_rate | ||
self.decoder_r_num_blocks = decoder_r_num_blocks | ||
self.decoder_self_attention_dropout_rate = decoder_self_attention_dropout_rate | ||
self.decoder_src_attention_dropout_rate = decoder_src_attention_dropout_rate | ||
self.ctc_blank_id = ctc_blank_id | ||
self.ctc_weight = ctc_weight | ||
self.lsm_weight = lsm_weight | ||
self.reverse_weight = reverse_weight | ||
if special_tokens is None: | ||
special_tokens = { | ||
"<blank>": 0, | ||
"<sos>": 2, | ||
"<eos>": 2, | ||
"<unk>": 1, | ||
} | ||
self.special_tokens = special_tokens | ||
self.cmvn_mean, self.cmvn_istd = cmvn(cmvn_mean_stat, cmvn_var_stat, cmvn_frame_num) | ||
self.inputs_to_logits_ratio = 1 | ||
super().__init__(**kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Following https://huggingface.co/docs/transformers/en/custom_models | ||
|
||
from typing import List, Optional, Tuple, Union | ||
import torch | ||
from transformers import PreTrainedModel | ||
from transformers.modeling_outputs import Seq2SeqLMOutput | ||
from wenet.transformer.asr_model import ASRModel | ||
from wenet.transformer.cmvn import GlobalCMVN | ||
from wenet.transformer.ctc import CTC | ||
from wenet.transformer.decoder import LanguageSpecificBiTransformerDecoder | ||
from wenet.transformer.encoder import ConformerEncoder | ||
from reverb_config import ReverbConfig | ||
|
||
class ReverbModel(PreTrainedModel): | ||
config_class = ReverbConfig | ||
main_input_name = "input_features" | ||
|
||
def __init__(self, config): | ||
super().__init__(config) | ||
self.config = config | ||
global_cmvn = GlobalCMVN( | ||
torch.Tensor(config.cmvn_mean), | ||
torch.Tensor(config.cmvn_istd), | ||
) | ||
encoder = ConformerEncoder( | ||
config.input_dim, | ||
global_cmvn=global_cmvn, | ||
activation_type=config.encoder_activation_type, | ||
attention_dropout_rate=config.encoder_attention_dropout_rate, | ||
attention_heads=config.encoder_attention_heads, | ||
causal=config.encoder_causal, | ||
cnn_module_kernel=config.encoder_cnn_module_kernel, | ||
cnn_module_norm=config.encoder_cnn_module_norm, | ||
dropout_rate=config.encoder_dropout_rate, | ||
input_layer=config.encoder_input_layer, | ||
linear_units=config.encoder_linear_units, | ||
normalize_before=config.encoder_normalize_before, | ||
num_blocks=config.encoder_num_blocks, | ||
num_langs=config.encoder_num_langs, | ||
output_size=config.encoder_output_size, | ||
pos_enc_layer_type=config.encoder_pos_enc_layer_type, | ||
positional_dropout_rate=config.encoder_positional_dropout_rate, | ||
selfattention_layer_type=config.encoder_selfattention_layer_type, | ||
use_cnn_module=config.encoder_use_cnn_module, | ||
use_dynamic_chunk=config.encoder_use_dynamic_chunk, | ||
) | ||
|
||
decoder = LanguageSpecificBiTransformerDecoder( | ||
config.output_dim, | ||
config.encoder_output_size, | ||
attention_heads=config.decoder_attention_heads, | ||
dropout_rate=config.decoder_dropout_rate, | ||
linear_units=config.decoder_linear_units, | ||
num_blocks=config.decoder_num_blocks, | ||
num_langs=config.decoder_num_langs, | ||
positional_dropout_rate=config.decoder_positional_dropout_rate, | ||
r_num_blocks=config.decoder_r_num_blocks, | ||
self_attention_dropout_rate=config.decoder_self_attention_dropout_rate, | ||
src_attention_dropout_rate=config.decoder_src_attention_dropout_rate, | ||
) | ||
|
||
ctc = CTC( | ||
config.output_dim, | ||
config.encoder_output_size, | ||
config.ctc_blank_id, | ||
) | ||
|
||
self.model = ASRModel( | ||
vocab_size=config.output_dim, | ||
encoder=encoder, | ||
decoder=decoder, | ||
ctc=ctc, | ||
special_tokens=config.special_tokens, | ||
ctc_weight=config.ctc_weight, | ||
lsm_weight=config.lsm_weight, | ||
reverse_weight=config.reverse_weight, | ||
) | ||
self.model.lsl_enc = True | ||
self.model.lsl_dec = True | ||
|
||
def forward( | ||
self, | ||
input_features=None, | ||
feats_lengths=None, | ||
labels=None, | ||
labels_lengths=None, | ||
**kwargs, | ||
): | ||
output = self.model.hf_forward( | ||
input_features, | ||
feats_lengths=feats_lengths, | ||
labels=labels, | ||
labels_lengths=labels_lengths, | ||
) | ||
return Seq2SeqLMOutput( | ||
logits=output['ctc_probs'], | ||
loss=output['loss'], | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import json | ||
from typing import List, Optional, Union | ||
import numpy as np | ||
import sentencepiece as spm | ||
import torch | ||
import torchaudio | ||
from torchaudio.compliance import kaldi | ||
from tqdm import tqdm | ||
from transformers import BatchFeature, PreTrainedTokenizer, ProcessorMixin, SequenceFeatureExtractor | ||
from transformers.utils import logging | ||
|
||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
class ReverbFeatureExtractor(SequenceFeatureExtractor): | ||
model_input_names = ["input_features"] | ||
def __init__( | ||
self, | ||
feature_size=80, | ||
sampling_rate=16000, | ||
frame_length=25, | ||
frame_shift=10, | ||
chunk_length=15, | ||
padding_value=0.0, | ||
**kwargs, | ||
): | ||
super().__init__( | ||
feature_size=feature_size, | ||
sampling_rate=sampling_rate, | ||
padding_value=padding_value, | ||
return_attention_mask=False, | ||
**kwargs, | ||
) | ||
self.frame_length = frame_length | ||
self.frame_shift = frame_shift | ||
self.chunk_length = chunk_length | ||
self.max_chunk_size = 2051 | ||
self._processor_class = "CTCWithLM" | ||
|
||
def __call__( | ||
self, | ||
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], | ||
device: Optional[str] = "cpu", | ||
sampling_rate: Optional[int] = None, | ||
**kwargs, | ||
) -> BatchFeature: | ||
if sampling_rate is not None: | ||
if sampling_rate != self.sampling_rate: | ||
ValueError( | ||
f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a" | ||
f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input" | ||
f" was sampled with {self.sampling_rate} and not {sampling_rate}." | ||
" Attempting a conversion." | ||
) | ||
else: | ||
logger.warning( | ||
"It is strongly recommended to pass the `sampling_rate` argument to this function. " | ||
"Failing to do so can result in silent errors that might be hard to debug." | ||
) | ||
|
||
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 | ||
if is_batched_numpy and len(raw_speech.shape) > 2: | ||
raise ValueError(f"Only mono-channel audio is supported for input to {self}") | ||
is_batched = is_batched_numpy or ( | ||
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) | ||
) | ||
|
||
if is_batched: | ||
raw_speech = [np.asarray([speech], dtype=np.float32) for speech in raw_speech] | ||
elif not is_batched and not isinstance(raw_speech, np.ndarray): | ||
raw_speech = np.asarray(raw_speech, dtype=np.float32) | ||
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): | ||
raw_speech = raw_speech.astype(np.float32) | ||
|
||
if not is_batched: | ||
raw_speech = [np.asarray([raw_speech])] | ||
|
||
fbank_speech, feats_lengths = [], [] | ||
for waveform in raw_speech: | ||
fbank_speech.append( | ||
kaldi.fbank( | ||
torch.tensor(waveform), | ||
num_mel_bins=self.feature_size, | ||
frame_length=self.frame_length, | ||
frame_shift=self.frame_shift, | ||
dither=0.0, | ||
energy_floor=0.0, | ||
sample_frequency=self.sampling_rate, | ||
) | ||
) | ||
feats_lengths.append(fbank_speech[-1].shape[0]) | ||
fbank_speech = BatchFeature({ | ||
"input_features": fbank_speech, | ||
"feats_lengths": feats_lengths, | ||
}) | ||
padded = self.pad( | ||
fbank_speech, | ||
padding="max_length", | ||
max_length=self.max_chunk_size, | ||
) | ||
return padded | ||
|
||
|
||
class ReverbTokenizer(PreTrainedTokenizer): | ||
def __init__( | ||
self, | ||
model: str, | ||
#units: str, | ||
**kwargs, | ||
): | ||
self.tokenizer = spm.SentencePieceProcessor(model) | ||
"""self.units = dict() | ||
with open(units, 'r') as units_file: | ||
for line in tqdm(units_file.readlines()): | ||
token, id = line.split() | ||
self.units[int(id)] = token.replace('▁', ' ')""" | ||
|
||
|
||
def encode( | ||
self, | ||
text, | ||
**kwargs | ||
): | ||
return self.tokenizer.encode(text) | ||
|
||
def decode( | ||
self, | ||
token_ids, | ||
**kwargs, | ||
): | ||
return self.tokenizer.decode(token_ids[token_ids.nonzero()[0]].tolist()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import numpy as np | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you plan to keep this example call script? maybe add some arguments then? |
||
from pyctcdecode import build_ctcdecoder | ||
import torch | ||
import torchaudio | ||
from transformers import pipeline | ||
from transformers import AutoConfig, AutoModelForSpeechSeq2Seq | ||
from reverb_hf import ReverbModel | ||
from reverb_config import ReverbConfig | ||
from reverb_processor import ReverbFeatureExtractor, ReverbTokenizer | ||
|
||
|
||
AutoConfig.register("reverb_asr", ReverbConfig) | ||
AutoModelForSpeechSeq2Seq.register(ReverbConfig, ReverbModel) | ||
feature_extractor = ReverbFeatureExtractor(return_tensors='pt') | ||
tokenizer = ReverbTokenizer( | ||
"hf-reverb/tk.model", | ||
) | ||
decoder_ids = [] | ||
with open("hf-reverb/tk.units.txt", 'r') as units_file: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above, can we include this modification in the preprocessing step so the inference is as simple as possible? |
||
for line in units_file: | ||
token = line.split()[0] | ||
if len(token) == 0: | ||
continue | ||
if token == '<blank>': | ||
token = '' | ||
decoder_ids.append(token) | ||
decoder = build_ctcdecoder(decoder_ids) | ||
|
||
transcribe = pipeline( | ||
"automatic-speech-recognition", | ||
model="hf-reverb", | ||
feature_extractor=feature_extractor, | ||
tokenizer=tokenizer, | ||
framework='pt', | ||
device='cpu', #crucial | ||
decoder=decoder, | ||
decoder_kwargs={"beam_width": 8, "token_min_logp": -10} | ||
) | ||
AUDIO_PATH = "" | ||
waveform, sample_rate = torchaudio.load(AUDIO_PATH, normalize=False) | ||
#print(waveform) | ||
waveform = np.array(waveform.to(torch.float).reshape(-1)) | ||
|
||
chunk_size_samples = feature_extractor.chunk_length * sample_rate | ||
for idx in range(0,len(waveform),chunk_size_samples): | ||
print(transcribe(waveform[idx: idx+chunk_size_samples])['text']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if this is 100% possible, but could we include .json/.yaml (or any other format) with model files and load the whole configuration of this class?