Skip to content

Commit bfcb306

Browse files
authored
Merge pull request #198 from jli943/main
add test_extract_flow
2 parents 43d67db + 76a956a commit bfcb306

6 files changed

+302
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
4+
from uniflow.flow.extract.extract_html_flow import ExtractHTMLFlow
5+
from uniflow.node import Node
6+
from uniflow.op.extract.split.constants import PARAGRAPH_SPLITTER
7+
8+
9+
class TestExtractHTMLFlow(unittest.TestCase):
10+
@patch("uniflow.flow.extract.extract_html_flow.ExtractHTMLOp")
11+
@patch("uniflow.flow.extract.extract_html_flow.ProcessHTMLOp")
12+
@patch("uniflow.flow.extract.extract_html_flow.SplitterOpsFactory")
13+
def setUp(
14+
self, mock_splitter_ops_factory, mock_process_html_op, mock_extract_html_op
15+
):
16+
self.mock_splitter_ops_factory = mock_splitter_ops_factory
17+
self.mock_process_html_op = mock_process_html_op
18+
self.mock_extract_html_op = mock_extract_html_op
19+
self.extract_html_flow = ExtractHTMLFlow()
20+
21+
def test_init(self):
22+
self.mock_extract_html_op.assert_called_once_with(name="extract_html_op")
23+
self.mock_process_html_op.assert_called_once_with(name="process_html_op")
24+
self.mock_splitter_ops_factory.get.assert_called_once_with(PARAGRAPH_SPLITTER)
25+
26+
def test_run(self):
27+
# arrange
28+
nodes = [
29+
Node(name="node1", value_dict={"filename": "filepath"}),
30+
Node(name="node2", value_dict={"filename": "filepath"}),
31+
]
32+
33+
self.mock_splitter_ops_factory.get.return_value.return_value = MagicMock()
34+
self.mock_process_html_op.return_value.return_value = MagicMock()
35+
self.mock_extract_html_op.return_value.return_value = MagicMock()
36+
result = self.extract_html_flow.run(nodes)
37+
38+
self.mock_extract_html_op.return_value.assert_called_once_with(nodes)
39+
self.mock_process_html_op.return_value.assert_called_once_with(
40+
self.mock_extract_html_op.return_value.return_value
41+
)
42+
self.mock_splitter_ops_factory.get.return_value.assert_called_once_with(
43+
self.mock_process_html_op.return_value.return_value
44+
)
45+
self.assertEqual(
46+
result, self.mock_splitter_ops_factory.get.return_value.return_value
47+
)
48+
49+
50+
if __name__ == "__main__":
51+
unittest.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
4+
from uniflow.flow.extract.extract_image_flow import ExtractImageFlow
5+
from uniflow.node import Node
6+
from uniflow.op.extract.split.constants import PARAGRAPH_SPLITTER
7+
8+
9+
class TestExtractImageFlow(unittest.TestCase):
10+
@patch("uniflow.flow.extract.extract_image_flow.ExtractImageOp")
11+
@patch("uniflow.flow.extract.extract_image_flow.ProcessImageOp")
12+
@patch("uniflow.flow.extract.extract_image_flow.SplitterOpsFactory")
13+
@patch("uniflow.flow.extract.extract_image_flow.CvModel")
14+
def setUp(
15+
self,
16+
mock_cv_model,
17+
mock_splitter_ops_factory,
18+
mock_process_image_op,
19+
mock_extract_image_op,
20+
):
21+
self.mock_extract_image_op = mock_extract_image_op
22+
self.mock_process_image_op = mock_process_image_op
23+
self.mock_splitter_ops_factory = mock_splitter_ops_factory
24+
self.mock_cv_model = mock_cv_model
25+
self.mock_cv_model.return_value = MagicMock()
26+
self.model_config = {"model_config": "model_config"}
27+
self.extract_image_flow = ExtractImageFlow(model_config=self.model_config)
28+
29+
def test_init(self):
30+
self.mock_extract_image_op.assert_called_once_with(
31+
name="extract_image_op", model=self.mock_cv_model.return_value
32+
)
33+
self.mock_cv_model.assert_called_once_with(model_config=self.model_config)
34+
self.mock_process_image_op.assert_called_once_with(name="process_image_op")
35+
self.mock_splitter_ops_factory.get.assert_called_once_with(PARAGRAPH_SPLITTER)
36+
37+
def test_run(self):
38+
# arrange
39+
nodes = [
40+
Node(name="node1", value_dict={"filename": "filepath"}),
41+
Node(name="node2", value_dict={"filename": "filepath"}),
42+
]
43+
44+
self.mock_extract_image_op.return_value.return_value = MagicMock()
45+
self.mock_process_image_op.return_value.return_value = MagicMock()
46+
self.mock_splitter_ops_factory.get.return_value.return_value = MagicMock()
47+
result = self.extract_image_flow.run(nodes)
48+
49+
self.mock_extract_image_op.return_value.assert_called_once_with(nodes)
50+
self.mock_process_image_op.return_value.assert_called_once_with(
51+
self.mock_extract_image_op.return_value.return_value
52+
)
53+
self.mock_splitter_ops_factory.get.return_value.assert_called_once_with(
54+
self.mock_process_image_op.return_value.return_value
55+
)
56+
57+
self.assertEqual(
58+
result, self.mock_splitter_ops_factory.get.return_value.return_value
59+
)
60+
61+
62+
if __name__ == "__main__":
63+
unittest.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
4+
from uniflow.flow.extract.extract_ipynb_flow import ExtractIpynbFlow
5+
from uniflow.node import Node
6+
7+
8+
class TestExtractIpynbFlow(unittest.TestCase):
9+
@patch("uniflow.flow.extract.extract_ipynb_flow.ExtractIpynbOp")
10+
@patch("uniflow.flow.extract.extract_ipynb_flow.ProcessIpynbOp")
11+
def setUp(self, mock_process_ipynb_op, mock_extract_ipynb_op):
12+
self.mock_process_ipynb_op = mock_process_ipynb_op
13+
self.mock_extract_ipynb_op = mock_extract_ipynb_op
14+
self.extract_ipynb_flow = ExtractIpynbFlow()
15+
16+
def test_init(self):
17+
self.mock_extract_ipynb_op.assert_called_once_with(name="extract_ipynb_op")
18+
self.mock_process_ipynb_op.assert_called_once_with(name="process_ipynb_op")
19+
20+
def test_run(self):
21+
# arrange
22+
nodes = [
23+
Node(name="node1", value_dict={"filename": "filepath"}),
24+
Node(name="node2", value_dict={"filename": "filepath"}),
25+
]
26+
27+
self.mock_process_ipynb_op.return_value.return_value = MagicMock()
28+
self.mock_extract_ipynb_op.return_value.return_value = MagicMock()
29+
result = self.extract_ipynb_flow.run(nodes)
30+
31+
self.mock_extract_ipynb_op.return_value.assert_called_once_with(nodes)
32+
self.mock_process_ipynb_op.return_value.assert_called_once_with(
33+
self.mock_extract_ipynb_op.return_value.return_value
34+
)
35+
self.assertEqual(result, self.mock_process_ipynb_op.return_value.return_value)
36+
37+
38+
if __name__ == "__main__":
39+
unittest.main()
+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
4+
from uniflow.flow.extract.extract_md_flow import ExtractMarkdownFlow
5+
from uniflow.node import Node
6+
from uniflow.op.extract.split.constants import MARKDOWN_HEADER_SPLITTER
7+
8+
9+
class TestExtractMarkdownFlow(unittest.TestCase):
10+
@patch("uniflow.flow.extract.extract_md_flow.ExtractTxtOp")
11+
@patch("uniflow.flow.extract.extract_md_flow.SplitterOpsFactory")
12+
def setUp(self, mock_splitter_ops_factory, mock_extract_md_op):
13+
self.mock_splitter_ops_factory = mock_splitter_ops_factory
14+
self.mock_extract_md_op = mock_extract_md_op
15+
self.extract_md_flow = ExtractMarkdownFlow()
16+
17+
def test_init(self):
18+
self.mock_extract_md_op.assert_called_once_with(name="extract_md_op")
19+
self.mock_splitter_ops_factory.get.assert_called_once_with(
20+
MARKDOWN_HEADER_SPLITTER
21+
)
22+
23+
def test_run(self):
24+
# arrange
25+
nodes = [
26+
Node(name="node1", value_dict={"filename": "filepath"}),
27+
Node(name="node2", value_dict={"filename": "filepath"}),
28+
]
29+
30+
self.mock_splitter_ops_factory.get.return_value.return_value = MagicMock()
31+
self.mock_extract_md_op.return_value.return_value = MagicMock()
32+
result = self.extract_md_flow.run(nodes)
33+
34+
self.mock_extract_md_op.return_value.assert_called_once_with(nodes)
35+
self.mock_splitter_ops_factory.get.return_value.assert_called_once_with(
36+
self.mock_extract_md_op.return_value.return_value
37+
)
38+
self.assertEqual(
39+
result, self.mock_splitter_ops_factory.get.return_value.return_value
40+
)
41+
42+
43+
if __name__ == "__main__":
44+
unittest.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
4+
from uniflow.flow.extract.extract_pdf_flow import ExtractPDFFlow
5+
from uniflow.node import Node
6+
from uniflow.op.extract.split.constants import PARAGRAPH_SPLITTER
7+
8+
9+
class TestExtractPDFFlow(unittest.TestCase):
10+
@patch("uniflow.flow.extract.extract_pdf_flow.ExtractPDFOp")
11+
@patch("uniflow.flow.extract.extract_pdf_flow.ProcessPDFOp")
12+
@patch("uniflow.flow.extract.extract_pdf_flow.SplitterOpsFactory")
13+
@patch("uniflow.flow.extract.extract_pdf_flow.CvModel")
14+
def setUp(
15+
self,
16+
mock_cv_model,
17+
mock_splitter_ops_factory,
18+
mock_process_pdf_op,
19+
mock_extract_pdf_op,
20+
):
21+
self.mock_extract_pdf_op = mock_extract_pdf_op
22+
self.mock_process_pdf_op = mock_process_pdf_op
23+
self.mock_splitter_ops_factory = mock_splitter_ops_factory
24+
self.mock_cv_model = mock_cv_model
25+
self.mock_cv_model.return_value = MagicMock()
26+
self.model_config = {"model_config": "model_config"}
27+
self.extract_pdf_flow = ExtractPDFFlow(model_config=self.model_config)
28+
29+
def test_init(self):
30+
self.mock_extract_pdf_op.assert_called_once_with(
31+
name="extract_pdf_op", model=self.mock_cv_model.return_value
32+
)
33+
self.mock_cv_model.assert_called_once_with(model_config=self.model_config)
34+
self.mock_process_pdf_op.assert_called_once_with(name="process_pdf_op")
35+
self.mock_splitter_ops_factory.get.assert_called_once_with(PARAGRAPH_SPLITTER)
36+
37+
def test_run(self):
38+
# arrange
39+
nodes = [
40+
Node(name="node1", value_dict={"filename": "filepath"}),
41+
Node(name="node2", value_dict={"filename": "filepath"}),
42+
]
43+
44+
self.mock_extract_pdf_op.return_value.return_value = MagicMock()
45+
self.mock_process_pdf_op.return_value.return_value = MagicMock()
46+
self.mock_splitter_ops_factory.get.return_value.return_value = MagicMock()
47+
result = self.extract_pdf_flow.run(nodes)
48+
49+
self.mock_extract_pdf_op.return_value.assert_called_once_with(nodes)
50+
self.mock_process_pdf_op.return_value.assert_called_once_with(
51+
self.mock_extract_pdf_op.return_value.return_value
52+
)
53+
self.mock_splitter_ops_factory.get.return_value.assert_called_once_with(
54+
self.mock_process_pdf_op.return_value.return_value
55+
)
56+
57+
self.assertEqual(
58+
result, self.mock_splitter_ops_factory.get.return_value.return_value
59+
)
60+
61+
62+
if __name__ == "__main__":
63+
unittest.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
4+
from uniflow.flow.extract.extract_txt_flow import ExtractTxtFlow
5+
from uniflow.node import Node
6+
from uniflow.op.extract.split.constants import PARAGRAPH_SPLITTER
7+
8+
9+
class TestExtractTxtFlow(unittest.TestCase):
10+
@patch("uniflow.flow.extract.extract_txt_flow.ExtractTxtOp")
11+
@patch("uniflow.flow.extract.extract_txt_flow.SplitterOpsFactory")
12+
def setUp(self, mock_splitter_ops_factory, mock_extract_txt_op):
13+
self.mock_splitter_ops_factory = mock_splitter_ops_factory
14+
self.mock_extract_txt_op = mock_extract_txt_op
15+
self.extract_txt_flow = ExtractTxtFlow()
16+
17+
def test_init(self):
18+
self.mock_extract_txt_op.assert_called_once_with(name="extract_txt_op")
19+
self.mock_splitter_ops_factory.get.assert_called_once_with(PARAGRAPH_SPLITTER)
20+
21+
def test_run(self):
22+
# arrange
23+
nodes = [
24+
Node(name="node1", value_dict={"filename": "filepath"}),
25+
Node(name="node2", value_dict={"filename": "filepath"}),
26+
]
27+
28+
self.mock_splitter_ops_factory.get.return_value.return_value = MagicMock()
29+
self.mock_extract_txt_op.return_value.return_value = MagicMock()
30+
result = self.extract_txt_flow.run(nodes)
31+
32+
self.mock_extract_txt_op.return_value.assert_called_once_with(nodes)
33+
self.mock_splitter_ops_factory.get.return_value.assert_called_once_with(
34+
self.mock_extract_txt_op.return_value.return_value
35+
)
36+
self.assertEqual(
37+
result, self.mock_splitter_ops_factory.get.return_value.return_value
38+
)
39+
40+
41+
if __name__ == "__main__":
42+
unittest.main()

0 commit comments

Comments
 (0)