Skip to content

Commit bdee4be

Browse files
author
Cambio ML
authored
Merge pull request #233 from SayaZhang/main
Customized splitter config
2 parents 03b8f65 + a0977ba commit bdee4be

14 files changed

+222
-172
lines changed

example/extract/extract_pdf_with_recursive_splitter.ipynb

+10-13
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
"import pandas as pd\n",
7171
"import pprint\n",
7272
"from uniflow.flow.client import ExtractClient, TransformClient\n",
73-
"from uniflow.flow.config import TransformOpenAIConfig, ExtractPDFConfig\n",
73+
"from uniflow.flow.config import TransformOpenAIConfig, ExtractPDFConfig, SplitterConfig\n",
7474
"from uniflow.op.model.model_config import OpenAIModelConfig, NougatModelConfig\n",
7575
"from uniflow.op.prompt import PromptTemplate, Context\n",
7676
"from uniflow.op.extract.split.splitter_factory import SplitterOpsFactory\n",
@@ -136,26 +136,23 @@
136136
"cell_type": "code",
137137
"execution_count": 5,
138138
"metadata": {},
139-
"outputs": [
140-
{
141-
"name": "stderr",
142-
"output_type": "stream",
143-
"text": [
144-
"/home/ubuntu/anaconda3/envs/uniflow/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3526.)\n",
145-
" return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n"
146-
]
147-
}
148-
],
139+
"outputs": [],
149140
"source": [
150141
"data = [\n",
151142
" {\"filename\": input_file},\n",
152143
"]\n",
153144
"\n",
145+
"splitter_config = SplitterConfig(\n",
146+
" max_chunk_size = 1024,\n",
147+
" splitter_func = RECURSIVE_CHARACTER_SPLITTER\n",
148+
" )\n",
149+
"splitter_config.chunk_overlap_size = 5\n",
150+
"\n",
154151
"config = ExtractPDFConfig(\n",
155152
" model_config=NougatModelConfig(\n",
156153
" batch_size = 1 # When batch_size>1, nougat will run on CUDA, otherwise it will run on CPU\n",
157154
" ),\n",
158-
" splitter=RECURSIVE_CHARACTER_SPLITTER,\n",
155+
" splitter_config=splitter_config,\n",
159156
")\n",
160157
"nougat_client = ExtractClient(config)"
161158
]
@@ -176,7 +173,7 @@
176173
"name": "stderr",
177174
"output_type": "stream",
178175
"text": [
179-
"100%|██████████| 1/1 [00:05<00:00, 5.07s/it]\n"
176+
"100%|██████████| 1/1 [00:03<00:00, 3.23s/it]\n"
180177
]
181178
}
182179
],

example/extract/extract_txt.ipynb

