|
8 | 8 |
|
9 | 9 | class TestRecursiveCharacterSplitter(unittest.TestCase):
|
10 | 10 | def setUp(self):
|
11 |
| - self.splitter = RecursiveCharacterSplitter("test_splitter", chunk_size=10) |
12 |
| - self.default_separators = ["\n\n", "\n", " ", ""] |
| 11 | + self.default_separators = ["\n\n", "\n"] |
13 | 12 |
|
14 | 13 | def test_recursive_splitter(self):
|
| 14 | + splitter = RecursiveCharacterSplitter({"max_chunk_size": 10}, "test_splitter") |
15 | 15 | text = "Hello\n\nWorld."
|
16 | 16 |
|
17 |
| - chunks = self.splitter._recursive_splitter(text, self.default_separators) |
| 17 | + chunks = splitter._recursive_splitter(text, splitter.default_separators) |
18 | 18 |
|
19 | 19 | self.assertEqual(chunks, ["Hello", "World."])
|
20 | 20 |
|
21 |
| - def test_recursive_splitter_with_merge_chunk(self): |
22 |
| - splitter = RecursiveCharacterSplitter("test_splitter", chunk_size=100) |
23 |
| - text = "Hello\n\nWorld" |
| 21 | + def test_merge_splits(self): |
| 22 | + splits = ["Hello", "World"] |
| 23 | + splitter = RecursiveCharacterSplitter({"max_chunk_size": 20}, "test_splitter") |
24 | 24 |
|
25 |
| - chunks = splitter._recursive_splitter(text, self.default_separators) |
| 25 | + merged = splitter._merge_splits(splits, "\n") |
26 | 26 |
|
27 |
| - self.assertEqual(chunks, ["HelloWorld"]) |
| 27 | + self.assertEqual(merged, ["Hello\nWorld"]) |
28 | 28 |
|
29 |
| - def test_recursive_splitter_with_small_chunk_size(self): |
30 |
| - splitter = RecursiveCharacterSplitter("test_splitter", chunk_size=1) |
31 |
| - text = "Hello\n\nWorld" |
32 |
| - expected_chunks = ["H", "e", "l", "l", "o", "W", "o", "r", "l", "d"] |
| 29 | + def test_recursive_splitter_with_merge_chunk(self): |
| 30 | + splitter = RecursiveCharacterSplitter({"max_chunk_size": 20}, "test_splitter") |
| 31 | + node = Node(name="node1", value_dict={"text": "Hello World"}) |
33 | 32 |
|
34 |
| - chunks = splitter._recursive_splitter(text, self.default_separators) |
| 33 | + output_nodes = splitter([node]) |
35 | 34 |
|
36 |
| - self.assertEqual(chunks, expected_chunks) |
| 35 | + self.assertEqual(len(output_nodes), 1) |
| 36 | + self.assertEqual(output_nodes[0].value_dict["text"], ["Hello\\ World"]) |
37 | 37 |
|
38 |
| - def test_recursive_splitter_with_zero_chunk_size(self): |
39 |
| - splitter = RecursiveCharacterSplitter("test_splitter", chunk_size=0) |
40 |
| - text = "Hello\n\nWorld" |
| 38 | + def test_recursive_splitter_with_small_chunk_size(self): |
| 39 | + splitter = RecursiveCharacterSplitter( |
| 40 | + {"max_chunk_size": 1, "chunk_overlap_size": 0}, "test_splitter" |
| 41 | + ) |
| 42 | + node = Node(name="node1", value_dict={"text": "Hello\n\nWorld"}) |
41 | 43 | expected_chunks = ["H", "e", "l", "l", "o", "W", "o", "r", "l", "d"]
|
42 | 44 |
|
43 |
| - chunks = splitter._recursive_splitter(text, self.default_separators) |
44 |
| - |
45 |
| - self.assertEqual(chunks, expected_chunks) |
46 |
| - |
47 |
| - def test_recursive_splitter_with_no_separators(self): |
48 |
| - text = "Hello\n\nWorld" |
49 |
| - separators = [] |
| 45 | + output_nodes = splitter([node]) |
50 | 46 |
|
51 |
| - chunks = self.splitter._recursive_splitter(text, separators) |
52 |
| - |
53 |
| - self.assertEqual(chunks, []) |
| 47 | + self.assertEqual(len(output_nodes), 1) |
| 48 | + self.assertEqual(output_nodes[0].value_dict["text"], expected_chunks) |
54 | 49 |
|
55 | 50 | def test_recursive_splitter_with_no_split(self):
|
56 |
| - text = "HelloWorld" |
| 51 | + splitter = RecursiveCharacterSplitter({"max_chunk_size": 10}, "test_splitter") |
| 52 | + node = Node(name="node1", value_dict={"text": "HelloWorld"}) |
57 | 53 |
|
58 |
| - chunks = self.splitter._recursive_splitter(text, self.default_separators) |
| 54 | + output_nodes = splitter([node]) |
59 | 55 |
|
60 |
| - self.assertEqual(chunks, ["HelloWorld"]) |
| 56 | + self.assertEqual(len(output_nodes), 1) |
| 57 | + self.assertEqual(output_nodes[0].value_dict["text"], ["HelloWorld"]) |
61 | 58 |
|
62 | 59 | def test_recursive_splitter_with_custom_separators(self):
|
63 |
| - text = "Hello--World." |
64 |
| - separators = ["-", " "] |
| 60 | + splitter = RecursiveCharacterSplitter( |
| 61 | + {"max_chunk_size": 10, "separators": "--"}, "test_splitter" |
| 62 | + ) |
| 63 | + node = Node(name="node1", value_dict={"text": "Hello--World"}) |
65 | 64 |
|
66 |
| - chunks = self.splitter._recursive_splitter(text, separators) |
| 65 | + output_nodes = splitter([node]) |
67 | 66 |
|
68 |
| - self.assertEqual(chunks, ["Hello", "World."]) |
| 67 | + self.assertEqual(len(output_nodes), 1) |
| 68 | + self.assertEqual(output_nodes[0].value_dict["text"], ["Hello", "World"]) |
69 | 69 |
|
70 | 70 | def test_recursive_splitter_with_large_text_default_chunk(self):
|
71 |
| - text = "Hello\n\nWorld\n\n" * 100 |
| 71 | + splitter = RecursiveCharacterSplitter({"max_chunk_size": 20}, "test_splitter") |
| 72 | + node = Node(name="node1", value_dict={"text": "Hello\n\nWorld\n\n" * 100}) |
72 | 73 |
|
73 |
| - chunks = self.splitter._recursive_splitter(text, self.default_separators) |
| 74 | + output_nodes = splitter([node]) |
74 | 75 |
|
75 |
| - self.assertEqual(len(chunks), 100) |
| 76 | + self.assertEqual(len(output_nodes), 1) |
| 77 | + self.assertEqual(len(output_nodes[0].value_dict["text"]), 100) |
76 | 78 |
|
77 | 79 | def test_recursive_splitter_with_large_text_large_chunk(self):
|
78 |
| - splitter = RecursiveCharacterSplitter("test_splitter", chunk_size=9999) |
79 |
| - text = "Hello\n\nWorld\n\n" * 100 |
| 80 | + splitter = RecursiveCharacterSplitter({"max_chunk_size": 9999}, "test_splitter") |
| 81 | + node = Node(name="node1", value_dict={"text": "Hello\n\nWorld\n\n" * 100}) |
80 | 82 |
|
81 |
| - chunks = splitter._recursive_splitter(text, self.default_separators) |
82 |
| - |
83 |
| - self.assertEqual(len(chunks), 1) |
84 |
| - self.assertEqual(chunks, ["HelloWorld" * 100]) |
85 |
| - |
86 |
| - def test_special_function_call(self): |
87 |
| - node = Node(name="node1", value_dict={"text": "Hello\n\nWorld"}) |
88 |
| - output_nodes = self.splitter([node]) |
| 83 | + output_nodes = splitter([node]) |
89 | 84 |
|
90 | 85 | self.assertEqual(len(output_nodes), 1)
|
91 |
| - self.assertEqual(output_nodes[0].value_dict["text"], ["HelloWorld"]) |
| 86 | + self.assertEqual(len(output_nodes[0].value_dict["text"]), 1) |
92 | 87 |
|
93 | 88 | def test_special_function_call_with_multiple_nodes(self):
|
| 89 | + splitter = RecursiveCharacterSplitter({"max_chunk_size": 10}, "test_splitter") |
| 90 | + |
94 | 91 | node0 = Node(name="node1", value_dict={"text": "Hello\n\nWorld"})
|
95 | 92 | node1 = Node(name="node1", value_dict={"text": "Hello\n\nWorld."})
|
96 | 93 | node2 = Node(name="node1", value_dict={"text": "Hello\n\nWorld\n\n" * 10})
|
97 | 94 | node3 = Node(name="node1", value_dict={"text": "Hello\n\nWorld.\n\n" * 2})
|
98 | 95 | expected_texts = [
|
99 |
| - ["HelloWorld"], |
| 96 | + ["Hello", "World"], |
100 | 97 | ["Hello", "World."],
|
101 |
| - ["HelloWorld"] * 10, |
| 98 | + ["Hello", "World"] * 10, |
102 | 99 | ["Hello", "World.", "Hello", "World."],
|
103 | 100 | ]
|
104 | 101 |
|
105 |
| - output_nodes = self.splitter([node0, node1, node2, node3]) |
| 102 | + output_nodes = splitter([node0, node1, node2, node3]) |
106 | 103 | output_texts = [node.value_dict["text"] for node in output_nodes]
|
107 | 104 |
|
108 | 105 | self.assertEqual(output_texts, expected_texts)
|
0 commit comments