Skip to content

Commit 50bd427

Browse files
committed
Update splitter unit test and fix issue#227
1 parent 1caed0d commit 50bd427

File tree

6 files changed

+105
-97
lines changed

6 files changed

+105
-97
lines changed

run_tests.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
#!/bin/sh
2-
python -m unittest discover tests
2+
python3 -m unittest discover tests

tests/op/extract/split/test_pattern_splitter_op.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66

77
class TestPatternSplitter(unittest.TestCase):
88
def setUp(self):
9-
self.splitter = PatternSplitter("test_splitter")
9+
self.splitter = PatternSplitter({}, "test_splitter")
1010

1111
def test_special_function_call(self):
12+
"""
13+
Test special function call.
14+
"""
1215
node = Node(name="node1", value_dict={"text": "Hello\n\nWorld"})
1316

1417
output_nodes = self.splitter([node])
@@ -17,7 +20,10 @@ def test_special_function_call(self):
1720
self.assertEqual(output_nodes[0].value_dict["text"], ["Hello", "World"])
1821

1922
def test_special_function_call_with_custom_splitter(self):
20-
splitter = PatternSplitter("test_splitter", splitter=" ")
23+
splitter = PatternSplitter(
24+
{"separators": " "},
25+
"test_splitter",
26+
)
2127
node = Node(name="node1", value_dict={"text": "Hello World"})
2228

2329
output_nodes = splitter([node])

tests/op/extract/split/test_recursive_character_splitter.py

+48-51
Original file line numberDiff line numberDiff line change
@@ -8,101 +8,98 @@
88

99
class TestRecursiveCharacterSplitter(unittest.TestCase):
1010
def setUp(self):
11-
self.splitter = RecursiveCharacterSplitter("test_splitter", chunk_size=10)
12-
self.default_separators = ["\n\n", "\n", " ", ""]
11+
self.default_separators = ["\n\n", "\n"]
1312

1413
def test_recursive_splitter(self):
14+
splitter = RecursiveCharacterSplitter({"max_chunk_size": 10}, "test_splitter")
1515
text = "Hello\n\nWorld."
1616

17-
chunks = self.splitter._recursive_splitter(text, self.default_separators)
17+
chunks = splitter._recursive_splitter(text, splitter.default_separators)
1818

1919
self.assertEqual(chunks, ["Hello", "World."])
2020

21-
def test_recursive_splitter_with_merge_chunk(self):
22-
splitter = RecursiveCharacterSplitter("test_splitter", chunk_size=100)
23-
text = "Hello\n\nWorld"
21+
def test_merge_splits(self):
22+
splits = ["Hello", "World"]
23+
splitter = RecursiveCharacterSplitter({"max_chunk_size": 20}, "test_splitter")
2424

25-
chunks = splitter._recursive_splitter(text, self.default_separators)
25+
merged = splitter._merge_splits(splits, "\n")
2626

27-
self.assertEqual(chunks, ["HelloWorld"])
27+
self.assertEqual(merged, ["Hello\nWorld"])
2828

29-
def test_recursive_splitter_with_small_chunk_size(self):
30-
splitter = RecursiveCharacterSplitter("test_splitter", chunk_size=1)
31-
text = "Hello\n\nWorld"
32-
expected_chunks = ["H", "e", "l", "l", "o", "W", "o", "r", "l", "d"]
29+
def test_recursive_splitter_with_merge_chunk(self):
30+
splitter = RecursiveCharacterSplitter({"max_chunk_size": 20}, "test_splitter")
31+
node = Node(name="node1", value_dict={"text": "Hello World"})
3332

34-
chunks = splitter._recursive_splitter(text, self.default_separators)
33+
output_nodes = splitter([node])
3534

36-
self.assertEqual(chunks, expected_chunks)
35+
self.assertEqual(len(output_nodes), 1)
36+
self.assertEqual(output_nodes[0].value_dict["text"], ["Hello\\ World"])
3737

38-
def test_recursive_splitter_with_zero_chunk_size(self):
39-
splitter = RecursiveCharacterSplitter("test_splitter", chunk_size=0)
40-
text = "Hello\n\nWorld"
38+
def test_recursive_splitter_with_small_chunk_size(self):
39+
splitter = RecursiveCharacterSplitter(
40+
{"max_chunk_size": 1, "chunk_overlap_size": 0}, "test_splitter"
41+
)
42+
node = Node(name="node1", value_dict={"text": "Hello\n\nWorld"})
4143
expected_chunks = ["H", "e", "l", "l", "o", "W", "o", "r", "l", "d"]
4244

