diff --git a/packages/plugin-node/src/services/llama.ts b/packages/plugin-node/src/services/llama.ts index 3f2d62183b0..3bfbaafd91a 100644 --- a/packages/plugin-node/src/services/llama.ts +++ b/packages/plugin-node/src/services/llama.ts @@ -11,6 +11,8 @@ import { GbnfJsonSchema, getLlama, Llama, + LlamaChatSession, + LlamaChatSessionRepeatPenalty, LlamaContext, LlamaContextSequence, LlamaContextSequenceRepeatPenalty, @@ -549,49 +551,28 @@ export class LlamaService extends Service { throw new Error("Model not initialized."); } - const tokens = this.model!.tokenize(context); + const session = new LlamaChatSession({ + contextSequence: this.sequence + }); - // tokenize the words to punish const wordsToPunishTokens = wordsToPunish .map((word) => this.model!.tokenize(word)) .flat(); - const repeatPenalty: LlamaContextSequenceRepeatPenalty = { - punishTokens: () => wordsToPunishTokens, + const repeatPenalty: LlamaChatSessionRepeatPenalty = { + punishTokensFilter: () => wordsToPunishTokens, penalty: 1.2, frequencyPenalty: frequency_penalty, presencePenalty: presence_penalty, }; - const responseTokens: Token[] = []; - - for await (const token of this.sequence.evaluate(tokens, { + const response = await session.prompt(context, { + onTextChunk(chunk) { // stream the response to the console as it's being generated + process.stdout.write(chunk); + }, temperature: Number(temperature), - repeatPenalty: repeatPenalty, - grammarEvaluationState: useGrammar ? this.grammar : undefined, - yieldEogToken: false, - })) { - const current = this.model.detokenize([...responseTokens, token]); - if ([...stop].some((s) => current.includes(s))) { - elizaLogger.info("Stop sequence found"); - break; - } - - responseTokens.push(token); - process.stdout.write(this.model!.detokenize([token])); - if (useGrammar) { - if (current.replaceAll("\n", "").includes("}```")) { - elizaLogger.info("JSON block found"); - break; - } - } - if (responseTokens.length > max_tokens) { - elizaLogger.info("Max tokens reached"); - break; - } - } - - const response = this.model!.detokenize(responseTokens); + repeatPenalty: repeatPenalty + }); if (!response) { throw new Error("Response is undefined");