@@ -4,7 +4,8 @@ import { fileURLToPath } from "url";
4
4
import models from "./models.ts" ;
5
5
import {
6
6
IAgentRuntime ,
7
- ModelProviderName
7
+ ModelProviderName ,
8
+ ModelClass
8
9
} from "./types.ts" ;
9
10
import fs from "fs" ;
10
11
import { trimTokens } from "./generation.ts" ;
@@ -18,7 +19,7 @@ function getRootPath() {
18
19
if ( rootPath . includes ( "/eliza/" ) ) {
19
20
return rootPath . split ( "/eliza/" ) [ 0 ] + "/eliza/" ;
20
21
}
21
-
22
+
22
23
return path . resolve ( __dirname , ".." ) ;
23
24
}
24
25
@@ -32,13 +33,13 @@ interface EmbeddingOptions {
32
33
33
34
async function getRemoteEmbedding ( input : string , options : EmbeddingOptions ) : Promise < number [ ] > {
34
35
// Ensure endpoint ends with /v1 for OpenAI
35
- const baseEndpoint = options . endpoint . endsWith ( '/v1' ) ?
36
- options . endpoint :
36
+ const baseEndpoint = options . endpoint . endsWith ( '/v1' ) ?
37
+ options . endpoint :
37
38
`${ options . endpoint } ${ options . isOllama ? '/v1' : '' } ` ;
38
-
39
+
39
40
// Construct full URL
40
41
const fullUrl = `${ baseEndpoint } /embeddings` ;
41
-
42
+
42
43
//console.log("Calling embedding API at:", fullUrl); // Debug log
43
44
44
45
const requestOptions = {
@@ -87,7 +88,18 @@ async function getRemoteEmbedding(input: string, options: EmbeddingOptions): Pro
87
88
export async function embed ( runtime : IAgentRuntime , input : string ) {
88
89
const modelProvider = models [ runtime . character . modelProvider ] ;
89
90
//need to have env override for this to select what to use for embedding if provider doesnt provide or using openai
90
- const embeddingModel = modelProvider . model . embedding ;
91
+ const embeddingModel = (
92
+ settings . USE_OPENAI_EMBEDDING ? "text-embedding-3-small" : // Use OpenAI if specified
93
+ modelProvider . model ?. [ ModelClass . EMBEDDING ] || // Use provider's embedding model if available
94
+ models [ ModelProviderName . OPENAI ] . model [ ModelClass . EMBEDDING ] // Fallback to OpenAI
95
+ ) ;
96
+
97
+ if ( ! embeddingModel ) {
98
+ throw new Error ( 'No embedding model configured' ) ;
99
+ }
100
+
101
+ console . log ( "embeddingModel" , embeddingModel ) ;
102
+
91
103
92
104
// Try local embedding first
93
105
if (
@@ -107,16 +119,17 @@ export async function embed(runtime: IAgentRuntime, input: string) {
107
119
// Get remote embedding
108
120
return await getRemoteEmbedding ( input , {
109
121
model : embeddingModel ,
110
- endpoint : settings . USE_OPENAI_EMBEDDING ?
122
+ endpoint : settings . USE_OPENAI_EMBEDDING ?
111
123
'https://api.openai.com/v1' : // Always use OpenAI endpoint when USE_OPENAI_EMBEDDING is true
112
124
( runtime . character . modelEndpointOverride || modelProvider . endpoint ) ,
113
- apiKey : settings . USE_OPENAI_EMBEDDING ?
125
+ apiKey : settings . USE_OPENAI_EMBEDDING ?
114
126
settings . OPENAI_API_KEY : // Use OpenAI key from settings when USE_OPENAI_EMBEDDING is true
115
127
runtime . token , // Use runtime token for other providers
116
128
isOllama : runtime . character . modelProvider === ModelProviderName . OLLAMA && ! settings . USE_OPENAI_EMBEDDING
117
129
} ) ;
118
130
}
119
131
132
+
120
133
async function getLocalEmbedding ( input : string ) : Promise < number [ ] > {
121
134
const cacheDir = getRootPath ( ) + "/cache/" ;
122
135
if ( ! fs . existsSync ( cacheDir ) ) {
@@ -137,13 +150,13 @@ export async function retrieveCachedEmbedding(
137
150
runtime : IAgentRuntime ,
138
151
input : string
139
152
) {
140
- if ( ! input ) {
153
+ if ( ! input ) {
141
154
console . log ( "No input to retrieve cached embedding for" ) ;
142
155
return null ;
143
156
}
144
-
145
- const similaritySearchResult = [ ] ;
146
- // await runtime.messageManager.getCachedEmbeddings(input);
157
+
158
+ const similaritySearchResult =
159
+ await runtime . messageManager . getCachedEmbeddings ( input ) ;
147
160
if ( similaritySearchResult . length > 0 ) {
148
161
return similaritySearchResult [ 0 ] . embedding ;
149
162
}
0 commit comments