43-
chunks = splitter._recursive_splitter(text, self.default_separators)
44-
45-
self.assertEqual(chunks, expected_chunks)
46-
47-
def test_recursive_splitter_with_no_separators(self):
48-
text = "Hello\n\nWorld"
49-
separators = []
45+
output_nodes = splitter([node])
5046

51-
chunks = self.splitter._recursive_splitter(text, separators)
52-
53-
self.assertEqual(chunks, [])
47+
self.assertEqual(len(output_nodes), 1)
48+
self.assertEqual(output_nodes[0].value_dict["text"], expected_chunks)
5449

5550
def test_recursive_splitter_with_no_split(self):
56-
text = "HelloWorld"
51+
splitter = RecursiveCharacterSplitter({"max_chunk_size": 10}, "test_splitter")
52+
node = Node(name="node1", value_dict={"text": "HelloWorld"})
5753

58-
chunks = self.splitter._recursive_splitter(text, self.default_separators)
54+
output_nodes = splitter([node])
5955

60-
self.assertEqual(chunks, ["HelloWorld"])
56+
self.assertEqual(len(output_nodes), 1)
57+
self.assertEqual(output_nodes[0].value_dict["text"], ["HelloWorld"])
6158

6259
def test_recursive_splitter_with_custom_separators(self):
63-
text = "Hello--World."
64-
separators = ["-", " "]
60+
splitter = RecursiveCharacterSplitter(
61+
{"max_chunk_size": 10, "separators": "--"}, "test_splitter"
62+
)
63+
node = Node(name="node1", value_dict={"text": "Hello--World"})
6564

66-
chunks = self.splitter._recursive_splitter(text, separators)
65+
output_nodes = splitter([node])
6766

68-
self.assertEqual(chunks, ["Hello", "World."])
67+
self.assertEqual(len(output_nodes), 1)
68+
self.assertEqual(output_nodes[0].value_dict["text"], ["Hello", "World"])
6969

7070
def test_recursive_splitter_with_large_text_default_chunk(self):
71-
text = "Hello\n\nWorld\n\n" * 100
71+
splitter = RecursiveCharacterSplitter({"max_chunk_size": 20}, "test_splitter")
72+
node = Node(name="node1", value_dict={"text": "Hello\n\nWorld\n\n" * 100})
7273

73-
chunks = self.splitter._recursive_splitter(text, self.default_separators)
74+
output_nodes = splitter([node])
7475

75-
self.assertEqual(len(chunks), 100)
76+
self.assertEqual(len(output_nodes), 1)
77+
self.assertEqual(len(output_nodes[0].value_dict["text"]), 100)
7678

7779
def test_recursive_splitter_with_large_text_large_chunk(self):
78-
splitter = RecursiveCharacterSplitter("test_splitter", chunk_size=9999)
79-
text = "Hello\n\nWorld\n\n" * 100
80+
splitter = RecursiveCharacterSplitter({"max_chunk_size": 9999}, "test_splitter")
81+
node = Node(name="node1", value_dict={"text": "Hello\n\nWorld\n\n" * 100})
8082

81-
chunks = splitter._recursive_splitter(text, self.default_separators)
82-
83-
self.assertEqual(len(chunks), 1)
84-
self.assertEqual(chunks, ["HelloWorld" * 100])
85-
86-
def test_special_function_call(self):
87-
node = Node(name="node1", value_dict={"text": "Hello\n\nWorld"})
88-
output_nodes = self.splitter([node])
83+
output_nodes = splitter([node])
8984

9085
self.assertEqual(len(output_nodes), 1)
91-
self.assertEqual(output_nodes[0].value_dict["text"], ["HelloWorld"])
86+
self.assertEqual(len(output_nodes[0].value_dict["text"]), 1)
9287

9388
def test_special_function_call_with_multiple_nodes(self):
89+
splitter = RecursiveCharacterSplitter({"max_chunk_size": 10}, "test_splitter")
90+
9491
node0 = Node(name="node1", value_dict={"text": "Hello\n\nWorld"})
9592
node1 = Node(name="node1", value_dict={"text": "Hello\n\nWorld."})
9693
node2 = Node(name="node1", value_dict={"text": "Hello\n\nWorld\n\n" * 10})
9794
node3 = Node(name="node1", value_dict={"text": "Hello\n\nWorld.\n\n" * 2})
9895
expected_texts = [
99-
["HelloWorld"],
96+
["Hello", "World"],
10097
["Hello", "World."],
101-
["HelloWorld"] * 10,
98+
["Hello", "World"] * 10,
10299
["Hello", "World.", "Hello", "World."],
103100
]
104101

