Skip to content

Commit 80156dd

Browse files
authored
Training with byte level BPE (AIShell) (#986)
* copy files from zipformer librispeech * Add byte bpe training for aishell * compile LG graph * Support LG decoding * Minor fixes * black * Minor fixes * export & fix pretrain.py * fix black * Update RESULTS.md * Fix export.py
1 parent 61ec3a7 commit 80156dd

30 files changed

+3992
-47
lines changed

egs/aishell/ASR/RESULTS.md

+52-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,57 @@
22

33
### Aishell training result(Stateless Transducer)
44

5+
#### Pruned transducer stateless 7 (zipformer)
6+
7+
See <https://github.com/k2-fsa/icefall/pull/986>
8+
9+
[./pruned_transducer_stateless7_bbpe](./pruned_transducer_stateless7_bbpe)
10+
11+
**Note**: The modeling units are byte level BPEs
12+
13+
The best results I have gotten are:
14+
15+
Vocab size | Greedy search(dev & test) | Modified beam search(dev & test) | Fast beam search (dev & test) | Fast beam search LG (dev & test) | comments
16+
-- | -- | -- | -- | -- | --
17+
500 | 4.31 & 4.59 | 4.25 & 4.54 | 4.27 & 4.55 | 4.07 & 4.38 | --epoch 48 --avg 29
18+
19+
The training command:
20+
21+
```
22+
export CUDA_VISIBLE_DEVICES="4,5,6,7"
23+
24+
./pruned_transducer_stateless7_bbpe/train.py \
25+
--world-size 4 \
26+
--num-epochs 50 \
27+
--start-epoch 1 \
28+
--use-fp16 1 \
29+
--max-duration 800 \
30+
--bpe-model data/lang_bbpe_500/bbpe.model \
31+
--exp-dir pruned_transducer_stateless7_bbpe/exp \
32+
--lr-epochs 6 \
33+
--master-port 12535
34+
```
35+
36+
The decoding command:
37+
38+
```
39+
for m in greedy_search modified_beam_search fast_beam_search fast_beam_search_LG; do
40+
./pruned_transducer_stateless7_bbpe/decode.py \
41+
--epoch 48 \
42+
--avg 29 \
43+
--exp-dir ./pruned_transducer_stateless7_bbpe/exp \
44+
--max-sym-per-frame 1 \
45+
--ngram-lm-scale 0.25 \
46+
--ilme-scale 0.2 \
47+
--bpe-model data/lang_bbpe_500/bbpe.model \
48+
--max-duration 2000 \
49+
--decoding-method $m
50+
done
51+
```
52+
53+
The pretrained model is available at: https://huggingface.co/pkufool/icefall_asr_aishell_pruned_transducer_stateless7_bbpe
54+
55+
556
#### Pruned transducer stateless 3
657

758
See <https://github.com/k2-fsa/icefall/pull/436>
@@ -75,7 +126,7 @@ for epoch in 29; do
75126
done
76127
```
77128

78-
We provide the option of shallow fusion with a RNN language model. The pre-trained language model is
129+
We provide the option of shallow fusion with a RNN language model. The pre-trained language model is
79130
available at <https://huggingface.co/marcoyang/icefall-aishell-rnn-lm>. To decode with the language model,
80131
please use the following command:
81132

egs/aishell/ASR/local/compile_lg.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../../librispeech/ASR/local/compile_lg.py

egs/aishell/ASR/local/prepare_char.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
- tokens.txt
3434
"""
3535

36+
import argparse
3637
import re
3738
from pathlib import Path
3839
from typing import Dict, List
@@ -189,8 +190,22 @@ def generate_tokens(text_file: str) -> Dict[str, int]:
189190
return tokens
190191

191192

193+
def get_args():
194+
parser = argparse.ArgumentParser()
195+
parser.add_argument(
196+
"--lang-dir",
197+
type=str,
198+
help="""Input and output directory.
199+
It should contain the bpe.model and words.txt
200+
""",
201+
)
202+
203+
return parser.parse_args()
204+
205+
192206
def main():
193-
lang_dir = Path("data/lang_char")
207+
args = get_args()
208+
lang_dir = Path(args.lang_dir)
194209
text_file = lang_dir / "text"
195210

196211
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
+267
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
3+
# Wei Kang)
4+
#
5+
# See ../../../../LICENSE for clarification regarding multiple authors
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
19+
20+
"""
21+
22+
This script takes as input `lang_dir`, which should contain::
23+
24+
- lang_dir/bbpe.model,
25+
- lang_dir/words.txt
26+
27+
and generates the following files in the directory `lang_dir`:
28+
29+
- lexicon.txt
30+
- lexicon_disambig.txt
31+
- L.pt
32+
- L_disambig.pt
33+
- tokens.txt
34+
"""
35+
36+
import argparse
37+
from pathlib import Path
38+
from typing import Dict, List, Tuple
39+
40+
import k2
41+
import sentencepiece as spm
42+
import torch
43+
from prepare_lang import (
44+
Lexicon,
45+
add_disambig_symbols,
46+
add_self_loops,
47+
write_lexicon,
48+
write_mapping,
49+
)
50+
51+
from icefall.byte_utils import byte_encode
52+
from icefall.utils import str2bool, tokenize_by_CJK_char
53+
54+
55+
def lexicon_to_fst_no_sil(
56+
lexicon: Lexicon,
57+
token2id: Dict[str, int],
58+
word2id: Dict[str, int],
59+
need_self_loops: bool = False,
60+
) -> k2.Fsa:
61+
"""Convert a lexicon to an FST (in k2 format).
62+
63+
Args:
64+
lexicon:
65+
The input lexicon. See also :func:`read_lexicon`
66+
token2id:
67+
A dict mapping tokens to IDs.
68+
word2id:
69+
A dict mapping words to IDs.
70+
need_self_loops:
71+
If True, add self-loop to states with non-epsilon output symbols
72+
on at least one arc out of the state. The input label for this
73+
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
74+
Returns:
75+
Return an instance of `k2.Fsa` representing the given lexicon.
76+
"""
77+
loop_state = 0 # words enter and leave from here
78+
next_state = 1 # the next un-allocated state, will be incremented as we go
79+
80+
arcs = []
81+
82+
# The blank symbol <blk> is defined in local/train_bpe_model.py
83+
assert token2id["<blk>"] == 0
84+
assert word2id["<eps>"] == 0
85+
86+
eps = 0
87+
88+
for word, pieces in lexicon:
89+
assert len(pieces) > 0, f"{word} has no pronunciations"
90+
cur_state = loop_state
91+
92+
word = word2id[word]
93+
pieces = [token2id[i] for i in pieces]
94+
95+
for i in range(len(pieces) - 1):
96+
w = word if i == 0 else eps
97+
arcs.append([cur_state, next_state, pieces[i], w, 0])
98+
99+
cur_state = next_state
100+
next_state += 1
101+
102+
# now for the last piece of this word
103+
i = len(pieces) - 1
104+
w = word if i == 0 else eps
105+
arcs.append([cur_state, loop_state, pieces[i], w, 0])
106+
107+
if need_self_loops:
108+
disambig_token = token2id["#0"]
109+
disambig_word = word2id["#0"]
110+
arcs = add_self_loops(
111+
arcs,
112+
disambig_token=disambig_token,
113+
disambig_word=disambig_word,
114+
)
115+
116+
final_state = next_state
117+
arcs.append([loop_state, final_state, -1, -1, 0])
118+
arcs.append([final_state])
119+
120+
arcs = sorted(arcs, key=lambda arc: arc[0])
121+
arcs = [[str(i) for i in arc] for arc in arcs]
122+
arcs = [" ".join(arc) for arc in arcs]
123+
arcs = "\n".join(arcs)
124+
125+
fsa = k2.Fsa.from_str(arcs, acceptor=False)
126+
return fsa
127+
128+
129+
def generate_lexicon(
130+
model_file: str, words: List[str], oov: str
131+
) -> Tuple[Lexicon, Dict[str, int]]:
132+
"""Generate a lexicon from a BPE model.
133+
134+
Args:
135+
model_file:
136+
Path to a sentencepiece model.
137+
words:
138+
A list of strings representing words.
139+
oov:
140+
The out of vocabulary word in lexicon.
141+
Returns:
142+
Return a tuple with two elements:
143+
- A dict whose keys are words and values are the corresponding
144+
word pieces.
145+
- A dict representing the token symbol, mapping from tokens to IDs.
146+
"""
147+
sp = spm.SentencePieceProcessor()
148+
sp.load(str(model_file))
149+
150+
# Convert word to word piece IDs instead of word piece strings
151+
# to avoid OOV tokens.
152+
encode_words = [byte_encode(tokenize_by_CJK_char(w)) for w in words]
153+
words_pieces_ids: List[List[int]] = sp.encode(encode_words, out_type=int)
154+
155+
# Now convert word piece IDs back to word piece strings.
156+
words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids]
157+
158+
lexicon = []
159+
for word, pieces in zip(words, words_pieces):
160+
lexicon.append((word, pieces))
161+
162+
lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())]))
163+
164+
token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
165+
166+
return lexicon, token2id
167+
168+
169+
def get_args():
170+
parser = argparse.ArgumentParser()
171+
parser.add_argument(
172+
"--lang-dir",
173+
type=str,
174+
help="""Input and output directory.
175+
It should contain the bpe.model and words.txt
176+
""",
177+
)
178+
179+
parser.add_argument(
180+
"--oov",
181+
type=str,
182+
default="<UNK>",
183+
help="The out of vocabulary word in lexicon.",
184+
)
185+
186+
parser.add_argument(
187+
"--debug",
188+
type=str2bool,
189+
default=False,
190+
help="""True for debugging, which will generate
191+
a visualization of the lexicon FST.
192+
193+
Caution: If your lexicon contains hundreds of thousands
194+
of lines, please set it to False!
195+
196+
See "test/test_bpe_lexicon.py" for usage.
197+
""",
198+
)
199+
200+
return parser.parse_args()
201+
202+
203+
def main():
204+
args = get_args()
205+
lang_dir = Path(args.lang_dir)
206+
model_file = lang_dir / "bbpe.model"
207+
208+
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
209+
210+
words = word_sym_table.symbols
211+
212+
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", args.oov, "#0", "<s>", "</s>"]
213+
214+
for w in excluded:
215+
if w in words:
216+
words.remove(w)
217+
218+
lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov)
219+
220+
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
221+
222+
next_token_id = max(token_sym_table.values()) + 1
223+
for i in range(max_disambig + 1):
224+
disambig = f"#{i}"
225+
assert disambig not in token_sym_table
226+
token_sym_table[disambig] = next_token_id
227+
next_token_id += 1
228+
229+
word_sym_table.add("#0")
230+
word_sym_table.add("<s>")
231+
word_sym_table.add("</s>")
232+
233+
write_mapping(lang_dir / "tokens.txt", token_sym_table)
234+
235+
write_lexicon(lang_dir / "lexicon.txt", lexicon)
236+
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
237+
238+
L = lexicon_to_fst_no_sil(
239+
lexicon,
240+
token2id=token_sym_table,
241+
word2id=word_sym_table,
242+
)
243+
244+
L_disambig = lexicon_to_fst_no_sil(
245+
lexicon_disambig,
246+
token2id=token_sym_table,
247+
word2id=word_sym_table,
248+
need_self_loops=True,
249+
)
250+
torch.save(L.as_dict(), lang_dir / "L.pt")
251+
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
252+
253+
if args.debug:
254+
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
255+
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
256+
257+
L.labels_sym = labels_sym
258+
L.aux_labels_sym = aux_labels_sym
259+
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
260+
261+
L_disambig.labels_sym = labels_sym
262+
L_disambig.aux_labels_sym = aux_labels_sym
263+
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
264+
265+
266+
if __name__ == "__main__":
267+
main()

0 commit comments

Comments
 (0)