@@ -52,12 +52,13 @@ def main(path_to_config: str):
52
52
verbose = config ["prepare_data" ]["val_data" ]["verbose" ],
53
53
)
54
54
55
- test_token_seq , test_label_seq = prepare_conll_data_format (
56
- path = config ["prepare_data" ]["test_data" ]["path" ],
57
- sep = config ["prepare_data" ]["test_data" ]["sep" ],
58
- lower = config ["prepare_data" ]["test_data" ]["lower" ],
59
- verbose = config ["prepare_data" ]["test_data" ]["verbose" ],
60
- )
55
+ if "test_data" in config ["prepare_data" ]:
56
+ test_token_seq , test_label_seq = prepare_conll_data_format (
57
+ path = config ["prepare_data" ]["test_data" ]["path" ],
58
+ sep = config ["prepare_data" ]["test_data" ]["sep" ],
59
+ lower = config ["prepare_data" ]["test_data" ]["lower" ],
60
+ verbose = config ["prepare_data" ]["test_data" ]["verbose" ],
61
+ )
61
62
62
63
# token2idx / label2idx
63
64
@@ -91,13 +92,14 @@ def main(path_to_config: str):
91
92
preprocess = config ["dataloader" ]["preprocess" ],
92
93
)
93
94
94
- testset = NERDataset (
95
- token_seq = test_token_seq ,
96
- label_seq = test_label_seq ,
97
- token2idx = token2idx ,
98
- label2idx = label2idx ,
99
- preprocess = config ["dataloader" ]["preprocess" ],
100
- )
95
+ if "test_data" in config ["prepare_data" ]:
96
+ testset = NERDataset (
97
+ token_seq = test_token_seq ,
98
+ label_seq = test_label_seq ,
99
+ token2idx = token2idx ,
100
+ label2idx = label2idx ,
101
+ preprocess = config ["dataloader" ]["preprocess" ],
102
+ )
101
103
102
104
# collators
103
105
@@ -113,11 +115,12 @@ def main(path_to_config: str):
113
115
percentile = 100 , # hardcoded
114
116
)
115
117
116
- test_collator = NERCollator (
117
- token_padding_value = token2idx [config ["dataloader" ]["token_padding" ]],
118
- label_padding_value = label2idx [config ["dataloader" ]["label_padding" ]],
119
- percentile = 100 , # hardcoded
120
- )
118
+ if "test_data" in config ["prepare_data" ]:
119
+ test_collator = NERCollator (
120
+ token_padding_value = token2idx [config ["dataloader" ]["token_padding" ]],
121
+ label_padding_value = label2idx [config ["dataloader" ]["label_padding" ]],
122
+ percentile = 100 , # hardcoded
123
+ )
121
124
122
125
# dataloaders
123
126
@@ -136,12 +139,13 @@ def main(path_to_config: str):
136
139
collate_fn = val_collator ,
137
140
)
138
141
139
- testloader = DataLoader (
140
- dataset = testset ,
141
- batch_size = 1 , # hardcoded
142
- shuffle = False , # hardcoded
143
- collate_fn = test_collator ,
144
- )
142
+ if "test_data" in config ["prepare_data" ]:
143
+ testloader = DataLoader (
144
+ dataset = testset ,
145
+ batch_size = 1 , # hardcoded
146
+ shuffle = False , # hardcoded
147
+ collate_fn = test_collator ,
148
+ )
145
149
146
150
# INIT MODEL
147
151
@@ -208,7 +212,7 @@ def main(path_to_config: str):
208
212
model = model ,
209
213
trainloader = trainloader ,
210
214
valloader = valloader ,
211
- testloader = testloader ,
215
+ testloader = testloader if "test_data" in config [ "prepare_data" ] else None ,
212
216
criterion = criterion ,
213
217
optimizer = optimizer ,
214
218
device = device ,
0 commit comments