Skip to content

Commit 588e40d

Browse files
Better scorer (issue #11 )
1 parent 586b7a3 commit 588e40d

File tree

2 files changed

+48
-77
lines changed

2 files changed

+48
-77
lines changed

boudams/cli.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def train(config_files, epochs, batch_size, device, debug):
174174
config["datasets"]["test"]
175175

176176
vocabulary = LabelEncoder(
177-
maximum_length=config["max_sentence_size"],
177+
maximum_length=config.get("max_sentence_size", None),
178178
masked=masked,
179179
remove_diacriticals=config["label_encoder"].get("normalize", True),
180180
lower=config["label_encoder"].get("lower", True)
@@ -200,7 +200,7 @@ def train(config_files, epochs, batch_size, device, debug):
200200

201201
tagger = BoudamsTagger(
202202
vocabulary,
203-
device=device, system=config["model"], out_max_sentence_length=config["max_sentence_size"],
203+
device=device, system=config["model"], out_max_sentence_length=config.get("max_sentence_size", None),
204204
**config["network"])
205205
trainer = Trainer(tagger, device=device)
206206
print(tagger.model)

boudams/trainer.py

+46-75
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import logging
1010

1111
from collections import namedtuple
12-
from typing import Callable
12+
from typing import Callable, List, Tuple
1313

1414

1515
import torch
@@ -19,8 +19,7 @@
1919
import tqdm
2020

2121

22-
from sklearn.metrics import accuracy_score, confusion_matrix, precision_recall_fscore_support
23-
from leven import levenshtein
22+
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
2423

2524
from boudams.tagger import BoudamsTagger, DEVICE
2625
from boudams.encoder import DatasetIterator
@@ -29,7 +28,7 @@
2928

3029
INVALID = "<INVALID>"
3130
DEBUG = bool(os.getenv("DEBUG"))
32-
Score = namedtuple("Score", ["loss", "perplexity", "accuracy", "leven", "leven_per_char", "scorer"])
31+
Score = namedtuple("Score", ["loss", "accuracy", "precision", "recall", "fscore", "scorer"])
3332

3433

3534
class PlateauModes(enum.Enum):
@@ -62,8 +61,7 @@ def __init__(self, tagger: BoudamsTagger, masked: bool = False, record: bool = F
6261
self.preds = []
6362
self.srcs = []
6463

65-
self._score_tuple = namedtuple("scores", ["accuracy", "leven", "leven_per_char",
66-
"precision", "recall", "fscore"])
64+
self._score_tuple = namedtuple("scores", ["accuracy", "precision", "recall", "fscore"])
6765
self.scores = None
6866
self.masked: bool = masked
6967

@@ -86,14 +84,13 @@ def plot_confusion_matrix(self, path: str = "confusion-matrix.png"):
8684
)
8785
plt.savefig(path)
8886

89-
def compute(self):
90-
levenshteins = []
91-
leven_per_char = []
92-
87+
def compute(self) -> "Scorer":
9388
unrolled_trues = list([y_char for y_sent in self.trues for y_char in y_sent])
9489
unrolled_preds = list([y_char for y_sent in self.preds for y_char in y_sent])
9590

96-
matrix = confusion_matrix(unrolled_trues, unrolled_preds,
91+
matrix = confusion_matrix(
92+
unrolled_trues,
93+
unrolled_preds,
9794
labels=[self.tagger.vocabulary.space_token_index, self.tagger.vocabulary.mask_token_index]
9895
)
9996
# Accuracy score takes into account PAD, EOS and SOS so we get the data from the confusion matrix
@@ -108,43 +105,15 @@ def compute(self):
108105
labels=[self.tagger.vocabulary.space_token_index, self.tagger.vocabulary.mask_token_index]
109106
)
110107

111-
for tr_true, tr_pred in zip(
112-
self.tagger.vocabulary.transcribe_batch(
113-
self.tagger.vocabulary.reverse_batch(self.trues, ignore=(self.tagger.vocabulary.pad_token_index, ),
114-
masked=self.srcs)
115-
),
116-
self.tagger.vocabulary.transcribe_batch(
117-
self.tagger.vocabulary.reverse_batch(self.preds, ignore=(self.tagger.vocabulary.pad_token_index, ),
118-
masked=self.srcs)
119-
)
120-
):
121-
levenshteins.append(levenshtein(tr_true, tr_pred))
122-
leven_per_char.append(levenshteins[-1] / len(tr_true))
123-
if DEBUG and random.random() < 0.05:
124-
logging.debug("EXP:" + "".join(tr_true))
125-
logging.debug("OUT:" + "".join(tr_pred))
126-
logging.debug("---")
127-
128-
self.scores = self._score_tuple(accuracy,
129-
statistics.mean(levenshteins),
130-
statistics.mean(leven_per_char),
131-
precision, recall, fscore)
108+
self.scores = self._score_tuple(accuracy, precision, recall, fscore)
109+
110+
return self
132111

133112
def get_accuracy(self) -> float:
134113
if not self.scores:
135114
self.compute()
136115
return self.scores.accuracy
137116

138-
def avg_levenshteins(self) -> float:
139-
if not self.scores:
140-
self.compute()
141-
return self.scores.leven
142-
143-
def avg_levenshteins_per_char(self) -> float:
144-
if not self.scores:
145-
self.compute()
146-
return self.scores.leven_per_char
147-
148117
def register_batch(self, hypotheses, targets, src):
149118
"""
150119
@@ -213,6 +182,13 @@ def _temp_save(self, file_path: str, best_score: float, current_score: Score) ->
213182
best_score = current_score.loss
214183
return best_score
215184

185+
@staticmethod
186+
def print_score(key: str, score: Score) -> None:
187+
print(f'\t{key} Loss: {score.loss:.3f} | FScore: {score.fscore:.3f} | '
188+
f' Acc.: {score.accuracy:.3f} | '
189+
f' Prec.: {score.precision:.3f} | '
190+
f' Recl.: {score.recall:.3f}')
191+
216192
def run(
217193
self, train_dataset: DatasetIterator, dev_dataset: DatasetIterator,
218194
lr: float = 1e-3, min_lr: float = 1e-6, lr_factor: int = 0.75, lr_patience: float = 10,
@@ -244,7 +220,7 @@ def run(
244220
fid = '/tmp/{}'.format(str(uuid.uuid1()))
245221
best_valid_loss = float("inf")
246222
# In case exception was run before eval
247-
dev_score = Score(float("inf"), float("inf"), float("-inf"), float("inf"), float("inf"), None)
223+
dev_score = Score(float("inf"), float("-inf"), float("-inf"), float("-inf"), float("-inf"), None)
248224

249225
# Set up loss but ignore the loss when the token is <pad>
250226
# where <pad> is the token for filling the vector to get same-sized matrix
@@ -269,26 +245,20 @@ def run(
269245
(
270246
str(epoch),
271247
# train
272-
str(train_score.loss), str(train_score.perplexity), str(train_score.accuracy),
273-
str(train_score.leven), str(train_score.leven_per_char),
248+
str(train_score.loss), str(train_score.accuracy), str(train_score.precision),
249+
str(train_score.recall), str(train_score.fscore),
274250
# Dev
275-
str(dev_score.loss), str(dev_score.perplexity), str(dev_score.accuracy),
276-
str(dev_score.leven), str(dev_score.leven_per_char),
277-
"UNK", "UNK"
251+
str(dev_score.loss), str(dev_score.accuracy), str(dev_score.precision),
252+
str(dev_score.recall), str(dev_score.fscore),
253+
# Test
254+
"UNK", "UNK", "UNK", "UNK", "UNK"
278255
)
279256
)
280257

281258
# Run a check on saving the current model
282259
best_valid_loss = self._temp_save(fid, best_valid_loss, dev_score)
283-
print(f'\tTrain Loss: {train_score.loss:.3f} | Perplexity: {train_score.perplexity:7.3f} | '
284-
f' Acc.: {train_score.accuracy:.3f} | '
285-
f' Lev.: {train_score.leven:.3f} | '
286-
f' Lev. / char: {train_score.leven_per_char:.3f}')
287-
288-
print(f'\t Val. Loss: {dev_score.loss:.3f} | Perplexity: {dev_score.perplexity:7.3f} | '
289-
f' Acc.: {dev_score.accuracy:.3f} | '
290-
f' Lev.: {dev_score.leven:.3f} | '
291-
f' Lev. / char: {dev_score.leven_per_char:.3f}')
260+
self.print_score("Train", train_score)
261+
self.print_score("Dev", dev_score)
292262
print(lr_scheduler)
293263
print()
294264

@@ -355,22 +325,16 @@ def save(self, fpath="model.tar", csv_content=None):
355325
return fpath
356326

357327
@staticmethod
358-
def init_csv_content():
328+
def init_csv_content() -> List[Tuple[str, str, str, str, str, str, str, str, str, str, str, str, str, str, str, str]]:
359329
return [
360330
(
361331
"Epoch",
362-
"Train Loss", "Train Perplexity", "Train Accuracy", "Train Avg Leven", "Train Avg Leven Per Char",
363-
"Dev Loss", "Dev Perplexity", "Dev Accuracy", "Dev Avg Leven", "Dev Avg Leven Per Char",
364-
"Test Loss", "Test Perplexity"
332+
"Train Loss", "Train Accuracy", "Train Precision", "Train Recall", "Train F1",
333+
"Dev Loss", "Dev Accuracy", "Dev Precision", "Dev Recall", "Dev F1",
334+
"Test Loss", "Test Accuracy", "Test Precision", "Test Recall", "Test F1"
365335
)
366336
]
367337

368-
def _get_perplexity(self, loss):
369-
try:
370-
return math.exp(loss)
371-
except:
372-
return float("inf")
373-
374338
def _train_epoch(self, iterator: DatasetIterator, optimizer: optim.Optimizer, criterion: nn.CrossEntropyLoss,
375339
clip: float, desc: str, batch_size: int = 32) -> Score:
376340
self.tagger.model.train()
@@ -404,8 +368,13 @@ def _train_epoch(self, iterator: DatasetIterator, optimizer: optim.Optimizer, cr
404368
epoch_loss += loss.item()
405369

406370
loss = epoch_loss / iterator.batch_count
407-
return Score(loss, self._get_perplexity(loss), scorer.get_accuracy(),
408-
scorer.avg_levenshteins(), scorer.avg_levenshteins_per_char(), scorer=scorer)
371+
scorer.compute()
372+
return Score(loss,
373+
accuracy=scorer.scores.accuracy,
374+
precision=scorer.scores.precision,
375+
recall=scorer.scores.recall,
376+
fscore=scorer.scores.fscore,
377+
scorer=scorer)
409378

410379
def evaluate(self, iterator: DatasetIterator, criterion: nn.CrossEntropyLoss,
411380
desc: str, batch_size: int, test_mode=False) -> Score:
@@ -435,8 +404,13 @@ def evaluate(self, iterator: DatasetIterator, criterion: nn.CrossEntropyLoss,
435404

436405
loss = epoch_loss / iterator.batch_count
437406

438-
return Score(loss, self._get_perplexity(loss), scorer.get_accuracy(),
439-
scorer.avg_levenshteins(), scorer.avg_levenshteins_per_char(), scorer=scorer)
407+
scorer.compute()
408+
return Score(loss,
409+
accuracy=scorer.scores.accuracy,
410+
precision=scorer.scores.precision,
411+
recall=scorer.scores.recall,
412+
fscore=scorer.scores.fscore,
413+
scorer=scorer)
440414

441415
def test(self, test_dataset: DatasetIterator, batch_size: int = 256, do_print=True):
442416
# Set up loss but ignore the loss when the token is <pad>
@@ -447,8 +421,5 @@ def test(self, test_dataset: DatasetIterator, batch_size: int = 256, do_print=Tr
447421
scorer: Scorer = score_object.scorer
448422

449423
if do_print:
450-
print(f' Test Loss: {score_object.loss:.3f} | Perplexity: {score_object.perplexity:7.3f} | '
451-
f'Acc.: {score_object.accuracy:.3f} | '
452-
f'Lev.: {score_object.scorer.avg_levenshteins():.3f} | '
453-
f'Lev. / char: {score_object.leven_per_char:.3f}')
424+
self.print_score("Test", score_object)
454425
return scorer

0 commit comments

Comments
 (0)