@@ -449,30 +449,50 @@ def annotate(self, texts: List[str], batch_size=32, device: str = "cpu"):
449
449
for index in range (len (translations )):
450
450
yield "" .join (translations [order .index (index )])
451
451
452
- def annotate_text (self , string , splitter = r"([⁊\W\d]+)" , batch_size = 32 , device : str = "cpu" ):
453
- splitter = re .compile (splitter )
454
- splits = splitter .split (string )
455
-
456
- tempList = splits + ["" ] * 2
457
- strings = ["" .join (tempList [n :n + 2 ]) for n in range (0 , len (splits ), 2 )]
458
- strings = list (filter (lambda x : x .strip (), strings ))
452
+ @staticmethod
453
+ def _apply_max_size (tokens : str , size : int ):
454
+ # Use finditer when applied to things with spaces ?
455
+ # [(m.start(0), m.end(0)) for m in re.finditer(pattern, string)] ?
456
+ current = []
457
+ for tok in re .split (r"(\s+)" , tokens ):
458
+ if not tok :
459
+ continue
460
+ current .append (tok )
461
+ string_size = len ("" .join (current ))
462
+ if string_size > size :
463
+ yield "" .join (current [:- 1 ])
464
+ current = current [- 1 :]
465
+ elif string_size == size :
466
+ yield "" .join (current )
467
+ current = []
468
+ if current :
469
+ yield "" .join (current )
470
+
471
+ def annotate_text (self , single_sentence , splitter : Optional [str ] = None , batch_size = 32 , device : str = "cpu" , rolling = True ):
472
+ if splitter is None :
473
+ # ToDo: Mode specific splitter ?
474
+ splitter = r"([\.!\?]+)"
459
475
476
+ splitter = re .compile (splitter )
477
+ sentences = [tok for tok in splitter .split (single_sentence ) if tok .strip ()]
478
+
460
479
if self ._maximum_sentence_size :
480
+ # This is currently quite limitating.
481
+ # If the end token is ending with a W and not a WB, there is no way to "correct it"
482
+ # We'd need a rolling system: cut in the middle of maximum sentence size ?
461
483
treated = []
462
484
max_size = self ._maximum_sentence_size
463
- for string in strings :
464
- if len (string ) > max_size :
465
- treated .extend ([
466
- "" .join (string [n :n + max_size ])
467
- for n in range (0 , len (string ), max_size )
468
- ])
485
+ for single_sentence in sentences :
486
+ if len (single_sentence ) > max_size :
487
+ treated .extend (self ._apply_max_size (single_sentence , max_size ))
469
488
else :
470
- treated .append (string )
471
- strings = treated
472
- yield from self .annotate (strings , batch_size = batch_size , device = device )
489
+ treated .append (single_sentence )
490
+ sentences = treated
491
+
492
+ yield from self .annotate (sentences , batch_size = batch_size , device = device )
473
493
474
494
@classmethod
475
- def load (cls , fpath = "./model.boudams_model" ):
495
+ def load (cls , fpath = "./model.boudams_model" , device = None ):
476
496
with tarfile .open (utils .ensure_ext (fpath , 'boudams_model' ), 'r' ) as tar :
477
497
settings = json .loads (utils .get_gzip_from_tar (tar , 'settings.json.zip' ))
478
498
@@ -487,7 +507,7 @@ def load(cls, fpath="./model.boudams_model"):
487
507
tar .extract ('state_dict.pt' , path = tmppath )
488
508
dictpath = os .path .join (tmppath , 'state_dict.pt' )
489
509
# Strict false for predict (nll_weight is removed)
490
- obj .load_state_dict (torch .load (dictpath ), strict = False )
510
+ obj .load_state_dict (torch .load (dictpath , map_location = device ), strict = False )
491
511
492
512
obj .eval ()
493
513
0 commit comments