Skip to content

Commit 1caed0d

Browse files
committed
Add splitter config for recursive splitter
1 parent 3895073 commit 1caed0d

File tree

4 files changed

+36
-30
lines changed

4 files changed

+36
-30
lines changed

example/extract/extract_pdf_with_recursive_splitter.ipynb

+10-13
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
"import pandas as pd\n",
7171
"import pprint\n",
7272
"from uniflow.flow.client import ExtractClient, TransformClient\n",
73-
"from uniflow.flow.config import TransformOpenAIConfig, ExtractPDFConfig\n",
73+
"from uniflow.flow.config import TransformOpenAIConfig, ExtractPDFConfig, SplitterConfig\n",
7474
"from uniflow.op.model.model_config import OpenAIModelConfig, NougatModelConfig\n",
7575
"from uniflow.op.prompt import PromptTemplate, Context\n",
7676
"from uniflow.op.extract.split.splitter_factory import SplitterOpsFactory\n",
@@ -136,26 +136,23 @@
136136
"cell_type": "code",
137137
"execution_count": 5,
138138
"metadata": {},
139-
"outputs": [
140-
{
141-
"name": "stderr",
142-
"output_type": "stream",
143-
"text": [
144-
"/home/ubuntu/anaconda3/envs/uniflow/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3526.)\n",
145-
" return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n"
146-
]
147-
}
148-
],
139+
"outputs": [],
149140
"source": [
150141
"data = [\n",
151142
" {\"filename\": input_file},\n",
152143
"]\n",
153144
"\n",
145+
"splitter_config = SplitterConfig(\n",
146+
" max_chunk_size = 1024,\n",
147+
" splitter_func = RECURSIVE_CHARACTER_SPLITTER\n",
148+
" )\n",
149+
"splitter_config.chunk_overlap_size = 5\n",
150+
"\n",
154151
"config = ExtractPDFConfig(\n",
155152
" model_config=NougatModelConfig(\n",
156153
" batch_size = 1 # When batch_size>1, nougat will run on CUDA, otherwise it will run on CPU\n",
157154
" ),\n",
158-
" splitter=RECURSIVE_CHARACTER_SPLITTER,\n",
155+
" splitter_config=splitter_config,\n",
159156
")\n",
160157
"nougat_client = ExtractClient(config)"
161158
]
@@ -176,7 +173,7 @@
176173
"name": "stderr",
177174
"output_type": "stream",
178175
"text": [
179-
"100%|██████████| 1/1 [00:05<00:00, 5.07s/it]\n"
176+
"100%|██████████| 1/1 [00:03<00:00, 3.23s/it]\n"
180177
]
181178
}
182179
],

uniflow/flow/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class ExtractPDFConfig(ExtractConfig):
7272

7373
flow_name: str = "ExtractPDFFlow"
7474
model_config: ModelConfig = field(default_factory=NougatModelConfig)
75-
splitter: str = PARAGRAPH_SPLITTER
75+
splitter_config: SplitterConfig = field(default_factory=SplitterConfig)
7676

7777

7878
@dataclass

uniflow/flow/extract/extract_pdf_flow.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from uniflow.flow.flow import Flow
77
from uniflow.node import Node
88
from uniflow.op.extract.load.pdf_op import ExtractPDFOp, ProcessPDFOp
9-
from uniflow.op.extract.split.constants import PARAGRAPH_SPLITTER
109
from uniflow.op.extract.split.splitter_factory import SplitterOpsFactory
1110
from uniflow.op.model.cv.model import CvModel
1211

@@ -17,9 +16,7 @@ class ExtractPDFFlow(Flow):
1716
TAG = EXTRACT
1817

1918
def __init__(
20-
self,
21-
model_config: Dict[str, Any],
22-
splitter: str = PARAGRAPH_SPLITTER,
19+
self, model_config: Dict[str, Any], splitter_config: Dict[str, Any]
2320
) -> None:
2421
"""Extract PDF Flow Constructor.
2522
@@ -35,7 +32,7 @@ def __init__(
3532
),
3633
)
3734
self._process_pdf_op = ProcessPDFOp(name="process_pdf_op")
38-
self._split_op = SplitterOpsFactory.get(splitter)
35+
self._split_op = SplitterOpsFactory.get(splitter_config)
3936

4037
def run(self, nodes: Sequence[Node]) -> Sequence[Node]:
4138
"""Run Model Flow.

uniflow/op/extract/split/recursive_character_splitter.py

+23-11
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import copy
44
import re
5-
from typing import Iterable, List, Optional, Sequence
5+
from typing import Iterable, List, Sequence
66

77
import tiktoken # Import necessary for token-based splitting
88

@@ -13,15 +13,15 @@
1313
class RecursiveCharacterSplitter(Op):
1414
"""Recursive character splitter class."""
1515

16-
default_separators = ["\n\n", "\n", ". ", " ", ""]
16+
default_chunk_size = 1024
17+
default_chunk_overlap_size = 32
18+
default_separators = "\n\n|\n|. |.|, | "
19+
default_splitting_mode = "char"
1720

1821
def __init__(
1922
self,
20-
name: str,
21-
chunk_size: int = 1024,
22-
chunk_overlap_size: int = 0,
23-
separators: Optional[List[str]] = None,
24-
splitting_mode: str = "char", # Added parameter for splitting mode
23+
splitterConfig: dict[str, any],
24+
name: str = "recursive_character_splitter_op",
2525
) -> None:
2626
"""Recursive Splitter Op Constructor
2727
@@ -36,10 +36,22 @@ def __init__(
3636
splitting_mode (str): "char" for character count, "token" for token count. Defaults to "char".
3737
"""
3838
super().__init__(name)
39-
self._chunk_size = chunk_size
40-
self._chunk_overlap_size = chunk_overlap_size
41-
self._separators = separators or self.default_separators
42-
self._splitting_mode = splitting_mode # Track splitting mode
39+
40+
# Set up the splitter configuration
41+
self._chunk_size = splitterConfig["max_chunk_size"] or self.default_chunk_size
42+
self._separators = (
43+
splitterConfig["separators"] or self.default_separators
44+
).split("|")
45+
46+
# Set up the splitter configuration for recursive splitting
47+
self._chunk_overlap_size = (
48+
"chunk_overlap_size" in splitterConfig
49+
and splitterConfig["chunk_overlap_size"]
50+
) or self.default_chunk_overlap_size
51+
self._splitting_mode = (
52+
"splitting_mode" in splitterConfig and splitterConfig["splitting_mode"]
53+
) or self.default_splitting_mode
54+
4355
self._encoder = tiktoken.encoding_for_model(
4456
"gpt-3.5"
4557
) # Setup encoder for token-based splitting

0 commit comments

Comments
 (0)