Skip to content

Commit 3062cc8

Browse files
authored
Merge pull request #262 from dorianjanezic/main
cachedEmbeddings fix
2 parents d273fdd + 854365f commit 3062cc8

File tree

4 files changed

+707
-923
lines changed

4 files changed

+707
-923
lines changed

packages/adapter-sqlite/src/index.ts

+20-25
Original file line numberDiff line numberDiff line change
@@ -336,34 +336,29 @@ export class SqliteDatabaseAdapter extends DatabaseAdapter {
336336
query_field_name: string;
337337
query_field_sub_name: string;
338338
query_match_count: number;
339-
}): Promise<
340-
{
341-
embedding: number[];
342-
levenshtein_score: number;
343-
}[]
344-
> {
339+
}): Promise<{ embedding: number[]; levenshtein_score: number }[]> {
345340
const sql = `
346-
SELECT *
347-
FROM memories
348-
WHERE type = ?
349-
AND vec_distance_L2(${opts.query_field_name}, ?) <= ?
350-
ORDER BY vec_distance_L2(${opts.query_field_name}, ?) ASC
351-
LIMIT ?
352-
`;
353-
console.log("sql", sql)
354-
console.log("opts.query_input", opts.query_input)
355-
const memories = this.db.prepare(sql).all(
341+
SELECT
342+
embedding,
343+
0 as levenshtein_score -- Using 0 as placeholder score
344+
FROM memories
345+
WHERE type = ?
346+
AND json_extract(content, '$.' || ? || '.' || ?) IS NOT NULL
347+
LIMIT ?
348+
`;
349+
350+
const params = [
356351
opts.query_table_name,
357-
new Float32Array(opts.query_input.split(",").map(Number)), // Convert string to Float32Array
358-
opts.query_input,
359-
new Float32Array(opts.query_input.split(",").map(Number))
360-
) as Memory[];
352+
opts.query_field_name,
353+
opts.query_field_sub_name,
354+
opts.query_match_count
355+
];
361356

362-
return memories.map((memory) => ({
363-
embedding: Array.from(
364-
new Float32Array(memory.embedding as unknown as Buffer)
365-
), // Convert Buffer to number[]
366-
levenshtein_score: 0,
357+
const rows = this.db.prepare(sql).all(...params);
358+
359+
return rows.map((row) => ({
360+
embedding: row.embedding,
361+
levenshtein_score: 0
367362
}));
368363
}
369364

packages/core/src/defaultCharacter.ts

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@ export const defaultCharacter: Character = {
44
name: "Eliza",
55
plugins: [],
66
clients: [],
7-
modelProvider: ModelProviderName.LLAMALOCAL,
7+
modelProvider: ModelProviderName.OPENAI,
88
settings: {
9-
secrets: {},
9+
secrets: {
10+
},
1011
voice: {
1112
model: "en_US-hfc_female-medium",
1213
},

packages/core/src/embedding.ts

+26-13
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ import { fileURLToPath } from "url";
44
import models from "./models.ts";
55
import {
66
IAgentRuntime,
7-
ModelProviderName
7+
ModelProviderName,
8+
ModelClass
89
} from "./types.ts";
910
import fs from "fs";
1011
import { trimTokens } from "./generation.ts";
@@ -18,7 +19,7 @@ function getRootPath() {
1819
if (rootPath.includes("/eliza/")) {
1920
return rootPath.split("/eliza/")[0] + "/eliza/";
2021
}
21-
22+
2223
return path.resolve(__dirname, "..");
2324
}
2425

@@ -32,13 +33,13 @@ interface EmbeddingOptions {
3233

3334
async function getRemoteEmbedding(input: string, options: EmbeddingOptions): Promise<number[]> {
3435
// Ensure endpoint ends with /v1 for OpenAI
35-
const baseEndpoint = options.endpoint.endsWith('/v1') ?
36-
options.endpoint :
36+
const baseEndpoint = options.endpoint.endsWith('/v1') ?
37+
options.endpoint :
3738
`${options.endpoint}${options.isOllama ? '/v1' : ''}`;
38-
39+
3940
// Construct full URL
4041
const fullUrl = `${baseEndpoint}/embeddings`;
41-
42+
4243
//console.log("Calling embedding API at:", fullUrl); // Debug log
4344

4445
const requestOptions = {
@@ -87,7 +88,18 @@ async function getRemoteEmbedding(input: string, options: EmbeddingOptions): Pro
8788
export async function embed(runtime: IAgentRuntime, input: string) {
8889
const modelProvider = models[runtime.character.modelProvider];
8990
//need to have env override for this to select what to use for embedding if provider doesnt provide or using openai
90-
const embeddingModel = modelProvider.model.embedding;
91+
const embeddingModel = (
92+
settings.USE_OPENAI_EMBEDDING ? "text-embedding-3-small" : // Use OpenAI if specified
93+
modelProvider.model?.[ModelClass.EMBEDDING] || // Use provider's embedding model if available
94+
models[ModelProviderName.OPENAI].model[ModelClass.EMBEDDING] // Fallback to OpenAI
95+
);
96+
97+
if (!embeddingModel) {
98+
throw new Error('No embedding model configured');
99+
}
100+
101+
console.log("embeddingModel", embeddingModel);
102+
91103

92104
// Try local embedding first
93105
if (
@@ -107,16 +119,17 @@ export async function embed(runtime: IAgentRuntime, input: string) {
107119
// Get remote embedding
108120
return await getRemoteEmbedding(input, {
109121
model: embeddingModel,
110-
endpoint: settings.USE_OPENAI_EMBEDDING ?
122+
endpoint: settings.USE_OPENAI_EMBEDDING ?
111123
'https://api.openai.com/v1' : // Always use OpenAI endpoint when USE_OPENAI_EMBEDDING is true
112124
(runtime.character.modelEndpointOverride || modelProvider.endpoint),
113-
apiKey: settings.USE_OPENAI_EMBEDDING ?
125+
apiKey: settings.USE_OPENAI_EMBEDDING ?
114126
settings.OPENAI_API_KEY : // Use OpenAI key from settings when USE_OPENAI_EMBEDDING is true
115127
runtime.token, // Use runtime token for other providers
116128
isOllama: runtime.character.modelProvider === ModelProviderName.OLLAMA && !settings.USE_OPENAI_EMBEDDING
117129
});
118130
}
119131

132+
120133
async function getLocalEmbedding(input: string): Promise<number[]> {
121134
const cacheDir = getRootPath() + "/cache/";
122135
if (!fs.existsSync(cacheDir)) {
@@ -137,13 +150,13 @@ export async function retrieveCachedEmbedding(
137150
runtime: IAgentRuntime,
138151
input: string
139152
) {
140-
if(!input) {
153+
if (!input) {
141154
console.log("No input to retrieve cached embedding for");
142155
return null;
143156
}
144-
145-
const similaritySearchResult = [];
146-
// await runtime.messageManager.getCachedEmbeddings(input);
157+
158+
const similaritySearchResult =
159+
await runtime.messageManager.getCachedEmbeddings(input);
147160
if (similaritySearchResult.length > 0) {
148161
return similaritySearchResult[0].embedding;
149162
}

0 commit comments

Comments
 (0)