Skip to content

Commit 3d611e0

Browse files
Removing SOS and EOS
1 parent 588e40d commit 3d611e0

File tree

5 files changed

+34
-82
lines changed

5 files changed

+34
-82
lines changed

boudams/cli.py

-2
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,12 @@ def train(config_files, epochs, batch_size, device, debug):
168168
for config_file in config_files:
169169
config = json.load(config_file)
170170

171-
masked = config["model"].startswith("linear")
172171
train_path, dev_path, test_path = config["datasets"]["train"],\
173172
config["datasets"]["dev"],\
174173
config["datasets"]["test"]
175174

176175
vocabulary = LabelEncoder(
177176
maximum_length=config.get("max_sentence_size", None),
178-
masked=masked,
179177
remove_diacriticals=config["label_encoder"].get("normalize", True),
180178
lower=config["label_encoder"].get("lower", True)
181179
)

boudams/encoder.py

+27-59
Original file line numberDiff line numberDiff line change
@@ -145,71 +145,51 @@ def iterable():
145145

146146
class LabelEncoder:
147147
def __init__(self,
148-
init_token=DEFAULT_INIT_TOKEN,
149-
eos_token=DEFAULT_EOS_TOKEN,
150148
pad_token=DEFAULT_PAD_TOKEN,
151149
unk_token=DEFAULT_UNK_TOKEN,
152150
mask_token=DEFAULT_MASK_TOKEN,
153151
maximum_length: int = None,
154152
lower: bool = True,
155-
remove_diacriticals: bool = True,
156-
masked: bool = False
153+
remove_diacriticals: bool = True
157154
):
158155

159-
self.masked: bool = masked
160-
self.init_token: str = init_token
161-
self.eos_token: str = eos_token
162156
self.pad_token: str = pad_token
163157
self.unk_token: str = unk_token
164158
self.mask_token: str = mask_token
165159
self.space_token: str = " "
166160

167-
self.init_token_index: int = 0
168-
self.eos_token_index: int = 1
169161
self.pad_token_index: int = 2
170-
self.space_token_index: int = 3
171-
self.mask_token_index: int = 4
172-
self.unk_token_index: int = 5 # Put here because it isn't used in masked
162+
self.space_token_index: int = 1
163+
self.mask_token_index: int = 0
164+
self.unk_token_index: int = 0 # Put here because it isn't used in masked
173165

174166
self.max_len: Optional[int] = maximum_length
175167
self.lower = lower
176168
self.remove_diacriticals = remove_diacriticals
177169

178170
self.itos: Dict[int, str] = {
179-
self.init_token_index: self.init_token,
180-
self.eos_token_index: self.eos_token,
181171
self.pad_token_index: self.pad_token,
182-
self.unk_token_index: self.unk_token
172+
self.unk_token_index: self.unk_token,
173+
self.space_token_index: self.space_token
183174
} # Id to string for reversal
184175

185176
self.stoi: Dict[str, int] = {
186-
self.init_token: self.init_token_index,
187-
self.eos_token: self.eos_token_index,
188177
self.pad_token: self.pad_token_index,
189-
self.unk_token: self.unk_token_index
178+
self.unk_token: self.unk_token_index,
179+
self.space_token: self.space_token_index
190180
} # String to ID
191181

192182
# Mask dictionaries
193183
self.itom: Dict[int, str] = {
194-
self.init_token_index: self.init_token,
195-
self.eos_token_index: self.eos_token,
196184
self.pad_token_index: self.pad_token,
197185
self.mask_token_index: self.mask_token,
198186
self.space_token_index: self.space_token
199187
}
200188
self.mtoi: Dict[str, int] = {
201-
self.init_token: self.init_token_index,
202-
self.eos_token: self.eos_token_index,
203189
self.pad_token: self.pad_token_index,
204190
self.mask_token: self.mask_token_index,
205191
self.space_token: self.space_token_index
206192
}
207-
self.use_init = True
208-
self.use_eos = True
209-
210-
def encoding_parameters(self, use_init, use_eos):
211-
self.use_init = use_init
212-
self.use_eos = use_eos
213193

