Skip to content

Commit 8499d96

Browse files
committed
wip: save
1 parent a40b7fd commit 8499d96

File tree

6 files changed

+68917
-339
lines changed

6 files changed

+68917
-339
lines changed

packages/core/__tests__/embedding.test.ts

+241-31
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import {
55
getEmbeddingType,
66
getEmbeddingZeroVector,
77
} from "../src/embedding.ts";
8-
import { type IAgentRuntime, ModelProviderName } from "../types.ts";
8+
import { type IAgentRuntime, ModelProviderName } from "../src/types.ts";
99
import settings from "../src/settings.ts";
1010

1111
// Mock environment-related settings
@@ -23,18 +23,26 @@ vi.mock("../settings", () => ({
2323
}));
2424

2525
// Mock fastembed module for local embeddings
26-
vi.mock("fastembed", () => ({
27-
FlagEmbedding: {
28-
init: vi.fn().mockResolvedValue({
29-
queryEmbed: vi
30-
.fn()
31-
.mockResolvedValue(new Float32Array(384).fill(0.1)),
32-
}),
33-
},
34-
EmbeddingModel: {
35-
BGESmallENV15: "BGE-small-en-v1.5",
36-
},
37-
}));
26+
vi.mock("fastembed", () => {
27+
class MockFlagEmbedding {
28+
constructor() {}
29+
30+
static async init() {
31+
return new MockFlagEmbedding();
32+
}
33+
34+
async queryEmbed(text: string | string[]) {
35+
return [new Float32Array(384).fill(0.1)];
36+
}
37+
}
38+
39+
return {
40+
FlagEmbedding: MockFlagEmbedding,
41+
EmbeddingModel: {
42+
BGESmallENV15: "BGE-small-en-v1.5",
43+
},
44+
};
45+
});
3846

