Skip to content

Commit

Permalink
Merge pull request #61 from mideind/update-domain-translation-task
Browse files Browse the repository at this point in the history
Fixing the domain translation model to use our BartTranslation task w…
  • Loading branch information
peturorri authored Sep 18, 2023
2 parents bc0701b + 26c2299 commit 6b69267
Showing 1 changed file with 23 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from fairseq.data import BaseWrapperDataset, Dictionary, LanguagePairDataset, data_utils
from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationTask, load_langpair_dataset
from fairseq.tasks.translation_from_pretrained_bart import TranslationFromPretrainedBARTTask

from .translation_from_pretrained_bart import TranslationFromPretrainedBARTTask

_EES_DEFAULT_DOMAIN = "ees_ótiltekið"

Expand Down Expand Up @@ -95,10 +96,11 @@ def add_args(parser):
'line has some domain from the domain_dict.txt')
# fmt: on

def __init__(self, args, src_dict, tgt_dict):
super().__init__(args, src_dict, tgt_dict)
self.langs = args.langs.split(",")
self.load_domains(args)
def __init__(self, cfg, src_dict, tgt_dict):
super().__init__(cfg, src_dict, tgt_dict)
self.args = cfg
self.langs = cfg.langs.split(",")
self.load_domains(cfg)
for dict_ in [src_dict, tgt_dict]:
for lang in self.langs:
dict_.add_symbol("[{}]".format(lang))
Expand All @@ -109,28 +111,28 @@ def __init__(self, args, src_dict, tgt_dict):
symbol = self.domain_dict.symbols[idx]
dict_.add_symbol(f"<{symbol}>")

def load_domains(self, args):
self.domain_dict = self.load_dictionary(args.domain_dict)
def load_domains(self, cfg):
self.domain_dict = self.load_dictionary(cfg.domain_dict)

def load_dataset(self, split, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = utils.split_paths(self.args.data)
paths = utils.split_paths(self.cfg.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]

# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
src, tgt = self.cfg.source_lang, self.cfg.target_lang

if "train" in split:
domain_path = self.args.train_domains
domain_path = self.cfg.train_domains
elif "valid" in split:
domain_path = self.args.valid_domains
domain_path = self.cfg.valid_domains
elif "test" in split:
domain_path = self.args.test_domains
domain_path = self.cfg.test_domains

else:
raise ValueError(f"Unknown split {split}")
Expand All @@ -145,20 +147,20 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
tgt,
self.tgt_dict,
combine=combine,
dataset_impl=self.args.dataset_impl,
upsample_primary=self.args.upsample_primary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=getattr(self.args, "max_source_positions", 512),
max_target_positions=getattr(self.args, "max_target_positions", 512),
load_alignments=self.args.load_alignments,
prepend_bos=getattr(self.args, "prepend_bos", False),
dataset_impl=self.cfg.dataset_impl,
upsample_primary=self.cfg.upsample_primary,
left_pad_source=self.cfg.left_pad_source,
left_pad_target=self.cfg.left_pad_target,
max_source_positions=getattr(self.cfg, "max_source_positions", 512),
max_target_positions=getattr(self.cfg, "max_target_positions", 512),
load_alignments=self.cfg.load_alignments,
prepend_bos=getattr(self.cfg, "prepend_bos", False),
append_source_id=True,
)
self.datasets[split] = DomainPrefixingDataset(
langpair_dataset=self.datasets[split],
domain_per_example=domain_per_example,
src_dict=self.src_dict,
domain_dict=self.domain_dict,
seed=self.args.seed,
seed=self.cfg.seed,
)

0 comments on commit 6b69267

Please sign in to comment.