Skip to content

Commit 420399e

Browse files
authored
Merge pull request #254 from o-on-x/main
refactor embeddings
2 parents acb4e86 + 7aad2f7 commit 420399e

File tree

2 files changed

+81
-73
lines changed

2 files changed

+81
-73
lines changed

packages/adapter-postgres/seed.sql

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1-
INSERT INTO public.accounts (id, name, email, avatarUrl, details) VALUES ('00000000-0000-0000-0000-000000000000', 'Default Agent', 'default@agent.com', '', '{}');
2-
INSERT INTO public.rooms (id) VALUES ('00000000-0000-0000-0000-000000000000');
3-
INSERT INTO public.participants (userId, roomId) VALUES ('00000000-0000-0000-0000-000000000000', '00000000-0000-0000-0000-000000000000');
1+
2+
INSERT INTO public.accounts (id, name, email, "avatarUrl", details)
3+
VALUES ('00000000-0000-0000-0000-000000000000', 'Default Agent', 'default@agent.com', '', '{}'::jsonb);
4+
5+
INSERT INTO public.rooms (id)
6+
VALUES ('00000000-0000-0000-0000-000000000000');
7+
8+
INSERT INTO public.participants (id, "userId", "roomId")
9+
VALUES ('00000000-0000-0000-0000-000000000001', '00000000-0000-0000-0000-000000000000', '00000000-0000-0000-0000-000000000000');

packages/core/src/embedding.ts

+72-70
Original file line numberDiff line numberDiff line change
@@ -22,105 +22,106 @@ function getRootPath() {
2222
return path.resolve(__dirname, "..");
2323
}
2424