214194
def __len__(self):
215195
return len(self.stoi)
@@ -279,6 +259,7 @@ def pad_and_tensorize(
279259
280260
:param sentences: List of sentences where characters have been separated into a list and index encoded
281261
:param padding: padding required (None if every sentence in the same size)
262+
:param reorder: List of index to reorder the sequence
282263
:param device: Torch device
283264
:return: Transformed batch into tensor
284265
"""
@@ -310,36 +291,26 @@ def pad_and_tensorize(
310291
def gt_to_numerical(self, sentence: Sequence[str]) -> Tuple[List[int], int]:
311292
""" Transform GT to numerical
312293
313-
:param sentence: Sequence of characters (can be a straight string)
314-
:return: List of character indexes
294+
:param sentence: Sequence of characters (can be a straight string) with spaces
295+
:return: List of mask indexes
315296
"""
316-
if not self.masked:
317-
return self.inp_to_numerical(sentence)
318-
else:
319-
obligatory_tokens = int(self.use_init) + int(self.use_eos) # Tokens for init and end of string
320-
init = [self.init_token_index] if self.use_init else []
321-
eos = [self.eos_token_index] if self.use_eos else []
322-
numericals = init + [
323-
self.mask_token_index if ngram[1] != " " else self.space_token_index
324-
for ngram in zip(*[sentence[i:] for i in range(2)])
325-
if ngram[0] != " "
326-
] + [self.space_token_index] + eos
297+
numericals = [
298+
self.mask_token_index if ngram[1] != " " else self.space_token_index
299+
for ngram in zip(*[sentence[i:] for i in range(2)])
300+
if ngram[0] != " "
301+
] + [self.space_token_index]
327302

328-
return numericals, len(sentence) - sentence.count(" ") + obligatory_tokens
303+
return numericals, len(sentence) - sentence.count(" ")
329304

330305
def inp_to_numerical(self, sentence: Sequence[str]) -> Tuple[List[int], int]:
331-
""" Transform GT to numerical
306+
""" Transform input sentence to numerical
332307
333-
:param sentence: Sequence of characters (can be a straight string)
308+
:param sentence: Sequence of characters (can be a straight string) without spaces
334309
:return: List of character indexes
335310
"""
336-
obligatory_tokens = int(self.use_init) + int(self.use_eos) # Tokens for init and end of string
337-
init = [self.init_token_index] if self.use_init else []
338-
eos = [self.eos_token_index] if self.use_eos else []
339-
340311
return (
341-
init + [self.stoi.get(char, self.unk_token_index) for char in sentence] + eos,
342-
len(sentence) + obligatory_tokens
312+
[self.stoi.get(char, self.unk_token_index) for char in sentence],
313+
len(sentence)
343314
)
344315

345316
def reverse_batch(
@@ -355,9 +326,8 @@ def reverse_batch(
355326
with torch.cuda.device_of(batch):
356327
batch = batch.tolist()
357328

358-
if self.masked is True and masked is not None:
329+
if masked is not None:
359330
if not isinstance(masked, list):
360-
361331
with torch.cuda.device_of(masked):
362332
masked = masked.tolist()
363333

@@ -371,9 +341,11 @@ def reverse_batch(
371341
]
372342
else:
373343
masked = [
374-
[self.init_token_index] + list(sentence) + [self.eos_token_index]
344+
list(sentence)
375345
for sentence in masked
376346
]
347+
print(ignore)
348+
377349
return [
378350
[
379351
tok
@@ -405,8 +377,7 @@ def reverse_batch(
405377

406378
def transcribe_batch(self, batch: List[List[str]]):
407379
for sentence in batch:
408-
end = len(sentence) if self.eos_token not in sentence else sentence.index(self.eos_token)
409-
yield "".join(sentence[1:end]) # Remove SOS
380+
yield "".join(sentence).strip() # Remove SOS
410381

411382
def get_dataset(self, path, **kwargs):
412383
"""
@@ -431,14 +402,11 @@ def dump(self) -> str:
431402
"itos": self.itos,
432403
"stoi": self.stoi,
433404
"params": {
434-
"init_token": self.init_token,
435-
"eos_token": self.eos_token,
436405
"pad_token": self.pad_token,
437406
"unk_token": self.unk_token,
438407
"mask_token": self.mask_token,
439408
"remove_diacriticals": self.remove_diacriticals,
440-
"lower": self.lower,
441-
"masked": self.masked
409+
"lower": self.lower
442410
}
443411
})
444412

boudams/model/linear.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __init__(
149149
self,
150150
encoder: CNNEncoder, decoder: LinearDecoder,
151151
device: str,
152-
pad_idx: int, sos_idx: int, eos_idx: int,
152+
pad_idx: int,
153153
pos: bool = False,
154154
**kwargs
155155
):
@@ -160,8 +160,6 @@ def __init__(
160160
self.pos = pos
161161

162162
self.pad_idx = pad_idx
163-
self.sos_idx = sos_idx
164-
self.eos_idx = eos_idx
165163
self.device = device
166164

167165
# nll weight
@@ -211,7 +209,7 @@ def predict(self, src, src_len, label_encoder: "LabelEncoder",
211209
return label_encoder.reverse_batch(
212210
logits,
213211
masked=override_src or src,
214-
ignore=(self.pad_idx, self.eos_idx, self.sos_idx)
212+
ignore=(self.pad_idx, )
215213
)
216214

217215
def gradient(

boudams/tagger.py

+1-14
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def __init__(
4747

4848
self.vocabulary: LabelEncoder = vocabulary
4949
self.vocabulary_dimension: int = len(self.vocabulary)
50-
self.masked: bool = self.vocabulary.masked
5150

5251
self.device: str = device
5352
self.enc_hid_dim = self.dec_hid_dim = self.hidden_size = hidden_size
@@ -64,16 +63,12 @@ def __init__(
6463
self.system: str = system
6564

6665
# Based on self.masked, decoder dimension can be drastically different
67-
self.dec_dim: int = self.vocabulary_dimension
68-
if self.masked:
69-
self.dec_dim = len(self.vocabulary.itom)
66+
self.dec_dim = len(self.vocabulary.itom)
7067

7168
self.mask_token = self.vocabulary.mask_token
7269

7370
seq2seq_shared_params = {
7471
"pad_idx": self.padtoken,
75-
"sos_idx": self.sostoken,
76-
"eos_idx": self.eostoken,
7772
"device": self.device,
7873
"out_max_sentence_length": self.out_max_sentence_length
7974
}
@@ -128,14 +123,6 @@ def __init__(
128123
def padtoken(self):
129124
return self.vocabulary.pad_token_index
130125

131-
@property
132-
def sostoken(self):
133-
return self.vocabulary.init_token_index
134-
135-
@property
136-
def eostoken(self):
137-
return self.vocabulary.eos_token_index
138-
139126
@property
140127
def settings(self):
141128
return {

boudams/trainer.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,10 @@ def register_batch(self, hypotheses, targets, src):
128128
src = src.tolist()
129129

130130
for y_true, y_pred, x in zip(exp, out, src):
131-
self.trues.append(y_true)
132-
self.preds.append(y_pred)
133-
self.srcs.append(x)
131+
stop = x.index(self.tagger.padtoken) if self.tagger.padtoken in x else len(x)
132+
self.trues.append(y_true[:stop])
133+
self.preds.append(y_pred[:stop])
134+
self.srcs.append(x[:stop])
134135

135136

136137
class LRScheduler(object):

0 commit comments

Comments
 (0)