@@ -22,105 +22,106 @@ function getRootPath() {
22
22
return path . resolve ( __dirname , ".." ) ;
23
23
}
24
24
25
- /**
26
- * Send a message to the OpenAI API for embedding.
27
- * @param input The input to be embedded.
28
- * @returns The embedding of the input.
29
- */
30
- export async function embed ( runtime : IAgentRuntime , input : string ) {
31
- // get the charcter, and handle by model type
32
- const modelProvider = models [ runtime . character . modelProvider ] ;
33
- const embeddingModel = modelProvider . model . embedding ;
34
-
35
- if (
36
- runtime . character . modelProvider !== ModelProviderName . OPENAI &&
37
- runtime . character . modelProvider !== ModelProviderName . OLLAMA &&
38
- ! settings . USE_OPENAI_EMBEDDING
39
- ) {
40
-
41
- // make sure to trim tokens to 8192
42
- const cacheDir = getRootPath ( ) + "/cache/" ;
43
-
44
- // if the cache directory doesn't exist, create it
45
- if ( ! fs . existsSync ( cacheDir ) ) {
46
- fs . mkdirSync ( cacheDir , { recursive : true } ) ;
47
- }
48
-
49
- const embeddingModel = await FlagEmbedding . init ( {
50
- cacheDir : cacheDir
51
- } ) ;
52
-
53
- const trimmedInput = trimTokens ( input , 8000 , "gpt-4o-mini" ) ;
54
-
55
- const embedding : number [ ] = await embeddingModel . queryEmbed ( trimmedInput ) ;
56
- console . log ( "Embedding dimensions: " , embedding . length ) ;
57
- return embedding ;
58
-
59
- // commented out the text generation service that uses llama
60
- // const service = runtime.getService<ITextGenerationService>(
61
- // ServiceType.TEXT_GENERATION
62
- // );
63
-
64
- // const instance = service?.getInstance();
65
-
66
- // if (instance) {
67
- // return await instance.getEmbeddingResponse(input);
68
- // }
69
- }
70
-
71
- // TODO: Fix retrieveCachedEmbedding
72
- // Check if we already have the embedding in the lore
73
- const cachedEmbedding = await retrieveCachedEmbedding ( runtime , input ) ;
74
- if ( cachedEmbedding ) {
75
- return cachedEmbedding ;
76
- }
25
+ interface EmbeddingOptions {
26
+ model : string ;
27
+ endpoint : string ;
28
+ apiKey ?: string ;
29
+ length ?: number ;
30
+ isOllama ?: boolean ;
31
+ }
77
32
33
+ async function getRemoteEmbedding ( input : string , options : EmbeddingOptions ) : Promise < number [ ] > {
78
34
const requestOptions = {
79
35
method : "POST" ,
80
36
headers : {
81
37
"Content-Type" : "application/json" ,
82
- // TODO: make this not hardcoded
83
- // TODO: make this not hardcoded
84
- ...( ( runtime . modelProvider !== ModelProviderName . OLLAMA || settings . USE_OPENAI_EMBEDDING ) ? {
85
- Authorization : `Bearer ${ runtime . token } ` ,
86
- } : { } ) ,
38
+ ...( options . apiKey ? {
39
+ Authorization : `Bearer ${ options . apiKey } ` ,
40
+ } : { } ) ,
87
41
} ,
88
42
body : JSON . stringify ( {
89
43
input,
90
- model : embeddingModel ,
91
- length : 384 , // we are squashing dimensions to 768 for openai, even thought the model supports 1536
92
- // -- this is ok for matryoshka embeddings but longterm, we might want to support 1536
44
+ model : options . model ,
45
+ length : options . length || 384 ,
93
46
} ) ,
94
47
} ;
48
+
95
49
try {
96
50
const response = await fetch (
97
- // TODO: make this not hardcoded
98
- `${ runtime . character . modelEndpointOverride || modelProvider . endpoint } ${ ( runtime . character . modelProvider === ModelProviderName . OLLAMA && ! settings . USE_OPENAI_EMBEDDING ) ? "/v1" : "" } /embeddings` ,
51
+ `${ options . endpoint } ${ options . isOllama ? "/v1" : "" } /embeddings` ,
99
52
requestOptions
100
53
) ;
101
54
102
55
if ( ! response . ok ) {
103
56
throw new Error (
104
- "OpenAI API Error: " +
105
- response . status +
106
- " " +
107
- response . statusText
57
+ "Embedding API Error: " +
58
+ response . status +
59
+ " " +
60
+ response . statusText
108
61
) ;
109
62
}
110
63
111
- interface OpenAIEmbeddingResponse {
64
+ interface EmbeddingResponse {
112
65
data : Array < { embedding : number [ ] } > ;
113
66
}
114
67
115
- const data : OpenAIEmbeddingResponse = await response . json ( ) ;
116
-
68
+ const data : EmbeddingResponse = await response . json ( ) ;
117
69
return data ?. data ?. [ 0 ] . embedding ;
118
70
} catch ( e ) {
119
71
console . error ( e ) ;
120
72
throw e ;
121
73
}
122
74
}
123
75
76
+ /**
77
+ * Send a message to the OpenAI API for embedding.
78
+ * @param input The input to be embedded.
79
+ * @returns The embedding of the input.
80
+ */
81
+ export async function embed ( runtime : IAgentRuntime , input : string ) {
82
+ const modelProvider = models [ runtime . character . modelProvider ] ;
83
+ const embeddingModel = modelProvider . model . embedding ;
84
+
85
+ // Try local embedding first
86
+ if (
87
+ runtime . character . modelProvider !== ModelProviderName . OPENAI &&
88
+ runtime . character . modelProvider !== ModelProviderName . OLLAMA &&
89
+ ! settings . USE_OPENAI_EMBEDDING
90
+ ) {
91
+ return await getLocalEmbedding ( input ) ;
92
+ }
93
+
94
+ // Check cache
95
+ const cachedEmbedding = await retrieveCachedEmbedding ( runtime , input ) ;
96
+ if ( cachedEmbedding ) {
97
+ return cachedEmbedding ;
98
+ }
99
+
100
+ // Get remote embedding
101
+ return await getRemoteEmbedding ( input , {
102
+ model : embeddingModel ,
103
+ endpoint : runtime . character . modelEndpointOverride || modelProvider . endpoint ,
104
+ apiKey : runtime . token ,
105
+ isOllama : runtime . character . modelProvider === ModelProviderName . OLLAMA && ! settings . USE_OPENAI_EMBEDDING
106
+ } ) ;
107
+ }
108
+
109
+ async function getLocalEmbedding ( input : string ) : Promise < number [ ] > {
110
+ const cacheDir = getRootPath ( ) + "/cache/" ;
111
+ if ( ! fs . existsSync ( cacheDir ) ) {
112
+ fs . mkdirSync ( cacheDir , { recursive : true } ) ;
113
+ }
114
+
115
+ const embeddingModel = await FlagEmbedding . init ( {
116
+ cacheDir : cacheDir
117
+ } ) ;
118
+
119
+ const trimmedInput = trimTokens ( input , 8000 , "gpt-4o-mini" ) ;
120
+ const embedding = await embeddingModel . queryEmbed ( trimmedInput ) ;
121
+ console . log ( "Embedding dimensions: " , embedding . length ) ;
122
+ return embedding ;
123
+ }
124
+
124
125
export async function retrieveCachedEmbedding (
125
126
runtime : IAgentRuntime ,
126
127
input : string
@@ -129,11 +130,12 @@ export async function retrieveCachedEmbedding(
129
130
console . log ( "No input to retrieve cached embedding for" ) ;
130
131
return null ;
131
132
}
132
-
133
+
133
134
const similaritySearchResult =
134
135
await runtime . messageManager . getCachedEmbeddings ( input ) ;
135
136
if ( similaritySearchResult . length > 0 ) {
136
137
return similaritySearchResult [ 0 ] . embedding ;
137
138
}
138
139
return null ;
139
140
}
141
+
0 commit comments