+37-22
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
" 'ExtractTxtFlow'],\n",
5656
" 'transform': ['TransformAzureOpenAIFlow',\n",
5757
" 'TransformCopyFlow',\n",
58+
" 'TransformGoogleFlow',\n",
59+
" 'TransformGoogleMultiModalModelFlow',\n",
5860
" 'TransformHuggingFaceFlow',\n",
5961
" 'TransformLMQGFlow',\n",
6062
" 'TransformOpenAIFlow'],\n",
@@ -68,9 +70,14 @@
6870
],
6971
"source": [
7072
"from uniflow.flow.client import ExtractClient\n",
71-
"from uniflow.flow.config import ExtractTxtConfig\n",
73+
"from uniflow.flow.config import ExtractTxtConfig, SplitterConfig\n",
7274
"from uniflow.viz import Viz\n",
7375
"from uniflow.flow.flow_factory import FlowFactory\n",
76+
"from uniflow.op.extract.split.constants import (\n",
77+
" MARKDOWN_HEADER_SPLITTER,\n",
78+
" PARAGRAPH_SPLITTER,\n",
79+
" RECURSIVE_CHARACTER_SPLITTER,\n",
80+
")\n",
7481
"\n",
7582
"FlowFactory.list()"
7683
]
@@ -104,7 +111,15 @@
104111
"metadata": {},
105112
"outputs": [],
106113
"source": [
107-
"client = ExtractClient(ExtractTxtConfig())"
114+
"client = ExtractClient(\n",
115+
" ExtractTxtConfig(\n",
116+
" splitter_config=SplitterConfig(\n",
117+
" min_chunk_size = 5, \n",
118+
" separators = \"\\n\\n|\\n\", \n",
119+
" splitter_func = PARAGRAPH_SPLITTER\n",
120+
" )\n",
121+
" )\n",
122+
")"
108123
]
109124
},
110125
{
@@ -116,7 +131,7 @@
116131
"name": "stderr",
117132
"output_type": "stream",
118133
"text": [
119-
"100%|██████████| 1/1 [00:00<00:00, 13066.37it/s]\n"
134+
"100%|██████████| 1/1 [00:00<00:00, 14217.98it/s]\n"
120135
]
121136
}
122137
],
@@ -158,7 +173,7 @@
158173
" 'benefit to humanity. In all of these, the rich get richer.',\n",
159174
" \"You can't understand the world without understanding the concept of \"\n",
160175
" \"superlinear returns. And if you're ambitious you definitely should, because \"\n",
161-
" 'this will be the wave you surf on.\\n']\n"
176+
" 'this will be the wave you surf on.']\n"
162177
]
163178
}
164179
],
@@ -189,46 +204,46 @@
189204
"<!-- Generated by graphviz version 2.43.0 (0)\n",
190205
" -->\n",
191206
"<!-- Title: %3 Pages: 1 -->\n",
192-
"<svg width=\"271pt\" height=\"188pt\"\n",
193-
" viewBox=\"0.00 0.00 270.58 188.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
207+
"<svg width=\"315pt\" height=\"188pt\"\n",
208+
" viewBox=\"0.00 0.00 314.77 188.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
194209
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 184)\">\n",
195210
"<title>%3</title>\n",
196-
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-184 266.58,-184 266.58,4 -4,4\"/>\n",
211+
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-184 310.77,-184 310.77,4 -4,4\"/>\n",
197212
"<!-- root -->\n",
198213
"<g id=\"node1\" class=\"node\">\n",
199214
"<title>root</title>\n",
200-
"<ellipse fill=\"none\" stroke=\"black\" cx=\"131.29\" cy=\"-162\" rx=\"29.8\" ry=\"18\"/>\n",
201-
"<text text-anchor=\"middle\" x=\"131.29\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">root</text>\n",
215+
"<ellipse fill=\"none\" stroke=\"black\" cx=\"153.39\" cy=\"-162\" rx=\"29.8\" ry=\"18\"/>\n",
216+
"<text text-anchor=\"middle\" x=\"153.39\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">root</text>\n",
202217
"</g>\n",
203218
"<!-- thread_0/extract_txt_op_1 -->\n",
204219
"<g id=\"node2\" class=\"node\">\n",
205220
"<title>thread_0/extract_txt_op_1</title>\n",
206-
"<ellipse fill=\"none\" stroke=\"black\" cx=\"131.29\" cy=\"-90\" rx=\"131.08\" ry=\"18\"/>\n",
207-
"<text text-anchor=\"middle\" x=\"131.29\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">thread_0/extract_txt_op_1</text>\n",
221+
"<ellipse fill=\"none\" stroke=\"black\" cx=\"153.39\" cy=\"-90\" rx=\"131.08\" ry=\"18\"/>\n",
222+
"<text text-anchor=\"middle\" x=\"153.39\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">thread_0/extract_txt_op_1</text>\n",
208223
"</g>\n",
209224
"<!-- root&#45;&gt;thread_0/extract_txt_op_1 -->\n",
210225
"<g id=\"edge1\" class=\"edge\">\n",
211226
"<title>root&#45;&gt;thread_0/extract_txt_op_1</title>\n",
212-
"<path fill=\"none\" stroke=\"black\" d=\"M131.29,-143.7C131.29,-135.98 131.29,-126.71 131.29,-118.11\"/>\n",
213-
"<polygon fill=\"black\" stroke=\"black\" points=\"134.79,-118.1 131.29,-108.1 127.79,-118.1 134.79,-118.1\"/>\n",
227+
"<path fill=\"none\" stroke=\"black\" d=\"M153.39,-143.7C153.39,-135.98 153.39,-126.71 153.39,-118.11\"/>\n",
228+
"<polygon fill=\"black\" stroke=\"black\" points=\"156.89,-118.1 153.39,-108.1 149.89,-118.1 156.89,-118.1\"/>\n",
214229
"</g>\n",
215-
"<!-- paragraph_split_op_1 -->\n",
230+
"<!-- thread_0/paragraph_split_op_1 -->\n",
216231
"<g id=\"node3\" class=\"node\">\n",
217-
"<title>paragraph_split_op_1</title>\n",
218-
"<ellipse fill=\"none\" stroke=\"black\" cx=\"131.29\" cy=\"-18\" rx=\"109.68\" ry=\"18\"/>\n",
219-
"<text text-anchor=\"middle\" x=\"131.29\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">paragraph_split_op_1</text>\n",
232+
"<title>thread_0/paragraph_split_op_1</title>\n",
233+
"<ellipse fill=\"none\" stroke=\"black\" cx=\"153.39\" cy=\"-18\" rx=\"153.27\" ry=\"18\"/>\n",
234+
"<text text-anchor=\"middle\" x=\"153.39\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">thread_0/paragraph_split_op_1</text>\n",
220235
"</g>\n",
221-
"<!-- thread_0/extract_txt_op_1&#45;&gt;paragraph_split_op_1 -->\n",
236+
"<!-- thread_0/extract_txt_op_1&#45;&gt;thread_0/paragraph_split_op_1 -->\n",
222237
"<g id=\"edge2\" class=\"edge\">\n",
223-
"<title>thread_0/extract_txt_op_1&#45;&gt;paragraph_split_op_1</title>\n",
224-
"<path fill=\"none\" stroke=\"black\" d=\"M131.29,-71.7C131.29,-63.98 131.29,-54.71 131.29,-46.11\"/>\n",
225-
"<polygon fill=\"black\" stroke=\"black\" points=\"134.79,-46.1 131.29,-36.1 127.79,-46.1 134.79,-46.1\"/>\n",
238+
"<title>thread_0/extract_txt_op_1&#45;&gt;thread_0/paragraph_split_op_1</title>\n",
239+
"<path fill=\"none\" stroke=\"black\" d=\"M153.39,-71.7C153.39,-63.98 153.39,-54.71 153.39,-46.11\"/>\n",
240+
"<polygon fill=\"black\" stroke=\"black\" points=\"156.89,-46.1 153.39,-36.1 149.89,-46.1 156.89,-46.1\"/>\n",
226241
"</g>\n",
227242
"</g>\n",
228243
"</svg>\n"
229244
],
230245
"text/plain": [
231-
"<graphviz.graphs.Digraph at 0x7fa3a0bbb340>"
246+
"<graphviz.graphs.Digraph at 0x7f09a99de1d0>"
232247
]
233248
},
234249
"metadata": {},