105-
output_nodes = self.splitter([node0, node1, node2, node3])
102+
output_nodes = splitter([node0, node1, node2, node3])
106103
output_texts = [node.value_dict["text"] for node in output_nodes]
107104

108105
self.assertEqual(output_texts, expected_texts)
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,22 @@
11
import unittest
22

3-
from uniflow.op.extract.split.constants import (
4-
MARKDOWN_HEADER_SPLITTER,
5-
PARAGRAPH_SPLITTER,
6-
RECURSIVE_CHARACTER_SPLITTER,
7-
)
8-
from uniflow.op.extract.split.markdown_header_splitter import MarkdownHeaderSplitter
3+
from uniflow.op.extract.split.constants import PARAGRAPH_SPLITTER
94
from uniflow.op.extract.split.pattern_splitter_op import PatternSplitter
10-
from uniflow.op.extract.split.recursive_character_splitter import (
11-
RecursiveCharacterSplitter,
12-
)
135
from uniflow.op.extract.split.splitter_factory import SplitterOpsFactory
146

157

168
class TestSplitterOpsFactory(unittest.TestCase):
17-
def setUp(self):
18-
self.paragraph_splitter = SplitterOpsFactory.get(PARAGRAPH_SPLITTER)
19-
self.markdown_header_splitter = SplitterOpsFactory.get(MARKDOWN_HEADER_SPLITTER)
20-
self.recursive_character_splitter = SplitterOpsFactory.get(
21-
RECURSIVE_CHARACTER_SPLITTER
22-
)
9+
def test_get_with_valid_config(self):
10+
config = {"splitter_func": PARAGRAPH_SPLITTER}
11+
splitter = SplitterOpsFactory.get(config)
12+
self.assertIsInstance(splitter, PatternSplitter)
2313

24-
def test_get(self):
25-
self.assertTrue(isinstance(self.paragraph_splitter, PatternSplitter))
26-
self.assertTrue(
27-
isinstance(self.markdown_header_splitter, MarkdownHeaderSplitter)
28-
)
29-
self.assertTrue(
30-
isinstance(self.recursive_character_splitter, RecursiveCharacterSplitter)
31-
)
32-
33-
def test_get_with_invalid_name(self):
14+
def test_get_with_invalid_config(self):
15+
config = {"splitter_func": "invalid"}
3416
with self.assertRaises(ValueError):
35-
SplitterOpsFactory.get("")
36-
37-
def test_list(self):
38-
excepted_splitters = [
39-
PARAGRAPH_SPLITTER,
40-
MARKDOWN_HEADER_SPLITTER,
41-
RECURSIVE_CHARACTER_SPLITTER,
42-
]
17+
SplitterOpsFactory.get(config)
4318

44-
self.assertEqual(SplitterOpsFactory.list(), excepted_splitters)
19+
def test_get_with_empty_config(self):
20+
config = {}
21+
with self.assertRaises(KeyError):
22+
SplitterOpsFactory.get(config)

uniflow/op/extract/split/pattern_splitter_op.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
class PatternSplitter(Op):
1212
"""Pattern Splitter Op Class"""
1313

