Skip to content

Commit ae56659

Browse files
authored
Merge pull request #1944 from shlokkhemani/add-embedding-tests
chore: add embedding tests
2 parents 533eb51 + e6a6ff2 commit ae56659

File tree

1 file changed

+201
-0
lines changed

1 file changed

+201
-0
lines changed
+201
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import { describe, test, expect, vi, beforeEach } from "vitest";
2+
import {
3+
embed,
4+
getEmbeddingConfig,
5+
getEmbeddingType,
6+
getEmbeddingZeroVector,
7+
} from "../embedding.ts";
8+
import { IAgentRuntime, ModelProviderName } from "../types.ts";
9+
import settings from "../settings.ts";
10+
11+
// Mock environment-related settings
12+
vi.mock("../settings", () => ({
13+
default: {
14+
USE_OPENAI_EMBEDDING: "false",
15+
USE_OLLAMA_EMBEDDING: "false",
16+
USE_GAIANET_EMBEDDING: "false",
17+
OPENAI_API_KEY: "mock-openai-key",
18+
OPENAI_API_URL: "https://api.openai.com/v1",
19+
GAIANET_API_KEY: "mock-gaianet-key",
20+
OLLAMA_EMBEDDING_MODEL: "mxbai-embed-large",
21+
GAIANET_EMBEDDING_MODEL: "nomic-embed",
22+
},
23+
}));
24+
25+
// Mock fastembed module for local embeddings
26+
vi.mock("fastembed", () => ({
27+
FlagEmbedding: {
28+
init: vi.fn().mockResolvedValue({
29+
queryEmbed: vi
30+
.fn()
31+
.mockResolvedValue(new Float32Array(384).fill(0.1)),
32+
}),
33+
},
34+
EmbeddingModel: {
35+
BGESmallENV15: "BGE-small-en-v1.5",
36+
},
37+
}));
38+
39+
// Mock global fetch for remote embedding requests
40+
const mockFetch = vi.fn();
41+
(global as any).fetch = mockFetch;
42+
43+
describe("Embedding Module", () => {
44+
let mockRuntime: IAgentRuntime;
45+
46+
beforeEach(() => {
47+
// Prepare a mock runtime
48+
mockRuntime = {
49+
character: {
50+
modelProvider: ModelProviderName.OLLAMA,
51+
modelEndpointOverride: null,
52+
},
53+
token: "mock-token",
54+
messageManager: {
55+
getCachedEmbeddings: vi.fn().mockResolvedValue([]),
56+
},
57+
} as unknown as IAgentRuntime;
58+
59+
vi.clearAllMocks();
60+
mockFetch.mockReset();
61+
});
62+
63+
describe("getEmbeddingConfig", () => {
64+
test("should return BGE config by default", () => {
65+
const config = getEmbeddingConfig();
66+
expect(config.dimensions).toBe(384);
67+
expect(config.model).toBe("BGE-small-en-v1.5");
68+
expect(config.provider).toBe("BGE");
69+
});
70+
71+
test("should return OpenAI config when USE_OPENAI_EMBEDDING is true", () => {
72+
vi.mocked(settings).USE_OPENAI_EMBEDDING = "true";
73+
const config = getEmbeddingConfig();
74+
expect(config.dimensions).toBe(1536);
75+
expect(config.model).toBe("text-embedding-3-small");
76+
expect(config.provider).toBe("OpenAI");
77+
});
78+
});
79+
80+
describe("getEmbeddingType", () => {
81+
test("should return 'remote' for Ollama provider", () => {
82+
const type = getEmbeddingType(mockRuntime);
83+
expect(type).toBe("remote");
84+
});
85+
86+
test("should return 'remote' for OpenAI provider", () => {
87+
mockRuntime.character.modelProvider = ModelProviderName.OPENAI;
88+
const type = getEmbeddingType(mockRuntime);
89+
expect(type).toBe("remote");
90+
});
91+
});
92+
93+
describe("getEmbeddingZeroVector", () => {
94+
beforeEach(() => {
95+
vi.mocked(settings).USE_OPENAI_EMBEDDING = "false";
96+
vi.mocked(settings).USE_OLLAMA_EMBEDDING = "false";
97+
vi.mocked(settings).USE_GAIANET_EMBEDDING = "false";
98+
});
99+
100+
test("should return 384-length zero vector by default (BGE)", () => {
101+
const vector = getEmbeddingZeroVector();
102+
expect(vector).toHaveLength(384);
103+
expect(vector.every((val) => val === 0)).toBe(true);
104+
});
105+
106+
test("should return 1536-length zero vector for OpenAI if enabled", () => {
107+
vi.mocked(settings).USE_OPENAI_EMBEDDING = "true";
108+
const vector = getEmbeddingZeroVector();
109+
expect(vector).toHaveLength(1536);
110+
expect(vector.every((val) => val === 0)).toBe(true);
111+
});
112+
});
113+
114+
describe("embed function", () => {
115+
beforeEach(() => {
116+
// Mock a successful remote response with an example 384-dim embedding
117+
mockFetch.mockResolvedValue({
118+
ok: true,
119+
json: () =>
120+
Promise.resolve({
121+
data: [{ embedding: new Array(384).fill(0.1) }],
122+
}),
123+
});
124+
});
125+
126+
test("should return an empty array for empty input text", async () => {
127+
const result = await embed(mockRuntime, "");
128+
expect(result).toEqual([]);
129+
});
130+
131+
test("should return cached embedding if it already exists", async () => {
132+
const cachedEmbedding = new Array(384).fill(0.5);
133+
mockRuntime.messageManager.getCachedEmbeddings = vi
134+
.fn()
135+
.mockResolvedValue([{ embedding: cachedEmbedding }]);
136+
137+
const result = await embed(mockRuntime, "test input");
138+
expect(result).toBe(cachedEmbedding);
139+
});
140+
141+
test("should handle local embedding successfully (fastembed fallback)", async () => {
142+
// By default, it tries local first if in Node.
143+
// Then uses the mock fastembed response above.
144+
const result = await embed(mockRuntime, "test input");
145+
expect(result).toHaveLength(384);
146+
expect(result.every((v) => typeof v === "number")).toBe(true);
147+
});
148+
149+
test("should fallback to remote if local embedding fails", async () => {
150+
// Force fastembed import to fail
151+
vi.mock("fastembed", () => {
152+
throw new Error("Module not found");
153+
});
154+
155+
// Mock a valid remote response
156+
const mockResponse = {
157+
ok: true,
158+
json: () =>
159+
Promise.resolve({
160+
data: [{ embedding: new Array(384).fill(0.1) }],
161+
}),
162+
};
163+
mockFetch.mockResolvedValueOnce(mockResponse);
164+
165+
const result = await embed(mockRuntime, "test input");
166+
expect(result).toHaveLength(384);
167+
expect(mockFetch).toHaveBeenCalled();
168+
});
169+
170+
test("should throw on remote embedding if fetch fails", async () => {
171+
mockFetch.mockRejectedValueOnce(new Error("API Error"));
172+
vi.mocked(settings).USE_OPENAI_EMBEDDING = "true"; // Force remote
173+
174+
await expect(embed(mockRuntime, "test input")).rejects.toThrow(
175+
"API Error"
176+
);
177+
});
178+
179+
test("should throw on non-200 remote response", async () => {
180+
const errorResponse = {
181+
ok: false,
182+
status: 400,
183+
statusText: "Bad Request",
184+
text: () => Promise.resolve("Invalid input"),
185+
};
186+
mockFetch.mockResolvedValueOnce(errorResponse);
187+
vi.mocked(settings).USE_OPENAI_EMBEDDING = "true"; // Force remote
188+
189+
await expect(embed(mockRuntime, "test input")).rejects.toThrow(
190+
"Embedding API Error"
191+
);
192+
});
193+
194+
test("should handle concurrent embedding requests", async () => {
195+
const promises = Array(5)
196+
.fill(null)
197+
.map(() => embed(mockRuntime, "concurrent test"));
198+
await expect(Promise.all(promises)).resolves.toBeDefined();
199+
});
200+
});
201+
});

0 commit comments

Comments
 (0)