@@ -145,71 +145,51 @@ def iterable():
145
145
146
146
class LabelEncoder :
147
147
def __init__ (self ,
148
- init_token = DEFAULT_INIT_TOKEN ,
149
- eos_token = DEFAULT_EOS_TOKEN ,
150
148
pad_token = DEFAULT_PAD_TOKEN ,
151
149
unk_token = DEFAULT_UNK_TOKEN ,
152
150
mask_token = DEFAULT_MASK_TOKEN ,
153
151
maximum_length : int = None ,
154
152
lower : bool = True ,
155
- remove_diacriticals : bool = True ,
156
- masked : bool = False
153
+ remove_diacriticals : bool = True
157
154
):
158
155
159
- self .masked : bool = masked
160
- self .init_token : str = init_token
161
- self .eos_token : str = eos_token
162
156
self .pad_token : str = pad_token
163
157
self .unk_token : str = unk_token
164
158
self .mask_token : str = mask_token
165
159
self .space_token : str = " "
166
160
167
- self .init_token_index : int = 0
168
- self .eos_token_index : int = 1
169
161
self .pad_token_index : int = 2
170
- self .space_token_index : int = 3
171
- self .mask_token_index : int = 4
172
- self .unk_token_index : int = 5 # Put here because it isn't used in masked
162
+ self .space_token_index : int = 1
163
+ self .mask_token_index : int = 0
164
+ self .unk_token_index : int = 0 # Put here because it isn't used in masked
173
165
174
166
self .max_len : Optional [int ] = maximum_length
175
167
self .lower = lower
176
168
self .remove_diacriticals = remove_diacriticals
177
169
178
170
self .itos : Dict [int , str ] = {
179
- self .init_token_index : self .init_token ,
180
- self .eos_token_index : self .eos_token ,
181
171
self .pad_token_index : self .pad_token ,
182
- self .unk_token_index : self .unk_token
172
+ self .unk_token_index : self .unk_token ,
173
+ self .space_token_index : self .space_token
183
174
} # Id to string for reversal
184
175
185
176
self .stoi : Dict [str , int ] = {
186
- self .init_token : self .init_token_index ,
187
- self .eos_token : self .eos_token_index ,
188
177
self .pad_token : self .pad_token_index ,
189
- self .unk_token : self .unk_token_index
178
+ self .unk_token : self .unk_token_index ,
179
+ self .space_token : self .space_token_index
190
180
} # String to ID
191
181
192
182
# Mask dictionaries
193
183
self .itom : Dict [int , str ] = {
194
- self .init_token_index : self .init_token ,
195
- self .eos_token_index : self .eos_token ,
196
184
self .pad_token_index : self .pad_token ,
197
185
self .mask_token_index : self .mask_token ,
198
186
self .space_token_index : self .space_token
199
187
}
200
188
self .mtoi : Dict [str , int ] = {
201
- self .init_token : self .init_token_index ,
202
- self .eos_token : self .eos_token_index ,
203
189
self .pad_token : self .pad_token_index ,
204
190
self .mask_token : self .mask_token_index ,
205
191
self .space_token : self .space_token_index
206
192
}
207
- self .use_init = True
208
- self .use_eos = True
209
-
210
- def encoding_parameters (self , use_init , use_eos ):
211
- self .use_init = use_init
212
- self .use_eos = use_eos
213
193
214
194
def __len__ (self ):
215
195
return len (self .stoi )
@@ -279,6 +259,7 @@ def pad_and_tensorize(
279
259
280
260
:param sentences: List of sentences where characters have been separated into a list and index encoded
281
261
:param padding: padding required (None if every sentence in the same size)
262
+ :param reorder: List of index to reorder the sequence
282
263
:param device: Torch device
283
264
:return: Transformed batch into tensor
284
265
"""
@@ -310,36 +291,26 @@ def pad_and_tensorize(
310
291
def gt_to_numerical (self , sentence : Sequence [str ]) -> Tuple [List [int ], int ]:
311
292
""" Transform GT to numerical
312
293
313
- :param sentence: Sequence of characters (can be a straight string)
314
- :return: List of character indexes
294
+ :param sentence: Sequence of characters (can be a straight string) with spaces
295
+ :return: List of mask indexes
315
296
"""
316
- if not self .masked :
317
- return self .inp_to_numerical (sentence )
318
- else :
319
- obligatory_tokens = int (self .use_init ) + int (self .use_eos ) # Tokens for init and end of string
320
- init = [self .init_token_index ] if self .use_init else []
321
- eos = [self .eos_token_index ] if self .use_eos else []
322
- numericals = init + [
323
- self .mask_token_index if ngram [1 ] != " " else self .space_token_index
324
- for ngram in zip (* [sentence [i :] for i in range (2 )])
325
- if ngram [0 ] != " "
326
- ] + [self .space_token_index ] + eos
297
+ numericals = [
298
+ self .mask_token_index if ngram [1 ] != " " else self .space_token_index
299
+ for ngram in zip (* [sentence [i :] for i in range (2 )])
300
+ if ngram [0 ] != " "
301
+ ] + [self .space_token_index ]
327
302
328
- return numericals , len (sentence ) - sentence .count (" " ) + obligatory_tokens
303
+ return numericals , len (sentence ) - sentence .count (" " )
329
304
330
305
def inp_to_numerical (self , sentence : Sequence [str ]) -> Tuple [List [int ], int ]:
331
- """ Transform GT to numerical
306
+ """ Transform input sentence to numerical
332
307
333
- :param sentence: Sequence of characters (can be a straight string)
308
+ :param sentence: Sequence of characters (can be a straight string) without spaces
334
309
:return: List of character indexes
335
310
"""
336
- obligatory_tokens = int (self .use_init ) + int (self .use_eos ) # Tokens for init and end of string
337
- init = [self .init_token_index ] if self .use_init else []
338
- eos = [self .eos_token_index ] if self .use_eos else []
339
-
340
311
return (
341
- init + [self .stoi .get (char , self .unk_token_index ) for char in sentence ] + eos ,
342
- len (sentence ) + obligatory_tokens
312
+ [self .stoi .get (char , self .unk_token_index ) for char in sentence ],
313
+ len (sentence )
343
314
)
344
315
345
316
def reverse_batch (
@@ -355,9 +326,8 @@ def reverse_batch(
355
326
with torch .cuda .device_of (batch ):
356
327
batch = batch .tolist ()
357
328
358
- if self . masked is True and masked is not None :
329
+ if masked is not None :
359
330
if not isinstance (masked , list ):
360
-
361
331
with torch .cuda .device_of (masked ):
362
332
masked = masked .tolist ()
363
333
@@ -371,9 +341,11 @@ def reverse_batch(
371
341
]
372
342
else :
373
343
masked = [
374
- [ self . init_token_index ] + list (sentence ) + [ self . eos_token_index ]
344
+ list (sentence )
375
345
for sentence in masked
376
346
]
347
+ print (ignore )
348
+
377
349
return [
378
350
[
379
351
tok
@@ -405,8 +377,7 @@ def reverse_batch(
405
377
406
378
def transcribe_batch (self , batch : List [List [str ]]):
407
379
for sentence in batch :
408
- end = len (sentence ) if self .eos_token not in sentence else sentence .index (self .eos_token )
409
- yield "" .join (sentence [1 :end ]) # Remove SOS
380
+ yield "" .join (sentence ).strip () # Remove SOS
410
381
411
382
def get_dataset (self , path , ** kwargs ):
412
383
"""
@@ -431,14 +402,11 @@ def dump(self) -> str:
431
402
"itos" : self .itos ,
432
403
"stoi" : self .stoi ,
433
404
"params" : {
434
- "init_token" : self .init_token ,
435
- "eos_token" : self .eos_token ,
436
405
"pad_token" : self .pad_token ,
437
406
"unk_token" : self .unk_token ,
438
407
"mask_token" : self .mask_token ,
439
408
"remove_diacriticals" : self .remove_diacriticals ,
440
- "lower" : self .lower ,
441
- "masked" : self .masked
409
+ "lower" : self .lower
442
410
}
443
411
})
444
412
0 commit comments