Skip to content

Commit cac3912

Browse files
authored
add note to context for local generation (#2604)
1 parent 4c8a60a commit cac3912

File tree

1 file changed

+47
-41
lines changed
  • packages/plugin-node/src/services

1 file changed

+47
-41
lines changed

packages/plugin-node/src/services/llama.ts

+47-41
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ export class LlamaService extends Service {
189189
const modelName = "model.gguf";
190190
this.modelPath = path.join(
191191
process.env.LLAMALOCAL_PATH?.trim() ?? "./",
192-
modelName
192+
modelName,
193193
);
194194
this.ollamaModel = process.env.OLLAMA_MODEL;
195195
}
@@ -202,7 +202,7 @@ export class LlamaService extends Service {
202202
private async ensureInitialized() {
203203
if (!this.modelInitialized) {
204204
elizaLogger.info(
205-
"Model not initialized, starting initialization..."
205+
"Model not initialized, starting initialization...",
206206
);
207207
await this.initializeModel();
208208
} else {
@@ -217,16 +217,16 @@ export class LlamaService extends Service {
217217

218218
const systemInfo = await si.graphics();
219219
const hasCUDA = systemInfo.controllers.some((controller) =>
220-
controller.vendor.toLowerCase().includes("nvidia")
220+
controller.vendor.toLowerCase().includes("nvidia"),
221221
);
222222

223223
if (hasCUDA) {
224224
elizaLogger.info(
225-
"LlamaService: CUDA detected, using GPU acceleration"
225+
"LlamaService: CUDA detected, using GPU acceleration",
226226
);
227227
} else {
228228
elizaLogger.warn(
229-
"LlamaService: No CUDA detected - local response will be slow"
229+
"LlamaService: No CUDA detected - local response will be slow",
230230
);
231231
}
232232

@@ -238,7 +238,7 @@ export class LlamaService extends Service {
238238
elizaLogger.info("Creating JSON schema grammar...");
239239
const grammar = new LlamaJsonSchemaGrammar(
240240
this.llama,
241-
jsonSchemaGrammar as GbnfJsonSchema
241+
jsonSchemaGrammar as GbnfJsonSchema,
242242
);
243243
this.grammar = grammar;
244244

@@ -257,21 +257,21 @@ export class LlamaService extends Service {
257257
} catch (error) {
258258
elizaLogger.error(
259259
"Model initialization failed. Deleting model and retrying:",
260-
error
260+
error,
261261
);
262262
try {
263263
elizaLogger.info(
264-
"Attempting to delete and re-download model..."
264+
"Attempting to delete and re-download model...",
265265
);
266266
await this.deleteModel();
267267
await this.initializeModel();
268268
} catch (retryError) {
269269
elizaLogger.error(
270270
"Model re-initialization failed:",
271-
retryError
271+
retryError,
272272
);
273273
throw new Error(
274-
`Model initialization failed after retry: ${retryError.message}`
274+
`Model initialization failed after retry: ${retryError.message}`,
275275
);
276276
}
277277
}
@@ -294,7 +294,7 @@ export class LlamaService extends Service {
294294
response.headers.location
295295
) {
296296
elizaLogger.info(
297-
`Following redirect to: ${response.headers.location}`
297+
`Following redirect to: ${response.headers.location}`,
298298
);
299299
downloadModel(response.headers.location);
300300
return;
@@ -303,24 +303,24 @@ export class LlamaService extends Service {
303303
if (response.statusCode !== 200) {
304304
reject(
305305
new Error(
306-
`Failed to download model: HTTP ${response.statusCode}`
307-
)
306+
`Failed to download model: HTTP ${response.statusCode}`,
307+
),
308308
);
309309
return;
310310
}
311311

312312
totalSize = Number.parseInt(
313313
response.headers["content-length"] || "0",
314-
10
314+
10,
315315
);
316316
elizaLogger.info(
317-
`Downloading model: Hermes-3-Llama-3.1-8B.Q8_0.gguf`
317+
`Downloading model: Hermes-3-Llama-3.1-8B.Q8_0.gguf`,
318318
);
319319
elizaLogger.info(
320-
`Download location: ${this.modelPath}`
320+
`Download location: ${this.modelPath}`,
321321
);
322322
elizaLogger.info(
323-
`Total size: ${(totalSize / 1024 / 1024).toFixed(2)} MB`
323+
`Total size: ${(totalSize / 1024 / 1024).toFixed(2)} MB`,
324324
);
325325

326326
response.pipe(file);
@@ -336,7 +336,7 @@ export class LlamaService extends Service {
336336
).toFixed(1)
337337
: "0.0";
338338
const dots = ".".repeat(
339-
Math.floor(Number(progress) / 5)
339+
Math.floor(Number(progress) / 5),
340340
);
341341
progressString = `Downloading model: [${dots.padEnd(20, " ")}] ${progress}%`;
342342
elizaLogger.progress(progressString);
@@ -353,17 +353,17 @@ export class LlamaService extends Service {
353353
fs.unlink(this.modelPath, () => {});
354354
reject(
355355
new Error(
356-
`Model download failed: ${error.message}`
357-
)
356+
`Model download failed: ${error.message}`,
357+
),
358358
);
359359
});
360360
})
361361
.on("error", (error) => {
362362
fs.unlink(this.modelPath, () => {});
363363
reject(
364364
new Error(
365-
`Model download request failed: ${error.message}`
366-
)
365+
`Model download request failed: ${error.message}`,
366+
),
367367
);
368368
});
369369
};
@@ -393,7 +393,7 @@ export class LlamaService extends Service {
393393
stop: string[],
394394
frequency_penalty: number,
395395
presence_penalty: number,
396-
max_tokens: number
396+
max_tokens: number,
397397
): Promise<any> {
398398
await this.ensureInitialized();
399399
return new Promise((resolve, reject) => {
@@ -418,7 +418,7 @@ export class LlamaService extends Service {
418418
stop: string[],
419419
frequency_penalty: number,
420420
presence_penalty: number,
421-
max_tokens: number
421+
max_tokens: number,
422422
): Promise<string> {
423423
await this.ensureInitialized();
424424

@@ -460,7 +460,7 @@ export class LlamaService extends Service {
460460
message.frequency_penalty,
461461
message.presence_penalty,
462462
message.max_tokens,
463-
message.useGrammar
463+
message.useGrammar,
464464
);
465465
message.resolve(response);
466466
} catch (error) {
@@ -509,14 +509,17 @@ export class LlamaService extends Service {
509509
frequency_penalty: number,
510510
presence_penalty: number,
511511
max_tokens: number,
512-
useGrammar: boolean
512+
useGrammar: boolean,
513513
): Promise<any | string> {
514+
context = context +=
515+
"\nIMPORTANT: Escape any quotes in any string fields with a backslash so the JSON is valid.";
516+
514517
const ollamaModel = process.env.OLLAMA_MODEL;
515518
if (ollamaModel) {
516519
const ollamaUrl =
517520
process.env.OLLAMA_SERVER_URL || "http://localhost:11434";
518521
elizaLogger.info(
519-
`Using Ollama API at ${ollamaUrl} with model ${ollamaModel}`
522+
`Using Ollama API at ${ollamaUrl} with model ${ollamaModel}`,
520523
);
521524

522525
const response = await fetch(`${ollamaUrl}/api/generate`, {
@@ -538,7 +541,7 @@ export class LlamaService extends Service {
538541

539542
if (!response.ok) {
540543
throw new Error(
541-
`Ollama request failed: ${response.statusText}`
544+
`Ollama request failed: ${response.statusText}`,
542545
);
543546
}
544547

@@ -552,11 +555,12 @@ export class LlamaService extends Service {
552555
}
553556

554557
const session = new LlamaChatSession({
555-
contextSequence: this.sequence
558+
contextSequence: this.sequence,
556559
});
557560

558-
const wordsToPunishTokens = wordsToPunish
559-
.flatMap((word) => this.model!.tokenize(word));
561+
const wordsToPunishTokens = wordsToPunish.flatMap((word) =>
562+
this.model!.tokenize(word),
563+
);
560564

561565
const repeatPenalty: LlamaChatSessionRepeatPenalty = {
562566
punishTokensFilter: () => wordsToPunishTokens,
@@ -566,11 +570,12 @@ export class LlamaService extends Service {
566570
};
567571

568572
const response = await session.prompt(context, {
569-
onTextChunk(chunk) { // stream the response to the console as it's being generated
573+
onTextChunk(chunk) {
574+
// stream the response to the console as it's being generated
570575
process.stdout.write(chunk);
571576
},
572577
temperature: Number(temperature),
573-
repeatPenalty: repeatPenalty
578+
repeatPenalty: repeatPenalty,
574579
});
575580

576581
if (!response) {
@@ -612,7 +617,7 @@ export class LlamaService extends Service {
612617
const embeddingModel =
613618
process.env.OLLAMA_EMBEDDING_MODEL || "mxbai-embed-large";
614619
elizaLogger.info(
615-
`Using Ollama API for embeddings with model ${embeddingModel} (base: ${ollamaModel})`
620+
`Using Ollama API for embeddings with model ${embeddingModel} (base: ${ollamaModel})`,
616621
);
617622

618623
const response = await fetch(`${ollamaUrl}/api/embeddings`, {
@@ -626,7 +631,7 @@ export class LlamaService extends Service {
626631

627632
if (!response.ok) {
628633
throw new Error(
629-
`Ollama embeddings request failed: ${response.statusText}`
634+
`Ollama embeddings request failed: ${response.statusText}`,
630635
);
631636
}
632637

@@ -644,7 +649,7 @@ export class LlamaService extends Service {
644649
const embeddingModel =
645650
process.env.OLLAMA_EMBEDDING_MODEL || "mxbai-embed-large";
646651
elizaLogger.info(
647-
`Using Ollama API for embeddings with model ${embeddingModel} (base: ${this.ollamaModel})`
652+
`Using Ollama API for embeddings with model ${embeddingModel} (base: ${this.ollamaModel})`,
648653
);
649654

650655
const response = await fetch(`${ollamaUrl}/api/embeddings`, {
@@ -671,7 +676,7 @@ export class LlamaService extends Service {
671676
const ollamaUrl =
672677
process.env.OLLAMA_SERVER_URL || "http://localhost:11434";
673678
elizaLogger.info(
674-
`Using Ollama API at ${ollamaUrl} with model ${ollamaModel}`
679+
`Using Ollama API at ${ollamaUrl} with model ${ollamaModel}`,
675680
);
676681

677682
const response = await fetch(`${ollamaUrl}/api/generate`, {
@@ -706,7 +711,7 @@ export class LlamaService extends Service {
706711
const embeddingModel =
707712
process.env.OLLAMA_EMBEDDING_MODEL || "mxbai-embed-large";
708713
elizaLogger.info(
709-
`Using Ollama API for embeddings with model ${embeddingModel} (base: ${ollamaModel})`
714+
`Using Ollama API for embeddings with model ${embeddingModel} (base: ${ollamaModel})`,
710715
);
711716

712717
const response = await fetch(`${ollamaUrl}/api/embeddings`, {
@@ -720,7 +725,7 @@ export class LlamaService extends Service {
720725

721726
if (!response.ok) {
722727
throw new Error(
723-
`Ollama embeddings request failed: ${response.statusText}`
728+
`Ollama embeddings request failed: ${response.statusText}`,
724729
);
725730
}
726731

@@ -736,8 +741,9 @@ export class LlamaService extends Service {
736741
const tokens = this.model!.tokenize(prompt);
737742

738743
// tokenize the words to punish
739-
const wordsToPunishTokens = wordsToPunish
740-
.flatMap((word) => this.model!.tokenize(word));
744+
const wordsToPunishTokens = wordsToPunish.flatMap((word) =>
745+
this.model!.tokenize(word),
746+
);
741747

742748
const repeatPenalty: LlamaContextSequenceRepeatPenalty = {
743749
punishTokens: () => wordsToPunishTokens,

0 commit comments

Comments
 (0)