-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdeepspeech2_train.py
88 lines (77 loc) · 3.12 KB
/
deepspeech2_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from load_dataset import load_hf_dataset
from data_collators import MelSpectrogramDataCollator
from deepspeech2_model import DeepSpeech2Config, LightDeepSpeech2
from vocab import Vocab
from tokenizer import Tokenizer
# Hyper-parameters
train_conf = {
'exp_name': 'deepspeech/test1',
'epochs': 500,
'batch_size': 96,
'accumulate_grad_batches': 1,
'learning_rate': 1e-3,
'log_interval': 0.25,
'early_stopping_patience': 20, # 5 epochs
'train_dataloader_workers': 4,
'test_dataloader_workers': 2,
}
# Config
sampling_rate = 16000
# Load the vocab and tokenizer
vocab = Vocab.from_json('deepspeech2_vocab.json')
tokenizer = Tokenizer(vocab)
# Create the config
config = DeepSpeech2Config(
n_mels=80,
hidden_units=384,
tokenizer=tokenizer,
# Training Parameters
learning_rate=train_conf['learning_rate']
)
# Load the dataset
torch.manual_seed(48)
train_dataset = load_hf_dataset('train', sampling_rate=sampling_rate, with_features=True)
test_dataset = load_hf_dataset('test', sampling_rate=sampling_rate, with_features=True)
# Tokenize the dataset
tokenize_fn = lambda batch: {'labels': tokenizer.tokenize(batch['sentence'])}
train_dataset = train_dataset.map(tokenize_fn, batched=True, batch_size=32, remove_columns=['sentence'])
test_dataset = test_dataset.map(tokenize_fn, batched=True, batch_size=32, remove_columns=['sentence'])
# Createa dataloaders
data_collator = MelSpectrogramDataCollator()
train_dataloader = DataLoader(train_dataset, batch_size=train_conf['batch_size'], collate_fn=data_collator, shuffle=True, num_workers=train_conf['train_dataloader_workers'])
test_dataloader = DataLoader(test_dataset, batch_size=train_conf['batch_size'], collate_fn=data_collator, num_workers=train_conf['test_dataloader_workers'])
# Callbacks
lr_monitor = LearningRateMonitor(logging_interval='step')
early_stopping = EarlyStopping(monitor="val_cer", mode="min", patience=train_conf['early_stopping_patience'])
checkpoint_callback = ModelCheckpoint(
save_top_k=1,
monitor='val_cer',
mode='min',
filename='checkpoint-{epoch:02d}-{val_cer:.2f}',
)
callbacks = [checkpoint_callback, lr_monitor, early_stopping]
# Train
steps_per_epoch = int(len(train_dataset) / (train_conf['batch_size'] * train_conf['accumulate_grad_batches']))
log_steps = max(1, int(steps_per_epoch * train_conf['log_interval']))
model = LightDeepSpeech2(config)
trainer = L.Trainer(
default_root_dir=f'exps/{train_conf["exp_name"]}',
accelerator='gpu' if torch.cuda.is_available() else 'cpu',
max_epochs=train_conf['epochs'],
accumulate_grad_batches=train_conf['accumulate_grad_batches'],
log_every_n_steps=log_steps,
val_check_interval=train_conf['log_interval'],
enable_model_summary=False,
callbacks=[checkpoint_callback, lr_monitor, early_stopping],
num_sanity_val_steps=0,
)
trainer.fit(
model=model,
train_dataloaders=train_dataloader,
val_dataloaders=test_dataloader
)