14+
default_separators = "\n\n|\n"
15+
default_min_chunk_size = 1
16+
1417
def __init__(
1518
self, splitterConfig: dict[str, any], name: str = "paragraph_split_op"
1619
) -> None:
@@ -22,6 +25,12 @@ def __init__(
2225
"""
2326
super().__init__(name)
2427
self._splitter_config = splitterConfig
28+
self._separators = (
29+
"separators" in splitterConfig and splitterConfig["separators"]
30+
) or self.default_separators
31+
self._min_chunk_size = (
32+
"min_chunk_size" in splitterConfig and splitterConfig["min_chunk_size"]
33+
) or self.default_min_chunk_size
2534

2635
def __call__(
2736
self,
@@ -39,8 +48,8 @@ def __call__(
3948
for node in nodes:
4049
value_dict = copy.deepcopy(node.value_dict)
4150
text = value_dict["text"]
42-
text = re.split(self._splitter_config["separators"], text)
43-
text = [p for p in text if len(p) > self._splitter_config["min_chunk_size"]]
51+
text = re.split(self._separators, text)
52+
text = [p for p in text if len(p) > self._min_chunk_size]
4453
output_nodes.append(
4554
Node(
4655
name=self.unique_name(),

uniflow/op/extract/split/recursive_character_splitter.py

+25-7
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import re
55
from typing import Iterable, List, Sequence
66

7-
import tiktoken # Import necessary for token-based splitting
7+
import tiktoken
88

99
from uniflow.node import Node
1010
from uniflow.op.op import Op
@@ -15,7 +15,7 @@ class RecursiveCharacterSplitter(Op):
1515

1616
default_chunk_size = 1024
1717
default_chunk_overlap_size = 32
18-
default_separators = "\n\n|\n|. |.|, | "
18+
default_separators = "\n\n|\n|. |.|, | |"
1919
default_splitting_mode = "char"
2020

2121
def __init__(
@@ -34,14 +34,18 @@ def __init__(
3434
chunk_overlap_size (int): Overlap in characters between chunks.
3535
separators (List[str]): Separators to use.
3636
splitting_mode (str): "char" for character count, "token" for token count. Defaults to "char".
37+
keep_separator (bool): Whether to keep the separator. Defaults to True.
38+
is_separator_regex (bool): Whether the separator is a regex. Defaults to False.
3739
"""
3840
super().__init__(name)
3941

4042
# Set up the splitter configuration
4143
self._chunk_size = splitterConfig["max_chunk_size"] or self.default_chunk_size
4244
self._separators = (
43-
splitterConfig["separators"] or self.default_separators
45+
("separators" in splitterConfig and splitterConfig["separators"])
46+
or self.default_separators
4447
).split("|")
48+
print(f"Separators: {self._separators}")
4549

4650
# Set up the splitter configuration for recursive splitting
4751
self._chunk_overlap_size = (
@@ -51,6 +55,16 @@ def __init__(
5155
self._splitting_mode = (
5256
"splitting_mode" in splitterConfig and splitterConfig["splitting_mode"]
5357
) or self.default_splitting_mode
58+
self._keep_separator = (
59+
True
60+
and ("keep_separator" in splitterConfig)
61+
and splitterConfig["keep_separator"]
62+
)
63+
self._is_separator_regex = (
64+
("is_separator_regex" in splitterConfig)
65+
and splitterConfig["is_separator_regex"]
66+
or False
67+
)
5468

5569
self._encoder = tiktoken.encoding_for_model(
5670
"gpt-3.5"
@@ -118,18 +132,21 @@ def _recursive_splitter(self, text: str, separators: List[str]) -> List[str]:
118132
break
119133

120134
# Splited by current separator firstly
121-
cur_separator = re.escape(cur_separator)
135+
cur_separator = (
136+
cur_separator if self._is_separator_regex else re.escape(cur_separator)
137+
)
122138
splits = [s for s in re.split(cur_separator, text) if s != ""]
123139

124140
# Then go merging things, recursively splitting longer texts.
125-
_tmp_splits, _separator = [], ""
141+
_tmp_splits = []
142+
merge_separator = "" if self._keep_separator else _separator
126143
for s in splits:
127144
if self._get_length(s) <= self._chunk_size:
128145
_tmp_splits.append(s)
129146
else:
130147
# merge splitted texts into a chunk
131148
if _tmp_splits:
132-
merged_text = self._merge_splits(_tmp_splits, _separator)
149+
merged_text = self._merge_splits(_tmp_splits, merge_separator)
133150
final_chunks.extend(merged_text)
134151
# reset tmp_splits
135152
_tmp_splits = []
@@ -142,7 +159,7 @@ def _recursive_splitter(self, text: str, separators: List[str]) -> List[str]:
142159
final_chunks.extend(other_info)
143160

144161
if _tmp_splits:
145-
merged_text = self._merge_splits(_tmp_splits, _separator)
162+
merged_text = self._merge_splits(_tmp_splits, merge_separator)
146163
final_chunks.extend(merged_text)
147164

148165
return final_chunks
@@ -177,6 +194,7 @@ def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
177194
doc = separator.join(current_doc).strip()
178195
if doc is not None:
179196
docs.append(doc)
197+
180198
# Keep on popping if:
181199
# - we have a larger chunk than in the chunk overlap
182200
# - or if we still have any chunks and the length is long

0 commit comments

Comments
 (0)