-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtesting.py
128 lines (105 loc) · 3.71 KB
/
testing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""#%%
from datasets import load_from_disk
from transformers import GPT2Tokenizer
train_data = load_from_disk("data/wmt17_de_en_train")
print(train_data)
item = train_data['translation']
print(item[0])
# %%
item[3]["de"]
# %%
tokenizer = GPT2Tokenizer.from_pretrained("./gpt2_from_bpe")
encoded_source = tokenizer.encode(source)
print(encoded_source)
print(type(encoded_source))
# %%
bos_token_id = tokenizer.convert_tokens_to_ids("[BOS]")
print(type(bos_token_id))
print(type(encoded_source))
encoded_source = [bos_token_id] + encoded_source
# %%
pad_length = 64
def add_padding_or_truncate(tokenized_text):
if len(tokenized_text) < pad_length:
left = pad_length - len(tokenized_text)
padding = [tokenizer.convert_tokens_to_ids("[PAD]")] * left
tokenized_text += padding
else:
tokenized_text = tokenized_text[:pad_length]
return tokenized_text
encoded_source = add_padding_or_truncate(encoded_source)
print(encoded_source)
# %%
import torch
encoded_source = torch.tensor(encoded_source, dtype=torch.long)
print(encoded_source)
# %%
print(len(encoded_source))
# %%
from datasets import load_from_disk
from transformers import GPT2Tokenizer
import torch
from dataset import MyDataset
from torch.utils.data import DataLoader
train_data = load_from_disk("./data/wmt17_de_en_train")
tokenizer = GPT2Tokenizer.from_pretrained("./gpt2_from_bpe")
train_dataset = MyDataset(train_data, tokenizer=tokenizer)
# %%
print(train_dataset[0])
# %%
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)
# %%
# get first batch
first_batch = next(iter(train_loader))
print(first_batch)
# %%
from datasets import load_from_disk
from transformers import GPT2Tokenizer
import torch
tokenizer = GPT2Tokenizer.from_pretrained("./gpt2_from_bpe")
train_data = load_from_disk("data/wmt17_de_en_train")
test_data = load_from_disk("data/wmt17_de_en_test")
val_data = load_from_disk("data/wmt17_de_en_val")
pad_length = 64
def add_padding_or_truncate(tokenized_text):
if len(tokenized_text) < pad_length:
left = pad_length - len(tokenized_text)
padding = [tokenizer.convert_tokens_to_ids("[PAD]")] * left
tokenized_text += padding
else:
tokenized_text = tokenized_text[:pad_length]
return tokenized_text
def preprocess_function(examples):
source = examples["translation"]["de"]
target = examples["translation"]["en"]
bos_token_id = tokenizer.convert_tokens_to_ids("[BOS]")
eos_token_id = tokenizer.convert_tokens_to_ids("[EOS]")
encoded_source = tokenizer.encode(source)
encoded_target = tokenizer.encode(target)
encoded_target_input = [bos_token_id] + encoded_target
encoded_target_output = encoded_target + [eos_token_id]
encoded_source = add_padding_or_truncate(encoded_source)
encoded_target_input = add_padding_or_truncate(encoded_target_input)
encoded_target_output = add_padding_or_truncate(encoded_target_output)
return {
"source": encoded_source,
"target_input": encoded_target_input,
"target_output": encoded_target_output,
}
train_data = train_data.map(preprocess_function)
test_data = test_data.map(preprocess_function)
val_data = val_data.map(preprocess_function)
# %%
train_data = torch.load("data/train_dataset.pt")
val_data = torch.load("data/val_dataset.pt")
train_data = train_data.remove_columns(["translation"])
val_data = val_data.remove_columns(["translation"])
# safe as torch dataset
torch.save(train_data, "data/train_dataset.pt")
torch.save(val_data, "data/val_dataset.pt")
# %%
import torch
train_data = torch.load("data/train_dataset.pt")
print(train_data[0])
#train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True)"""
# %%