run_tests.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
#!/bin/sh
2-
python -m unittest discover tests
2+
python3 -m unittest discover tests

tests/op/extract/split/test_pattern_splitter_op.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66

77
class TestPatternSplitter(unittest.TestCase):
88
def setUp(self):
9-
self.splitter = PatternSplitter("test_splitter")
9+
self.splitter = PatternSplitter({}, "test_splitter")
1010

1111
def test_special_function_call(self):
12+
"""
13+
Test special function call.
14+
"""
1215
node = Node(name="node1", value_dict={"text": "Hello\n\nWorld"})
1316

1417
output_nodes = self.splitter([node])
@@ -17,7 +20,10 @@ def test_special_function_call(self):
1720
self.assertEqual(output_nodes[0].value_dict["text"], ["Hello", "World"])
1821

1922
def test_special_function_call_with_custom_splitter(self):
20-
splitter = PatternSplitter("test_splitter", splitter=" ")
23+
splitter = PatternSplitter(
24+
{"separators": " "},
25+
"test_splitter",
26+
)
2127
node = Node(name="node1", value_dict={"text": "Hello World"})
2228

2329
output_nodes = splitter([node])

tests/op/extract/split/test_recursive_character_splitter.py

+48-51
Original file line numberDiff line numberDiff line change
@@ -8,101 +8,98 @@
88

