Skip to content

Commit d7f4fe1

Browse files
Fix file tagging to be more adapted to current system. Does not fix everything...
Squashed commit of the following: commit 4969a4b Author: Thibault Clérice <leponteineptique@gmail.com> Date: Wed Apr 20 13:25:26 2022 +0200 Fixed tagging probably commit 3858463 Author: Thibault Clérice <leponteineptique@gmail.com> Date: Wed Apr 20 11:19:03 2022 +0200 WIP
1 parent 0ea9830 commit d7f4fe1

File tree

3 files changed

+54
-53
lines changed

3 files changed

+54
-53
lines changed

boudams/cli.py

+14-33
Original file line numberDiff line numberDiff line change
@@ -408,45 +408,27 @@ def test(test_path, models, batch_size, device, debug, workers: int, avg: str):
408408
def tag(model, filename, device="cpu", batch_size=64):
409409
""" Tag all [FILENAME] using [MODEL]"""
410410
print("Loading the model.")
411-
model = BoudamsTagger.load(model)
411+
model = BoudamsTagger.load(model, device=device)
412412
model.eval()
413-
model.to(device)
414413
print("Model loaded.")
415-
remove_line = True
416-
spaces = re.compile(r"\s+")
417-
apos = re.compile(r"['’]")
418414
for file in tqdm.tqdm(filename):
419415
out_name = file.name.replace(".txt", ".tokenized.txt")
420416
content = file.read() # Could definitely be done a better way...
421-
if remove_line:
422-
content = spaces.sub("", content)
417+
if model.vocabulary.mode.name == "simple-space":
418+
content = re.sub(r"\s+", "", content)
419+
elif model.vocabulary.mode.NormalizeSpace:
420+
content = re.sub(r"\s+", " ", content)
423421
file.close()
424-
# Now, extract apostrophes, remove them, and reinject them
425-
apos_positions = [
426-
i
427-
for i in range(len(content))
428-
if content[i] in ["'", "’"]
429-
]
430-
content = apos.sub("", content)
431-
432422
with open(out_name, "w") as out_io:
433423
out = ''
434-
for tokenized_string in model.annotate_text(content, batch_size=batch_size, device=device):
435-
out = out + tokenized_string+" "
436-
437-
# Reinject apostrophes
438-
#out = 'Sainz Tiebauz fu nez en l evesché de Troies ; ses peres ot non Ernous et sa mere, Gile et furent fra'
439-
true_index = 0
440-
for i in range(len(out) + len(apos_positions)):
441-
if true_index in apos_positions:
442-
out = out[:i] + "'" + out[i:]
443-
true_index = true_index + 1
444-
else:
445-
if not out[i] == ' ':
446-
true_index = true_index + 1
447-
424+
for tokenized_string in model.annotate_text(
425+
content,
426+
batch_size=batch_size,
427+
device=device
428+
):
429+
out = out + tokenized_string + "\n"
448430
out_io.write(out)
449-
# print("--- File " + file.name + " has been tokenized")
431+
print("--- File " + file.name + " has been tokenized")
450432

451433

452434
@cli.command("tag-check")
@@ -458,11 +440,10 @@ def tag_check(config_model, content, device="cpu", batch_size=64):
458440
""" Tag all [FILENAME] using [MODEL]"""
459441
for model in config_model:
460442
click.echo(f"Loading the model {model}.")
461-
boudams = BoudamsTagger.load(model)
443+
boudams = BoudamsTagger.load(model, device=device)
462444
boudams.eval()
463-
boudams.to(device)
464445
click.echo(f"\t[X] Model loaded")
465-
click.echo("\n".join(boudams.annotate_text(content, splitter="([\.!\?]+)", batch_size=batch_size, device=device)))
446+
click.echo("\n".join(boudams.annotate_text(content, splitter=r"([\.!\?]+)", batch_size=batch_size, device=device)))
466447

467448

468449
@cli.command("graph")

boudams/modes.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class SimpleSpaceMode:
2424
NormalizeSpace: bool = True
2525

