1
-
2
1
// TODO: Maybe create these functions to read from character settings or env
3
2
// import { getEmbeddingModelSettings, getEndpoint } from "./models.ts";
4
- import { type IAgentRuntime , ModelProviderName } from "./types.ts" ;
3
+ import { type IAgentRuntime , ModelProviderName , ModelClass } from "./types.ts" ;
5
4
import settings from "./settings.ts" ;
6
5
import elizaLogger from "./logger.ts" ;
7
6
import LocalEmbeddingModelManager from "./localembeddingManager.ts" ;
@@ -30,15 +29,27 @@ export type EmbeddingConfig = {
30
29
readonly provider : string ;
31
30
} ;
32
31
33
- export const getEmbeddingConfig = ( ) : EmbeddingConfig => ( {
34
- dimensions :
35
- // TODO: get from env or character settings
36
- 384 ,
37
- model :
38
- // TODO: get from env or character settings
39
- "BGE-small-en-v1.5" ,
40
- provider : "BGE" ,
41
- } ) ;
32
+ export const getEmbeddingConfig = ( runtime ?: IAgentRuntime ) : EmbeddingConfig => {
33
+ if ( runtime ) {
34
+ const modelProvider = runtime . getModelProvider ( ) ;
35
+ const embeddingModel = modelProvider ?. models ?. [ ModelClass . EMBEDDING ] ;
36
+
37
+ if ( embeddingModel ?. name ) {
38
+ return {
39
+ dimensions : embeddingModel . dimensions || 1536 ,
40
+ model : embeddingModel . name ,
41
+ provider : modelProvider ?. provider || EmbeddingProvider . OpenAI ,
42
+ } ;
43
+ }
44
+ }
45
+
46
+ // Fallback to default config
47
+ return {
48
+ dimensions : 1536 , // OpenAI's text-embedding-ada-002 dimension
49
+ model : "text-embedding-3-small" , // Default to OpenAI's latest embedding model
50
+ provider : EmbeddingProvider . OpenAI
51
+ } ;
52
+ } ;
42
53
43
54
async function getRemoteEmbedding (
44
55
input : string ,
@@ -52,6 +63,26 @@ async function getRemoteEmbedding(
52
63
// Construct full URL
53
64
const fullUrl = `${ baseEndpoint } /embeddings` ;
54
65
66
+ elizaLogger . info ( "Embedding request:" , {
67
+ modelProvider : options . provider ,
68
+ useOpenAI : options . provider === EmbeddingProvider . OpenAI ,
69
+ input : `${ input ?. slice ( 0 , 50 ) } ...` ,
70
+ inputType : typeof input ,
71
+ inputLength : input ?. length ,
72
+ isString : typeof input === "string" ,
73
+ isEmpty : ! input ,
74
+ } ) ;
75
+
76
+ const requestBody : any = {
77
+ input,
78
+ model : options . model ,
79
+ } ;
80
+
81
+ // Only include dimensions for non-OpenAI providers
82
+ if ( options . provider !== EmbeddingProvider . OpenAI ) {
83
+ requestBody . dimensions = options . dimensions || options . length || getEmbeddingConfig ( ) . dimensions ;
84
+ }
85
+
55
86
const requestOptions = {
56
87
method : "POST" ,
57
88
headers : {
@@ -62,19 +93,14 @@ async function getRemoteEmbedding(
62
93
}
63
94
: { } ) ,
64
95
} ,
65
- body : JSON . stringify ( {
66
- input,
67
- model : options . model ,
68
- dimensions :
69
- options . dimensions ||
70
- options . length ||
71
- getEmbeddingConfig ( ) . dimensions , // Prefer dimensions, fallback to length
72
- } ) ,
96
+ body : JSON . stringify ( requestBody ) ,
73
97
} ;
74
98
75
99
try {
76
100
const response = await fetch ( fullUrl , requestOptions ) ;
77
101
102
+ elizaLogger . info ( "Embedding response:" , requestOptions ) ;
103
+
78
104
if ( ! response . ok ) {
79
105
elizaLogger . error ( "API Response:" , await response . text ( ) ) ; // Debug log
80
106
throw new Error (
@@ -158,28 +184,13 @@ export async function embed(runtime: IAgentRuntime, input: string) {
158
184
const cachedEmbedding = await retrieveCachedEmbedding ( runtime , input ) ;
159
185
if ( cachedEmbedding ) return cachedEmbedding ;
160
186
161
- const config = getEmbeddingConfig ( ) ;
162
187
const isNode = typeof process !== "undefined" && process . versions ?. node ;
163
188
164
- // use endpoint from model provider
165
- const endpoint = runtime . getSetting ( "PROVIDER_ENDPOINT" ) ;
166
- const apiKey = runtime . getSetting ( "PROVIDER_API_KEY" ) ;
167
-
168
-
169
- // Determine which embedding settings to use
170
- // TODO: enhance + verify logic to get from character settings or env
171
- if ( config . provider ) {
172
- return await getRemoteEmbedding ( input , {
173
- model : config . model ,
174
- endpoint : settings . PROVIDER_ENDPOINT || "https://api.openai.com/v1" ,
175
- apiKey : settings . PROVIDER_API_KEY ,
176
- dimensions : config . dimensions ,
177
- } ) ;
178
- }
179
-
189
+ // Get embedding configuration from runtime
190
+ const embeddingConfig = getEmbeddingConfig ( runtime ) ;
180
191
181
- // BGE - try local first if in Node
182
- if ( isNode ) {
192
+ // BGE - try local first if in Node and not using OpenAI
193
+ if ( isNode && embeddingConfig . provider !== EmbeddingProvider . OpenAI ) {
183
194
try {
184
195
return await getLocalEmbedding ( input ) ;
185
196
} catch ( error ) {
@@ -190,14 +201,13 @@ export async function embed(runtime: IAgentRuntime, input: string) {
190
201
}
191
202
}
192
203
193
- // Fallback to remote override
204
+ // Use remote embedding
194
205
return await getRemoteEmbedding ( input , {
195
- model : config . model ,
196
- endpoint :
197
- runtime . character . modelEndpointOverride ||
198
- runtime . getSetting ( "PROVIDER_ENDPOINT" ) ,
206
+ model : embeddingConfig . model ,
207
+ endpoint : runtime . character . modelEndpointOverride || runtime . getSetting ( "PROVIDER_ENDPOINT" ) || "https://api.openai.com/v1" ,
199
208
apiKey : runtime . getSetting ( "PROVIDER_API_KEY" ) || runtime . token ,
200
- dimensions : config . dimensions ,
209
+ dimensions : embeddingConfig . dimensions ,
210
+ provider : embeddingConfig . provider
201
211
} ) ;
202
212
203
213
async function getLocalEmbedding ( input : string ) : Promise < number [ ] > {
0 commit comments