99
class TestRecursiveCharacterSplitter(unittest.TestCase):
1010
def setUp(self):
11-
self.splitter = RecursiveCharacterSplitter("test_splitter", chunk_size=10)
12-
self.default_separators = ["\n\n", "\n", " ", ""]
11+
self.default_separators = ["\n\n", "\n"]
1312

1413
def test_recursive_splitter(self):
14+
splitter = RecursiveCharacterSplitter({"max_chunk_size": 10}, "test_splitter")
1515
text = "Hello\n\nWorld."
1616

17-
chunks = self.splitter._recursive_splitter(text, self.default_separators)
17+
chunks = splitter._recursive_splitter(text, splitter.default_separators)
1818

1919
self.assertEqual(chunks, ["Hello", "World."])
2020

21-
def test_recursive_splitter_with_merge_chunk(self):
22-
splitter = RecursiveCharacterSplitter("test_splitter", chunk_size=100)
23-
text = "Hello\n\nWorld"
21+
def test_merge_splits(self):
22+
splits = ["Hello", "World"]
23+
splitter = RecursiveCharacterSplitter({"max_chunk_size": 20}, "test_splitter")
2424

25-
chunks = splitter._recursive_splitter(text, self.default_separators)
25+
merged = splitter._merge_splits(splits, "\n")
2626

27-
self.assertEqual(chunks, ["HelloWorld"])
27+
self.assertEqual(merged, ["Hello\nWorld"])
2828

29-
def test_recursive_splitter_with_small_chunk_size(self):
30-
splitter = RecursiveCharacterSplitter("test_splitter", chunk_size=1)
31-
text = "Hello\n\nWorld"
32-
expected_chunks = ["H", "e", "l", "l", "o", "W", "o", "r", "l", "d"]
29+
def test_recursive_splitter_with_merge_chunk(self):
30+
splitter = RecursiveCharacterSplitter({"max_chunk_size": 20}, "test_splitter")
31+
node = Node(name="node1", value_dict={"text": "Hello World"})
3332

34-
chunks = splitter._recursive_splitter(text, self.default_separators)
33+
output_nodes = splitter([node])
3534

36-
self.assertEqual(chunks, expected_chunks)
35+
self.assertEqual(len(output_nodes), 1)
36+
self.assertEqual(output_nodes[0].value_dict["text"], ["Hello\\ World"])
3737

38-
def test_recursive_splitter_with_zero_chunk_size(self):
39-
splitter = RecursiveCharacterSplitter("test_splitter", chunk_size=0)
40-
text = "Hello\n\nWorld"
38+
def test_recursive_splitter_with_small_chunk_size(self):
39+
splitter = RecursiveCharacterSplitter(
40+
{"max_chunk_size": 1, "chunk_overlap_size": 0}, "test_splitter"
41+
)
42+
node = Node(name="node1", value_dict={"text": "Hello\n\nWorld"})
4143
expected_chunks = ["H", "e", "l", "l", "o", "W", "o", "r", "l", "d"]
4244

43-
chunks = splitter._recursive_splitter(text, self.default_separators)
44-
45-
self.assertEqual(chunks, expected_chunks)
46-
47-
def test_recursive_splitter_with_no_separators(self):
48-
text = "Hello\n\nWorld"
49-
separators = []
45+
output_nodes = splitter([node])
5046

51-
chunks = self.splitter._recursive_splitter(text, separators)
52-
53-
self.assertEqual(chunks, [])
47+
self.assertEqual(len(output_nodes), 1)
48+
self.assertEqual(output_nodes[0].value_dict["text"], expected_chunks)
5449

5550
def test_recursive_splitter_with_no_split(self):
56-
text = "HelloWorld"
51+
splitter = RecursiveCharacterSplitter({"max_chunk_size": 10}, "test_splitter")
52+
node = Node(name="node1", value_dict={"text": "HelloWorld"})
5753

