Skip to content

Commit acb4e86

Browse files
authored
Merge pull request #252 from o-on-x/main
use openai embeddings setting
2 parents 15f7ba8 + 432362b commit acb4e86

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

.env.example

+3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ X_SERVER_URL=
2929
XAI_API_KEY=
3030
XAI_MODEL=
3131

32+
#Leave blank to use local embeddings
33+
USE_OPENAI_EMBEDDING= #TRUE
34+
3235
#OpenRouter (Use one model for everything or set individual for small, medium, large tasks)
3336
#leave blank to use defaults hermes 70b for small tasks & 405b for medium/large tasks
3437
OPENROUTER_MODEL=

packages/core/src/embedding.ts

+7-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import {
88
} from "./types.ts";
99
import fs from "fs";
1010
import { trimTokens } from "./generation.ts";
11+
import settings from "./settings.ts";
1112

1213
function getRootPath() {
1314
const __filename = fileURLToPath(import.meta.url);
@@ -33,7 +34,8 @@ export async function embed(runtime: IAgentRuntime, input: string) {
3334

3435
if (
3536
runtime.character.modelProvider !== ModelProviderName.OPENAI &&
36-
runtime.character.modelProvider !== ModelProviderName.OLLAMA
37+
runtime.character.modelProvider !== ModelProviderName.OLLAMA &&
38+
!settings.USE_OPENAI_EMBEDDING
3739
) {
3840

3941
// make sure to trim tokens to 8192
@@ -78,9 +80,10 @@ export async function embed(runtime: IAgentRuntime, input: string) {
7880
headers: {
7981
"Content-Type": "application/json",
8082
// TODO: make this not hardcoded
81-
...(runtime.modelProvider !== ModelProviderName.OLLAMA && {
83+
// TODO: make this not hardcoded
84+
...((runtime.modelProvider !== ModelProviderName.OLLAMA || settings.USE_OPENAI_EMBEDDING) ? {
8285
Authorization: `Bearer ${runtime.token}`,
83-
}),
86+
} : {}),
8487
},
8588
body: JSON.stringify({
8689
input,
@@ -92,7 +95,7 @@ export async function embed(runtime: IAgentRuntime, input: string) {
9295
try {
9396
const response = await fetch(
9497
// TODO: make this not hardcoded
95-
`${runtime.character.modelEndpointOverride || modelProvider.endpoint}${runtime.character.modelProvider === ModelProviderName.OLLAMA ? "/v1" : ""}/embeddings`,
98+
`${runtime.character.modelEndpointOverride || modelProvider.endpoint}${(runtime.character.modelProvider === ModelProviderName.OLLAMA && !settings.USE_OPENAI_EMBEDDING) ? "/v1" : ""}/embeddings`,
9699
requestOptions
97100
);
98101

0 commit comments

Comments
 (0)