Skip to content

Commit

Permalink
Fix duplicate file names.
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgiosSmyrnis committed Mar 10, 2024
1 parent 71c3369 commit 0f77f45
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 13 deletions.
2 changes: 1 addition & 1 deletion open_lm/datapreprocess/ray/tokenize_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _flush_buffer(self, folder, counter):
tokens = [int(x) for x in self.buffer[i]["tokens"]]
token_count += len(tokens)
json_string = json.dumps(tokens)
uid = hashlib.md5(json_string.encode()).hexdigest()
uid = f"{tar_index_str}_{i:0{digits}}"
sample = {"__key__": uid, "json.gz": json_string}
sink.write(sample)
bio.seek(0)
Expand Down
17 changes: 5 additions & 12 deletions tests/test_tokenize_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,27 +132,20 @@ def test_tokenize_shuffle_with_pretokenized():
)
assert exit_value_1 == 0

os.system("mkdir test_input_2a")
os.system("mkdir test_input_2b")
os.system("cp -r ./test_output/00000001.tar ./test_input_2a/")
os.system("cp -r ./test_output/00000002.tar ./test_input_2b/")
os.system("mkdir test_output_2")
os.system("cp -r ./test_output ./test_input/2a/")
os.system("cp -r ./test_output ./test_input/2b/")

exit_value_2 = os.system(
f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input ./test_input_2a,./test_input_2b --content_key json.gz --seqlen {content_len} --output ./test_output_2 --pretok_tars --suffixes .tar"
f"python open_lm/datapreprocess/ray/tokenize_shuffle.py --input ./test_input/2a,./test_input/2b --content_key json.gz --seqlen {content_len} --output ./test_output/2 --pretok_tars --suffixes .tar"
)
assert exit_value_2 == 0

tars = [os.path.join("test_output_2", fname) for fname in os.listdir("test_output_2") if fname.endswith(".tar")]
tars = [os.path.join("test_output/2", fname) for fname in os.listdir("test_output/2") if fname.endswith(".tar")]
total = 0
for tar in tars:
ds = wds.WebDataset(tar).decode()
for x in ds:
assert len(x["json.gz"]) == content_len + 1
total += len(x["json.gz"])

os.system("rm -rf test_input_2a")
os.system("rm -rf test_input_2b")
os.system("rm -rf test_output_2")

assert total == NUM_TOKENS
assert total == 2 * NUM_TOKENS

0 comments on commit 0f77f45

Please sign in to comment.