@@ -13,25 +13,28 @@ import {
13
13
*/
14
14
export async function embed ( runtime : IAgentRuntime , input : string ) {
15
15
// get the charcter, and handle by model type
16
- const model = models [ runtime . character . settings . model ] ;
16
+ const modelProvider = models [ runtime . character . modelProvider ] ;
17
+ const embeddingModel = modelProvider . model . embedding ;
17
18
18
19
if (
19
- model !== ModelProviderName . OPENAI &&
20
- model !== ModelProviderName . OLLAMA
20
+ runtime . character . modelProvider !== ModelProviderName . OPENAI &&
21
+ runtime . character . modelProvider !== ModelProviderName . OLLAMA
21
22
) {
22
23
const service = runtime . getService < ITextGenerationService > (
23
24
ServiceType . TEXT_GENERATION
24
25
) ;
25
- return await service . getInstance ( ) . getEmbeddingResponse ( input ) ;
26
- }
27
-
28
- const embeddingModel = models [ runtime . modelProvider ] . model . embedding ;
26
+
27
+ const instance = service ?. getInstance ( ) ;
29
28
30
- // Check if we already have the embedding in the lore
31
- const cachedEmbedding = await retrieveCachedEmbedding ( runtime , input ) ;
32
- if ( cachedEmbedding ) {
33
- return cachedEmbedding ;
29
+ if ( instance ) {
30
+ return await instance . getEmbeddingResponse ( input ) ;
31
+ }
34
32
}
33
+ // Check if we already have the embedding in the lore
34
+ // const cachedEmbedding = await retrieveCachedEmbedding(runtime, input);
35
+ // if (cachedEmbedding) {
36
+ // return cachedEmbedding;
37
+ // }
35
38
36
39
const requestOptions = {
37
40
method : "POST" ,
@@ -51,7 +54,7 @@ export async function embed(runtime: IAgentRuntime, input: string) {
51
54
try {
52
55
const response = await fetch (
53
56
// TODO: make this not hardcoded
54
- `${ runtime . serverUrl } ${ runtime . modelProvider === ModelProviderName . OLLAMA ? "/v1" : "" } /embeddings` ,
57
+ `${ runtime . character . modelEndpointOverride || modelProvider . endpoint } ${ runtime . character . modelProvider === ModelProviderName . OLLAMA ? "/v1" : "" } /embeddings` ,
55
58
requestOptions
56
59
) ;
57
60
@@ -81,6 +84,11 @@ export async function retrieveCachedEmbedding(
81
84
runtime : IAgentRuntime ,
82
85
input : string
83
86
) {
87
+ if ( ! input ) {
88
+ console . log ( "No input to retrieve cached embedding for" ) ;
89
+ return null ;
90
+ }
91
+
84
92
const similaritySearchResult =
85
93
await runtime . messageManager . getCachedEmbeddings ( input ) ;
86
94
if ( similaritySearchResult . length > 0 ) {
0 commit comments