25-
/**
26-
* Send a message to the OpenAI API for embedding.
27-
* @param input The input to be embedded.
28-
* @returns The embedding of the input.
29-
*/
30-
export async function embed(runtime: IAgentRuntime, input: string) {
31-
// get the charcter, and handle by model type
32-
const modelProvider = models[runtime.character.modelProvider];
33-
const embeddingModel = modelProvider.model.embedding;
34-
35-
if (
36-
runtime.character.modelProvider !== ModelProviderName.OPENAI &&
37-
runtime.character.modelProvider !== ModelProviderName.OLLAMA &&
38-
!settings.USE_OPENAI_EMBEDDING
39-
) {
40-
41-
// make sure to trim tokens to 8192
42-
const cacheDir = getRootPath() + "/cache/";
43-
44-
// if the cache directory doesn't exist, create it
45-
if (!fs.existsSync(cacheDir)) {
46-
fs.mkdirSync(cacheDir, { recursive: true });
47-
}
48-
49-
const embeddingModel = await FlagEmbedding.init({
50-
cacheDir: cacheDir
51-
});
52-
53-
const trimmedInput = trimTokens(input, 8000, "gpt-4o-mini");
54-
55-
const embedding: number[] = await embeddingModel.queryEmbed(trimmedInput);
56-
console.log("Embedding dimensions: ", embedding.length);
57-
return embedding;
58-
59-
// commented out the text generation service that uses llama
60-
// const service = runtime.getService<ITextGenerationService>(
61-
// ServiceType.TEXT_GENERATION
62-
// );
63-
64-
// const instance = service?.getInstance();
65-
66-
// if (instance) {
67-
// return await instance.getEmbeddingResponse(input);
68-
// }
69-
}
70-
71-
// TODO: Fix retrieveCachedEmbedding
72-
// Check if we already have the embedding in the lore
73-
const cachedEmbedding = await retrieveCachedEmbedding(runtime, input);
74-
if (cachedEmbedding) {
75-
return cachedEmbedding;
76-
}
25+
interface EmbeddingOptions {
26+
model: string;
27+
endpoint: string;
28+
apiKey?: string;
29+
length?: number;
30+
isOllama?: boolean;
31+
}
7732

33+
async function getRemoteEmbedding(input: string, options: EmbeddingOptions): Promise<number[]> {
7834
const requestOptions = {
7935
method: "POST",
8036
headers: {
8137
"Content-Type": "application/json",
82-
// TODO: make this not hardcoded
83-
// TODO: make this not hardcoded
84-
...((runtime.modelProvider !== ModelProviderName.OLLAMA || settings.USE_OPENAI_EMBEDDING) ? {
85-
Authorization: `Bearer ${runtime.token}`,
86-
} : {}),
38+
...(options.apiKey ? {
39+
Authorization: `Bearer ${options.apiKey}`,
40+
} : {}),
8741
},
8842
body: JSON.stringify({
8943
input,
90-
model: embeddingModel,
91-
length: 384, // we are squashing dimensions to 768 for openai, even thought the model supports 1536
92-
// -- this is ok for matryoshka embeddings but longterm, we might want to support 1536
44+
model: options.model,
45+
length: options.length || 384,
9346
}),
9447
};
48+
9549
try {
9650
const response = await fetch(
97-
// TODO: make this not hardcoded
98-
`${runtime.character.modelEndpointOverride || modelProvider.endpoint}${(runtime.character.modelProvider === ModelProviderName.OLLAMA && !settings.USE_OPENAI_EMBEDDING) ? "/v1" : ""}/embeddings`,
51+
`${options.endpoint}${options.isOllama ? "/v1" : ""}/embeddings`,
9952
requestOptions
10053
);
10154

10255
if (!response.ok) {
10356
throw new Error(
104-
"OpenAI API Error: " +
105-
response.status +
106-
" " +
107-
response.statusText
57+
"Embedding API Error: " +
58+
response.status +
59+
" " +
60+
response.statusText
10861
);
10962
}
11063

111-
interface OpenAIEmbeddingResponse {
64+
interface EmbeddingResponse {
11265
data: Array<{ embedding: number[] }>;
11366
}
11467

115-
const data: OpenAIEmbeddingResponse = await response.json();
116-
68+
const data: EmbeddingResponse = await response.json();
11769
return data?.data?.[0].embedding;
11870
} catch (e) {
11971
console.error(e);
12072
throw e;
12173
}
12274
}
12375

76+
/**
77+
* Send a message to the OpenAI API for embedding.
78+
* @param input The input to be embedded.
79+
* @returns The embedding of the input.
80+
*/
81+
export async function embed(runtime: IAgentRuntime, input: string) {
82+
const modelProvider = models[runtime.character.modelProvider];
83+
const embeddingModel = modelProvider.model.embedding;
84+
85+
// Try local embedding first
86+
if (
87+
runtime.character.modelProvider !== ModelProviderName.OPENAI &&
88+
runtime.character.modelProvider !== ModelProviderName.OLLAMA &&
89+
!settings.USE_OPENAI_EMBEDDING
90+
) {
91+
return await getLocalEmbedding(input);
92+
}
93+
94+
// Check cache
95+
const cachedEmbedding = await retrieveCachedEmbedding(runtime, input);
96+
if (cachedEmbedding) {
97+
return cachedEmbedding;
98+
}
99+
100+
// Get remote embedding
101+
return await getRemoteEmbedding(input, {
102+
model: embeddingModel,
103+
endpoint: runtime.character.modelEndpointOverride || modelProvider.endpoint,
104+
apiKey: runtime.token,
105+
isOllama: runtime.character.modelProvider === ModelProviderName.OLLAMA && !settings.USE_OPENAI_EMBEDDING
106+
});
107+
}
108+
109+
async function getLocalEmbedding(input: string): Promise<number[]> {
110+
const cacheDir = getRootPath() + "/cache/";
111+
if (!fs.existsSync(cacheDir)) {
112+
fs.mkdirSync(cacheDir, { recursive: true });
113+
}
114+
115+
const embeddingModel = await FlagEmbedding.init({
116+
cacheDir: cacheDir
117+
});
118+
119+
const trimmedInput = trimTokens(input, 8000, "gpt-4o-mini");
120+
const embedding = await embeddingModel.queryEmbed(trimmedInput);
121+
console.log("Embedding dimensions: ", embedding.length);
122+
return embedding;
123+
}
124+
124125
export async function retrieveCachedEmbedding(
125126
runtime: IAgentRuntime,
126127
input: string
@@ -129,11 +130,12 @@ export async function retrieveCachedEmbedding(
129130
console.log("No input to retrieve cached embedding for");
130131
return null;
131132
}
132-
133+
133134
const similaritySearchResult =
134135
await runtime.messageManager.getCachedEmbeddings(input);
135136
if (similaritySearchResult.length > 0) {
136137
return similaritySearchResult[0].embedding;
137138
}
138139
return null;
139140
}
141+

0 commit comments

Comments
 (0)