58-
chunks = self.splitter._recursive_splitter(text, self.default_separators)
54+
output_nodes = splitter([node])
5955

60-
self.assertEqual(chunks, ["HelloWorld"])
56+
self.assertEqual(len(output_nodes), 1)
57+
self.assertEqual(output_nodes[0].value_dict["text"], ["HelloWorld"])
6158

6259
def test_recursive_splitter_with_custom_separators(self):
63-
text = "Hello--World."
64-
separators = ["-", " "]
60+
splitter = RecursiveCharacterSplitter(
61+
{"max_chunk_size": 10, "separators": "--"}, "test_splitter"
62+
)
63+
node = Node(name="node1", value_dict={"text": "Hello--World"})
6564

66-
chunks = self.splitter._recursive_splitter(text, separators)
65+
output_nodes = splitter([node])
6766

68-
self.assertEqual(chunks, ["Hello", "World."])
67+
self.assertEqual(len(output_nodes), 1)
68+
self.assertEqual(output_nodes[0].value_dict["text"], ["Hello", "World"])
6969

7070
def test_recursive_splitter_with_large_text_default_chunk(self):
71-
text = "Hello\n\nWorld\n\n" * 100
71+
splitter = RecursiveCharacterSplitter({"max_chunk_size": 20}, "test_splitter")
72+
node = Node(name="node1", value_dict={"text": "Hello\n\nWorld\n\n" * 100})
7273

73-
chunks = self.splitter._recursive_splitter(text, self.default_separators)
74+
output_nodes = splitter([node])
7475

75-
self.assertEqual(len(chunks), 100)
76+
self.assertEqual(len(output_nodes), 1)
77+
self.assertEqual(len(output_nodes[0].value_dict["text"]), 100)
7678

7779
def test_recursive_splitter_with_large_text_large_chunk(self):
78-
splitter = RecursiveCharacterSplitter("test_splitter", chunk_size=9999)
79-
text = "Hello\n\nWorld\n\n" * 100
80+
splitter = RecursiveCharacterSplitter({"max_chunk_size": 9999}, "test_splitter")
81+
node = Node(name="node1", value_dict={"text": "Hello\n\nWorld\n\n" * 100})
8082

81-
chunks = splitter._recursive_splitter(text, self.default_separators)
82-
83-
self.assertEqual(len(chunks), 1)
84-
self.assertEqual(chunks, ["HelloWorld" * 100])
85-
86-
def test_special_function_call(self):
87-
node = Node(name="node1", value_dict={"text": "Hello\n\nWorld"})
88-
output_nodes = self.splitter([node])
83+
output_nodes = splitter([node])
8984

9085
self.assertEqual(len(output_nodes), 1)
91-
self.assertEqual(output_nodes[0].value_dict["text"], ["HelloWorld"])
86+
self.assertEqual(len(output_nodes[0].value_dict["text"]), 1)
9287

9388
def test_special_function_call_with_multiple_nodes(self):
89+
splitter = RecursiveCharacterSplitter({"max_chunk_size": 10}, "test_splitter")
90+
9491
node0 = Node(name="node1", value_dict={"text": "Hello\n\nWorld"})
9592
node1 = Node(name="node1", value_dict={"text": "Hello\n\nWorld."})
9693
node2 = Node(name="node1", value_dict={"text": "Hello\n\nWorld\n\n" * 10})
9794
node3 = Node(name="node1", value_dict={"text": "Hello\n\nWorld.\n\n" * 2})
9895
expected_texts = [
99-
["HelloWorld"],
96+
["Hello", "World"],
10097
["Hello", "World."],
101-
["HelloWorld"] * 10,
98+
["Hello", "World"] * 10,
10299
["Hello", "World.", "Hello", "World."],
103100
]
104101

105-
output_nodes = self.splitter([node0, node1, node2, node3])
102+
output_nodes = splitter([node0, node1, node2, node3])
106103
output_texts = [node.value_dict["text"] for node in output_nodes]
107104

108105
self.assertEqual(output_texts, expected_texts)

0 commit comments

Comments
 (0)