forked from elizaOS/eliza
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathembedding.ts
161 lines (139 loc) · 5.1 KB
/
embedding.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import { FlagEmbedding } from "fastembed";
import path from "path";
import { fileURLToPath } from "url";
import { models } from "./models.ts";
import { IAgentRuntime, ModelProviderName, ModelClass } from "./types.ts";
import fs from "fs";
import { trimTokens } from "./generation.ts";
import settings from "./settings.ts";
function getRootPath() {
const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);
const rootPath = path.resolve(__dirname, "..");
if (rootPath.includes("/eliza/")) {
return rootPath.split("/eliza/")[0] + "/eliza/";
}
return path.resolve(__dirname, "..");
}
interface EmbeddingOptions {
model: string;
endpoint: string;
apiKey?: string;
length?: number;
isOllama?: boolean;
}
async function getRemoteEmbedding(
input: string,
options: EmbeddingOptions
): Promise<number[]> {
// Ensure endpoint ends with /v1 for OpenAI
const baseEndpoint = options.endpoint.endsWith("/v1")
? options.endpoint
: `${options.endpoint}${options.isOllama ? "/v1" : ""}`;
// Construct full URL
const fullUrl = `${baseEndpoint}/embeddings`;
//console.log("Calling embedding API at:", fullUrl); // Debug log
const requestOptions = {
method: "POST",
headers: {
"Content-Type": "application/json",
...(options.apiKey
? {
Authorization: `Bearer ${options.apiKey}`,
}
: {}),
},
body: JSON.stringify({
input,
model: options.model,
length: options.length || 384,
}),
};
try {
const response = await fetch(fullUrl, requestOptions);
if (!response.ok) {
console.error("API Response:", await response.text()); // Debug log
throw new Error(
`Embedding API Error: ${response.status} ${response.statusText}`
);
}
interface EmbeddingResponse {
data: Array<{ embedding: number[] }>;
}
const data: EmbeddingResponse = await response.json();
return data?.data?.[0].embedding;
} catch (e) {
console.error("Full error details:", e);
throw e;
}
}
/**
* Send a message to the OpenAI API for embedding.
* @param input The input to be embedded.
* @returns The embedding of the input.
*/
export async function embed(runtime: IAgentRuntime, input: string) {
const modelProvider = models[runtime.character.modelProvider];
//need to have env override for this to select what to use for embedding if provider doesnt provide or using openai
const embeddingModel = settings.USE_OPENAI_EMBEDDING
? "text-embedding-3-small" // Use OpenAI if specified
: modelProvider.model?.[ModelClass.EMBEDDING] || // Use provider's embedding model if available
models[ModelProviderName.OPENAI].model[ModelClass.EMBEDDING]; // Fallback to OpenAI
if (!embeddingModel) {
throw new Error("No embedding model configured");
}
// Try local embedding first
if (
runtime.character.modelProvider !== ModelProviderName.OPENAI &&
runtime.character.modelProvider !== ModelProviderName.OLLAMA &&
!settings.USE_OPENAI_EMBEDDING
) {
return await getLocalEmbedding(input);
}
// Check cache
const cachedEmbedding = await retrieveCachedEmbedding(runtime, input);
if (cachedEmbedding) {
return cachedEmbedding;
}
// Get remote embedding
return await getRemoteEmbedding(input, {
model: embeddingModel,
endpoint: settings.USE_OPENAI_EMBEDDING
? "https://api.openai.com/v1" // Always use OpenAI endpoint when USE_OPENAI_EMBEDDING is true
: runtime.character.modelEndpointOverride || modelProvider.endpoint,
apiKey: settings.USE_OPENAI_EMBEDDING
? settings.OPENAI_API_KEY // Use OpenAI key from settings when USE_OPENAI_EMBEDDING is true
: runtime.token, // Use runtime token for other providers
isOllama:
runtime.character.modelProvider === ModelProviderName.OLLAMA &&
!settings.USE_OPENAI_EMBEDDING,
});
}
async function getLocalEmbedding(input: string): Promise<number[]> {
const cacheDir = getRootPath() + "/cache/";
if (!fs.existsSync(cacheDir)) {
fs.mkdirSync(cacheDir, { recursive: true });
}
const embeddingModel = await FlagEmbedding.init({
cacheDir: cacheDir,
});
const trimmedInput = trimTokens(input, 8000, "gpt-4o-mini");
const embedding = await embeddingModel.queryEmbed(trimmedInput);
//console.log("Embedding dimensions: ", embedding.length);
return embedding;
}
export async function retrieveCachedEmbedding(
runtime: IAgentRuntime,
input: string
) {
if (!input) {
console.log("No input to retrieve cached embedding for");
return null;
}
const similaritySearchResult =
await runtime.messageManager.getCachedEmbeddings(input);
if (similaritySearchResult.length > 0) {
return similaritySearchResult[0].embedding;
}
return null;
}