|
1 | 1 | from flask import Blueprint, request, Flask, Response, stream_with_context
|
| 2 | +from werkzeug.exceptions import BadRequest |
2 | 3 | from typing import Iterator
|
3 |
| -from pie.tagger import Tagger |
4 |
| -from pie.utils import chunks, model_spec |
5 | 4 |
|
6 |
| -from .testing import FakeTagger |
7 |
| -from .utils import DataIterator, Tokenizer, Formatter |
| 5 | +from pie_extended.tagger import ExtensibleTagger |
| 6 | +from pie_extended.pipeline.postprocessor.proto import ProcessorPrototype |
| 7 | +from pie_extended.pipeline.iterators.proto import DataIterator |
| 8 | +from pie_extended.pipeline.formatters.proto import Formatter |
| 9 | + |
| 10 | + |
| 11 | +from typing import Callable, Tuple, Type |
8 | 12 |
|
9 | 13 |
|
10 | 14 | class PieController(object):
|
11 | 15 | def __init__(self,
|
12 |
| - path: str = "/api", name: str = "nlp_pie", iterator: DataIterator = None, device: str = None, |
13 |
| - batch_size: int = None, model_file: str = None, formatter_class: Formatter = None, |
14 |
| - headers=None, force_lower=False, disambiguation=None, get_iterator_and_formatter=None): |
| 16 | + tagger: ExtensibleTagger, |
| 17 | + get_iterator_and_processor: Callable[[], Tuple[DataIterator, ProcessorPrototype]], |
| 18 | + path: str = "/api", |
| 19 | + name: str = "nlp_pie", |
| 20 | + batch_size: int = None, |
| 21 | + formatter_class: Type[Formatter] = Formatter, |
| 22 | + headers=None, force_lower=False): |
15 | 23 |
|
16 | 24 | self._bp: Blueprint = Blueprint(name, import_name=name, url_prefix=path)
|
17 |
| - self.tokenizer: Tokenizer = None |
18 | 25 | self.force_lower = force_lower
|
19 |
| - self.formatter_class = formatter_class or Formatter |
| 26 | + self.tagger = tagger |
| 27 | + self.get_iterator_and_processor = get_iterator_and_processor |
20 | 28 | self.batch_size = batch_size
|
21 |
| - self.model_file = model_file |
| 29 | + self.tagger.batch_size = batch_size or 8 |
| 30 | + self.formatter = formatter_class |
22 | 31 | self.headers = {
|
23 | 32 | 'Content-Type': 'text/plain; charset=utf-8',
|
24 | 33 | 'Access-Control-Allow-Origin': "*"
|
25 | 34 | }
|
26 | 35 | if isinstance(headers, dict):
|
27 | 36 | self.headers.update(headers)
|
28 | 37 |
|
29 |
| - if isinstance(model_file, FakeTagger): |
30 |
| - self.tagger = model_file |
31 |
| - else: |
32 |
| - self.tagger = Tagger(device=device, batch_size=batch_size) |
33 |
| - |
34 |
| - for model, tasks in model_spec(model_file): |
35 |
| - self.tagger.add_model(model, *tasks) |
36 |
| - |
37 |
| - self.iterator = iterator |
38 |
| - if not iterator: |
39 |
| - self.iterator = DataIterator() |
40 |
| - |
41 |
| - self._get_iterator_and_formatter = get_iterator_and_formatter |
42 |
| - self.disambiguation = disambiguation |
43 |
| - |
44 |
| - def get_iterator_and_formatter(self): |
45 |
| - if self._get_iterator_and_formatter: |
46 |
| - return self._get_iterator_and_formatter() |
47 |
| - return self.iterator, self.formatter_class |
48 |
| - |
49 | 38 | def init_app(self, app: Flask):
|
50 | 39 | self._bp.add_url_rule("/", view_func=self.route, endpoint="main", methods=["GET", "POST", "OPTIONS"])
|
51 | 40 | app.register_blueprint(self._bp)
|
@@ -74,115 +63,16 @@ def csv_stream(self) -> Iterator[str]:
|
74 | 63 | else:
|
75 | 64 | data = request.form.get("data")
|
76 | 65 |
|
| 66 | + if lower: |
| 67 | + data = data.lower() |
| 68 | + |
77 | 69 | if not data:
|
78 |
| - yield "" |
| 70 | + raise BadRequest() |
79 | 71 | else:
|
80 |
| - iter_fn, formatter = self.get_iterator_and_formatter() |
81 |
| - yield from self.build_response( |
82 |
| - data, |
83 |
| - lower=lower, |
| 72 | + iter_fn, proc = self.get_iterator_and_processor() |
| 73 | + yield from self.tagger.iter_tag( |
| 74 | + data=data, |
| 75 | + formatter_class=self.formatter, |
84 | 76 | iterator=iter_fn,
|
85 |
| - batch_size=self.batch_size, |
86 |
| - tagger=self.tagger, |
87 |
| - formatter_class=formatter |
88 |
| - ) |
89 |
| - |
90 |
| - def reinsert_full(self, formatter, sent_reinsertion, tasks): |
91 |
| - yield formatter.write_sentence_beginning() |
92 |
| - # If a sentence is empty, it's most likely because everything is in sent_reinsertions |
93 |
| - for reinsertion in sorted(list(sent_reinsertion.keys())): |
94 |
| - yield formatter.write_line( |
95 |
| - formatter.format_line( |
96 |
| - token=sent_reinsertion[reinsertion], |
97 |
| - tags=[""] * len(tasks) |
98 |
| - ) |
| 77 | + processor=proc |
99 | 78 | )
|
100 |
| - yield formatter.write_sentence_end() |
101 |
| - |
102 |
| - def build_response(self, data, iterator, lower, batch_size, tagger, formatter_class): |
103 |
| - header = False |
104 |
| - formatter = None |
105 |
| - for chunk in chunks(iterator(data, lower=lower), size=batch_size): |
106 |
| - # Unzip the batch into the sentences, their sizes and the dictionaries of things that needs |
107 |
| - # to be reinserted |
108 |
| - sents, lengths, needs_reinsertion = zip(*chunk) |
109 |
| - # Removing punctuation might create empty sentences ! |
110 |
| - # Which would crash Torch |
111 |
| - empty_sents_indexes = { |
112 |
| - index: [] |
113 |
| - for index, sent in enumerate(sents) |
114 |
| - if len(sent) == 0 |
115 |
| - } |
116 |
| - tagged, tasks = tagger.tag(sents=[sent for sent in sents if len(sent)], lengths=lengths) |
117 |
| - formatter = formatter_class(tasks) |
118 |
| - |
119 |
| - # We keep a real sentence index |
120 |
| - real_sentence_index = 0 |
121 |
| - for sent in tagged: |
122 |
| - if not sent: |
123 |
| - continue |
124 |
| - # Gets things that needs to be reinserted |
125 |
| - sent_reinsertion = needs_reinsertion[real_sentence_index] |
126 |
| - |
127 |
| - # If the header has not yet be written, write it |
128 |
| - if not header: |
129 |
| - yield formatter.write_headers() |
130 |
| - header = True |
131 |
| - |
132 |
| - # Some sentences can be empty and would have been removed from tagging |
133 |
| - # we check and until we get to a non empty sentence |
134 |
| - # we increment the real_sentence_index to keep in check with the reinsertion map |
135 |
| - while real_sentence_index in empty_sents_indexes: |
136 |
| - yield from self.reinsert_full( |
137 |
| - formatter, |
138 |
| - needs_reinsertion[real_sentence_index], |
139 |
| - tasks |
140 |
| - ) |
141 |
| - real_sentence_index += 1 |
142 |
| - |
143 |
| - yield formatter.write_sentence_beginning() |
144 |
| - |
145 |
| - # If we have a disambiguator, we run the results into it |
146 |
| - if self.disambiguation: |
147 |
| - sent = self.disambiguation(sent, tasks) |
148 |
| - |
149 |
| - reinsertion_index = 0 |
150 |
| - index = 0 |
151 |
| - |
152 |
| - for index, (token, tags) in enumerate(sent): |
153 |
| - while reinsertion_index + index in sent_reinsertion: |
154 |
| - yield formatter.write_line( |
155 |
| - formatter.format_line( |
156 |
| - token=sent_reinsertion[reinsertion_index + index], |
157 |
| - tags=[""] * len(tasks) |
158 |
| - ) |
159 |
| - ) |
160 |
| - del sent_reinsertion[reinsertion_index + index] |
161 |
| - reinsertion_index += 1 |
162 |
| - |
163 |
| - yield formatter.write_line( |
164 |
| - formatter.format_line(token, tags) |
165 |
| - ) |
166 |
| - |
167 |
| - for reinsertion in sorted(list(sent_reinsertion.keys())): |
168 |
| - yield formatter.write_line( |
169 |
| - formatter.format_line( |
170 |
| - token=sent_reinsertion[reinsertion], |
171 |
| - tags=[""] * len(tasks) |
172 |
| - ) |
173 |
| - ) |
174 |
| - |
175 |
| - yield formatter.write_sentence_end() |
176 |
| - |
177 |
| - real_sentence_index += 1 |
178 |
| - |
179 |
| - while real_sentence_index in empty_sents_indexes: |
180 |
| - yield from self.reinsert_full( |
181 |
| - formatter, |
182 |
| - needs_reinsertion[real_sentence_index], |
183 |
| - tasks |
184 |
| - ) |
185 |
| - real_sentence_index += 1 |
186 |
| - |
187 |
| - if formatter: |
188 |
| - yield formatter.write_footer() |
0 commit comments