diff --git a/src/README.md b/src/README.md index e88c2f784f..505dbfe2e9 100644 --- a/src/README.md +++ b/src/README.md @@ -196,6 +196,88 @@ int main(int argc, char* argv[]) { } ``` +```py +import openvino_genai as ov_genai + +class TextStreamerIterator(ov_genai.StreamerBase): + def __init__(self, tokenizer, stop_tokens=None): + super().__init__() + self.tokenizer = tokenizer + + # TODO don't know from where to import Queue + self.text_queue = [] # self.text_queue = Queue() # TODO: replace list with Queue() + self.stop_signal = None + self.tokens_cache = [] + self.print_len = 0 + self.stop_tokens = stop_tokens if stop_tokens else [] + + def __iter__(self): + return self + + def __next__(self): + value = self.text_queue.pop() # value = self.text_queue.get() # TODO: replace list with Queue() + if value is self.stop_signal: + raise StopIteration() + else: + return value + + def put_word(self, word: str): + # remove labels/special symbols + word = re.sub("<.*>", "", word) + self.text_queue.append(word) # self.text_queue.put(text) # TODO: replace list with Queue() + + def put(self, token_id): + self.tokens_cache.append(token_id) + text = self.tokenizer.decode(self.tokens_cache) + + if not text and '\n' == text[-1] and len(text) > self.print_len: + # Flush the cache after the new line symbol + word = text[self.print_len:] + self.tokens_cache = [] + self.print_len = 0 + self.put_word(word) + + if len(text) >= 3 and text[-3:] == "�": + # Don't print incomplete text + return self.put_word(word) + elif len(text) > self.print_len: + # It is possible to have a shorter text after adding new token. + # Print to output only if text lengh is increaesed. + word = text[self.print_len:] + self.print_len = len(text) + + self.put_word(word) + + if token_id in self.stop_tokens: + # When generation is stopped from streamer then end is not called, need to call it here manually. + self.end() + return True # True means stop generation + else: + return False # False means continue generation + + def end(self): + # Flush residual tokens from the buffer. + text = self.tokenizer.decode(self.tokens_cache); + if len(text) > self.print_len: + word = text[self.print_len:] + self.put_word(word) + self.tokens_cache = [] + self.print_len = 0 + + self.text_queue.append(self.stop_signal) # self.text_queue.put(self.stop_signal) # TODO: replace list with Queue() + +tokenizer = ov_genai.Tokenizer(model_path) +text_streamer = TextStreamerIterator(tokenizer) + +prompt = 'The Sun is yellow because' +pipe.generate(prompt, streamer=text_streamer, max_new_tokens=10) + +import openvino_genai as ov_genai +pipe = ov_genai.LLMPipeline(model_path, "CPU") + +pipe.generate("The Sun is yellow because", max_new_tokens=100, streamer=text_streamer) +``` + ### Performance Metrics `openvino_genai.PerfMetrics` (referred as `PerfMetrics` for simplicity) is a structure that holds performance metrics for each generate call. `PerfMetrics` holds fields with mean and standard deviations for the following metrics: