Skip to content

Commit bcd5aea

Browse files
committed
check test path to file
1 parent 835deb3 commit bcd5aea

File tree

1 file changed

+29
-25
lines changed

1 file changed

+29
-25
lines changed

main.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,13 @@ def main(path_to_config: str):
5252
verbose=config["prepare_data"]["val_data"]["verbose"],
5353
)
5454

55-
test_token_seq, test_label_seq = prepare_conll_data_format(
56-
path=config["prepare_data"]["test_data"]["path"],
57-
sep=config["prepare_data"]["test_data"]["sep"],
58-
lower=config["prepare_data"]["test_data"]["lower"],
59-
verbose=config["prepare_data"]["test_data"]["verbose"],
60-
)
55+
if "test_data" in config["prepare_data"]:
56+
test_token_seq, test_label_seq = prepare_conll_data_format(
57+
path=config["prepare_data"]["test_data"]["path"],
58+
sep=config["prepare_data"]["test_data"]["sep"],
59+
lower=config["prepare_data"]["test_data"]["lower"],
60+
verbose=config["prepare_data"]["test_data"]["verbose"],
61+
)
6162

6263
# token2idx / label2idx
6364

@@ -91,13 +92,14 @@ def main(path_to_config: str):
9192
preprocess=config["dataloader"]["preprocess"],
9293
)
9394

94-
testset = NERDataset(
95-
token_seq=test_token_seq,
96-
label_seq=test_label_seq,
97-
token2idx=token2idx,
98-
label2idx=label2idx,
99-
preprocess=config["dataloader"]["preprocess"],
100-
)
95+
if "test_data" in config["prepare_data"]:
96+
testset = NERDataset(
97+
token_seq=test_token_seq,
98+
label_seq=test_label_seq,
99+
token2idx=token2idx,
100+
label2idx=label2idx,
101+
preprocess=config["dataloader"]["preprocess"],
102+
)
101103

102104
# collators
103105

@@ -113,11 +115,12 @@ def main(path_to_config: str):
113115
percentile=100, # hardcoded
114116
)
115117

116-
test_collator = NERCollator(
117-
token_padding_value=token2idx[config["dataloader"]["token_padding"]],
118-
label_padding_value=label2idx[config["dataloader"]["label_padding"]],
119-
percentile=100, # hardcoded
120-
)
118+
if "test_data" in config["prepare_data"]:
119+
test_collator = NERCollator(
120+
token_padding_value=token2idx[config["dataloader"]["token_padding"]],
121+
label_padding_value=label2idx[config["dataloader"]["label_padding"]],
122+
percentile=100, # hardcoded
123+
)
121124

122125
# dataloaders
123126

@@ -136,12 +139,13 @@ def main(path_to_config: str):
136139
collate_fn=val_collator,
137140
)
138141

139-
testloader = DataLoader(
140-
dataset=testset,
141-
batch_size=1, # hardcoded
142-
shuffle=False, # hardcoded
143-
collate_fn=test_collator,
144-
)
142+
if "test_data" in config["prepare_data"]:
143+
testloader = DataLoader(
144+
dataset=testset,
145+
batch_size=1, # hardcoded
146+
shuffle=False, # hardcoded
147+
collate_fn=test_collator,
148+
)
145149

146150
# INIT MODEL
147151

@@ -208,7 +212,7 @@ def main(path_to_config: str):
208212
model=model,
209213
trainloader=trainloader,
210214
valloader=valloader,
211-
testloader=testloader,
215+
testloader=testloader if "test_data" in config["prepare_data"] else None,
212216
criterion=criterion,
213217
optimizer=optimizer,
214218
device=device,

0 commit comments

Comments
 (0)