2626
def __init__(self, masks: Dict[str, int] = None):
27-
self.name = "Default"
27+
self.name = "simple-space"
2828
self.masks_to_index: Dict[str, int] = masks or {
2929
DEFAULT_PAD_TOKEN: 0,
3030
DEFAULT_MASK_TOKEN: 1,
@@ -139,7 +139,7 @@ def computer_wer(self, confusion_matrix):
139139

140140
class AdvancedSpaceMode(SimpleSpaceMode):
141141
def __init__(self, masks: Dict[str, int] = None):
142-
self.name = "Default"
142+
self.name = "advanced-space"
143143
self.masks_to_index: Dict[str, int] = masks or {
144144
DEFAULT_PAD_TOKEN: 0,
145145
DEFAULT_MASK_TOKEN: 1,

boudams/tagger.py

+38-18
Original file line numberDiff line numberDiff line change
@@ -449,30 +449,50 @@ def annotate(self, texts: List[str], batch_size=32, device: str = "cpu"):
449449
for index in range(len(translations)):
450450
yield "".join(translations[order.index(index)])
451451

452-
def annotate_text(self, string, splitter=r"([⁊\W\d]+)", batch_size=32, device: str = "cpu"):
453-
splitter = re.compile(splitter)
454-
splits = splitter.split(string)
455-
456-
tempList = splits + [""] * 2
457-
strings = ["".join(tempList[n:n + 2]) for n in range(0, len(splits), 2)]
458-
strings = list(filter(lambda x: x.strip(), strings))
452+
@staticmethod
453+
def _apply_max_size(tokens: str, size: int):
454+
# Use finditer when applied to things with spaces ?
455+
# [(m.start(0), m.end(0)) for m in re.finditer(pattern, string)] ?
456+
current = []
457+
for tok in re.split(r"(\s+)", tokens):
458+
if not tok:
459+
continue
460+
current.append(tok)
461+
string_size = len("".join(current))
462+
if string_size > size:
463+
yield "".join(current[:-1])
464+
current = current[-1:]
465+
elif string_size == size:
466+
yield "".join(current)
467+
current = []
468+
if current:
469+
yield "".join(current)
470+
471+
def annotate_text(self, single_sentence, splitter: Optional[str] = None, batch_size=32, device: str = "cpu", rolling=True):
472+
if splitter is None:
473+
# ToDo: Mode specific splitter ?
474+
splitter = r"([\.!\?]+)"
459475

476+
splitter = re.compile(splitter)
477+
sentences = [tok for tok in splitter.split(single_sentence) if tok.strip()]
478+
460479
if self._maximum_sentence_size:
480+
# This is currently quite limitating.
481+
# If the end token is ending with a W and not a WB, there is no way to "correct it"
482+
# We'd need a rolling system: cut in the middle of maximum sentence size ?
461483
treated = []
462484
max_size = self._maximum_sentence_size
463-
for string in strings:
464-
if len(string) > max_size:
465-
treated.extend([
466-
"".join(string[n:n + max_size])
467-
for n in range(0, len(string), max_size)
468-
])
485+
for single_sentence in sentences:
486+
if len(single_sentence) > max_size:
487+
treated.extend(self._apply_max_size(single_sentence, max_size))
469488
else:
470-
treated.append(string)
471-
strings = treated
472-
yield from self.annotate(strings, batch_size=batch_size, device=device)
489+
treated.append(single_sentence)
490+
sentences = treated
491+
492+
yield from self.annotate(sentences, batch_size=batch_size, device=device)
473493

474494
@classmethod
475-
def load(cls, fpath="./model.boudams_model"):
495+
def load(cls, fpath="./model.boudams_model", device=None):
476496
with tarfile.open(utils.ensure_ext(fpath, 'boudams_model'), 'r') as tar:
477497
settings = json.loads(utils.get_gzip_from_tar(tar, 'settings.json.zip'))
478498

@@ -487,7 +507,7 @@ def load(cls, fpath="./model.boudams_model"):
487507
tar.extract('state_dict.pt', path=tmppath)
488508
dictpath = os.path.join(tmppath, 'state_dict.pt')
489509
# Strict false for predict (nll_weight is removed)
490-
obj.load_state_dict(torch.load(dictpath), strict=False)
510+
obj.load_state_dict(torch.load(dictpath, map_location=device), strict=False)
491511

492512
obj.eval()
493513

0 commit comments

Comments
 (0)