3947
// Mock global fetch for remote embedding requests
4048
const mockFetch = vi.fn();
@@ -44,20 +52,223 @@ describe("Embedding Module", () => {
4452
let mockRuntime: IAgentRuntime;
4553

4654
beforeEach(() => {
55+
// Reset all mocks
56+
vi.clearAllMocks();
57+
4758
// Prepare a mock runtime
4859
mockRuntime = {
60+
agentId: "00000000-0000-0000-0000-000000000000" as `${string}-${string}-${string}-${string}-${string}`,
61+
serverUrl: "http://test-server",
62+
databaseAdapter: {
63+
init: vi.fn(),
64+
close: vi.fn(),
65+
getMemories: vi.fn(),
66+
createMemory: vi.fn(),
67+
removeMemory: vi.fn(),
68+
searchMemories: vi.fn(),
69+
searchMemoriesByEmbedding: vi.fn(),
70+
getGoals: vi.fn(),
71+
createGoal: vi.fn(),
72+
updateGoal: vi.fn(),
73+
removeGoal: vi.fn(),
74+
getRoom: vi.fn(),
75+
createRoom: vi.fn(),
76+
removeRoom: vi.fn(),
77+
addParticipant: vi.fn(),
78+
removeParticipant: vi.fn(),
79+
getParticipantsForRoom: vi.fn(),
80+
getParticipantUserState: vi.fn(),
81+
setParticipantUserState: vi.fn(),
82+
createRelationship: vi.fn(),
83+
getRelationship: vi.fn(),
84+
getRelationships: vi.fn(),
85+
getKnowledge: vi.fn(),
86+
searchKnowledge: vi.fn(),
87+
createKnowledge: vi.fn(),
88+
removeKnowledge: vi.fn(),
89+
clearKnowledge: vi.fn(),
90+
},
91+
token: "test-token",
92+
modelProvider: ModelProviderName.OPENAI,
93+
imageModelProvider: ModelProviderName.OPENAI,
94+
imageVisionModelProvider: ModelProviderName.OPENAI,
95+
providers: [],
96+
actions: [],
97+
evaluators: [],
98+
plugins: [],
4999
character: {
50-
modelProvider: ModelProviderName.OLLAMA,
51-
modelEndpointOverride: null,
100+
modelProvider: ModelProviderName.OPENAI,
101+
name: "Test Character",
102+
username: "test",
103+
bio: ["Test bio"],
104+
lore: ["Test lore"],
105+
messageExamples: [],
106+
postExamples: [],
107+
topics: [],
108+
adjectives: [],
109+
style: {
110+
all: [],
111+
chat: [],
112+
post: []
113+
},
114+
clients: [],
115+
plugins: [],
116+
},
117+
getModelProvider: () => ({
118+
apiKey: "test-key",
119+
endpoint: "test-endpoint",
120+
provider: ModelProviderName.OPENAI,
121+
models: {
122+
default: { name: "test-model", maxInputTokens: 4096, maxOutputTokens: 4096, stop: [], temperature: 0.7 },
123+
},
124+
}),
125+
getSetting: (key: string) => {
126+
const settings = {
127+
USE_OPENAI_EMBEDDING: "false",
128+
USE_OLLAMA_EMBEDDING: "false",
129+
USE_GAIANET_EMBEDDING: "false",
130+
OPENAI_API_KEY: "mock-openai-key",
131+
OPENAI_API_URL: "https://api.openai.com/v1",
132+
GAIANET_API_KEY: "mock-gaianet-key",
133+
OLLAMA_EMBEDDING_MODEL: "mxbai-embed-large",
134+
GAIANET_EMBEDDING_MODEL: "nomic-embed",
135+
};
136+
return settings[key as keyof typeof settings] || "";
137+
},
138+
knowledgeManager: {
139+
init: vi.fn(),
140+
close: vi.fn(),
141+
addKnowledge: vi.fn(),
142+
removeKnowledge: vi.fn(),
143+
searchKnowledge: vi.fn(),
144+
clearKnowledge: vi.fn(),
145+
},
146+
memoryManager: {
147+
init: vi.fn(),
148+
close: vi.fn(),
149+
addMemory: vi.fn(),
150+
removeMemory: vi.fn(),
151+
searchMemories: vi.fn(),
152+
searchMemoriesByEmbedding: vi.fn(),
153+
clearMemories: vi.fn(),
154+
},
155+
goalManager: {
156+
init: vi.fn(),
157+
close: vi.fn(),
158+
addGoal: vi.fn(),
159+
updateGoal: vi.fn(),
160+
removeGoal: vi.fn(),
161+
getGoals: vi.fn(),
162+
clearGoals: vi.fn(),
163+
},
164+
relationshipManager: {
165+
init: vi.fn(),
166+
close: vi.fn(),
167+
addRelationship: vi.fn(),
168+
getRelationship: vi.fn(),
169+
getRelationships: vi.fn(),
52170
},
53-
token: "mock-token",
171+
cacheManager: {
172+
get: vi.fn(),
173+
set: vi.fn(),
174+
delete: vi.fn(),
175+
},
176+
services: new Map(),
177+
clients: {},
54178
messageManager: {
55-
getCachedEmbeddings: vi.fn().mockResolvedValue([]),
179+
runtime: {} as IAgentRuntime,
180+
tableName: "messages",
181+
addEmbeddingToMemory: vi.fn(),
182+
getMemories: vi.fn(),
183+
getCachedEmbeddings: vi.fn(),
184+
getMemoryById: vi.fn(),
185+
getMemoriesByRoomIds: vi.fn(),
186+
searchMemoriesByEmbedding: vi.fn(),
187+
createMemory: vi.fn(),
188+
removeMemory: vi.fn(),
189+
removeAllMemories: vi.fn(),
190+
countMemories: vi.fn(),
191+
},
192+
descriptionManager: {
193+
runtime: {} as IAgentRuntime,
194+
tableName: "descriptions",
195+
addEmbeddingToMemory: vi.fn(),
196+
getMemories: vi.fn(),
197+
getCachedEmbeddings: vi.fn(),
198+
getMemoryById: vi.fn(),
199+
getMemoriesByRoomIds: vi.fn(),
200+
searchMemoriesByEmbedding: vi.fn(),
201+
createMemory: vi.fn(),
202+
removeMemory: vi.fn(),
203+
removeAllMemories: vi.fn(),
204+
countMemories: vi.fn(),
205+
},
206+
documentsManager: {
207+
runtime: {} as IAgentRuntime,
208+
tableName: "documents",
209+
addEmbeddingToMemory: vi.fn(),
210+
getMemories: vi.fn(),
211+
getCachedEmbeddings: vi.fn(),
212+
getMemoryById: vi.fn(),
213+
getMemoriesByRoomIds: vi.fn(),
214+
searchMemoriesByEmbedding: vi.fn(),
215+
createMemory: vi.fn(),
216+
removeMemory: vi.fn(),
217+
removeAllMemories: vi.fn(),
218+
countMemories: vi.fn(),
219+
},
220+
loreManager: {
221+
runtime: {} as IAgentRuntime,
222+
tableName: "lore",
223+
addEmbeddingToMemory: vi.fn(),
224+
getMemories: vi.fn(),
225+
getCachedEmbeddings: vi.fn(),
226+
getMemoryById: vi.fn(),
227+
getMemoriesByRoomIds: vi.fn(),
228+
searchMemoriesByEmbedding: vi.fn(),
229+
createMemory: vi.fn(),
230+
removeMemory: vi.fn(),
231+
removeAllMemories: vi.fn(),
232+
countMemories: vi.fn(),
56233
},
234+
ragKnowledgeManager: {
235+
runtime: {} as IAgentRuntime,
236+
tableName: "rag_knowledge",
237+
getKnowledge: vi.fn(),
238+
createKnowledge: vi.fn(),
239+
removeKnowledge: vi.fn(),
240+
searchKnowledge: vi.fn(),
241+
clearKnowledge: vi.fn(),
242+
processFile: vi.fn(),
243+
cleanupDeletedKnowledgeFiles: vi.fn(),
244+
generateScopedId: vi.fn(),
245+
},
246+
initialize: vi.fn(),
247+
registerMemoryManager: vi.fn(),
248+
getMemoryManager: vi.fn(),
249+
getService: vi.fn(),
250+
registerService: vi.fn(),
251+
composeState: vi.fn(),
252+
processActions: vi.fn(),
253+
evaluate: vi.fn(),
254+
ensureParticipantExists: vi.fn(),
255+
ensureUserExists: vi.fn(),
256+
ensureConnection: vi.fn(),
257+
ensureParticipantInRoom: vi.fn(),
258+
ensureRoomExists: vi.fn(),
259+
updateRecentMessageState: vi.fn(),
260+
getConversationLength: vi.fn(),
261+
registerAction: vi.fn(),
57262
} as unknown as IAgentRuntime;
58263

59-
vi.clearAllMocks();
264+
// Reset fetch mock
60265
mockFetch.mockReset();
266+
mockFetch.mockResolvedValue({
267+
ok: true,
268+
json: async () => ({
269+
data: [new Array(384).fill(0.1)],
270+
}),
271+
});
61272
});
62273

63274
describe("getEmbeddingConfig", () => {
@@ -67,25 +278,24 @@ describe("Embedding Module", () => {
67278
expect(config.model).toBe("BGE-small-en-v1.5");
68279
expect(config.provider).toBe("BGE");
69280
});
70-
71-
test("should return OpenAI config when USE_OPENAI_EMBEDDING is true", () => {
72-
vi.mocked(settings).USE_OPENAI_EMBEDDING = "true";
73-
const config = getEmbeddingConfig();
74-
expect(config.dimensions).toBe(1536);
75-
expect(config.model).toBe("text-embedding-3-small");
76-
expect(config.provider).toBe("OpenAI");
77-
});
78281
});
79282

80283
describe("getEmbeddingType", () => {
81-
test("should return 'remote' for Ollama provider", () => {
284+
test("should return 'local' by default", () => {
82285
const type = getEmbeddingType(mockRuntime);
83-
expect(type).toBe("remote");
286+
expect(type).toBe("local");
84287
});
85288

86-
test("should return 'remote' for OpenAI provider", () => {
87-
mockRuntime.character.modelProvider = ModelProviderName.OPENAI;
88-
const type = getEmbeddingType(mockRuntime);
289+
test("should return 'remote' when using OpenAI", () => {
290+
const runtimeWithOpenAI = {
291+
...mockRuntime,
292+
getSetting: (key: string) => {
293+
if (key === "USE_OPENAI_EMBEDDING") return "true";
294+
return mockRuntime.getSetting(key);
295+
},
296+
} as IAgentRuntime;
297+
298+
const type = getEmbeddingType(runtimeWithOpenAI);
89299
expect(type).toBe("remote");
90300
});
91301
});

0 commit comments

Comments
 (0)