Skip to content

Commit 9123996

Browse files
Merge pull request #472 from tarrencev/main
feat: Improve knowledge embeddings
2 parents 8450877 + 0d75334 commit 9123996

File tree

10 files changed

+201
-137
lines changed

10 files changed

+201
-137
lines changed

packages/client-discord/src/actions/summarize_conversation.ts

+6-1
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,12 @@ const summarizeAction = {
251251
const model = models[runtime.character.settings.model];
252252
const chunkSize = model.settings.maxContextLength - 1000;
253253

254-
const chunks = await splitChunks(formattedMemories, chunkSize, 0);
254+
const chunks = await splitChunks(
255+
formattedMemories,
256+
chunkSize,
257+
"gpt-4o-mini",
258+
0
259+
);
255260

256261
const datestr = new Date().toUTCString().replace(/:/g, "-");
257262

packages/client-discord/src/messages.ts

+3-2
Original file line numberDiff line numberDiff line change
@@ -430,13 +430,13 @@ export class MessageManager {
430430
await this.runtime.messageManager.createMemory(memory);
431431
}
432432

433-
let state = (await this.runtime.composeState(userMessage, {
433+
let state = await this.runtime.composeState(userMessage, {
434434
discordClient: this.client,
435435
discordMessage: message,
436436
agentName:
437437
this.runtime.character.name ||
438438
this.client.user?.displayName,
439-
})) as State;
439+
});
440440

441441
if (!canSendMessage(message.channel).canSend) {
442442
return elizaLogger.warn(
@@ -649,6 +649,7 @@ export class MessageManager {
649649
message: DiscordMessage
650650
): Promise<{ processedContent: string; attachments: Media[] }> {
651651
let processedContent = message.content;
652+
652653
let attachments: Media[] = [];
653654

654655
// Process code blocks in the message content

packages/client-github/src/index.ts

+3-42
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,8 @@ import {
1010
AgentRuntime,
1111
Client,
1212
IAgentRuntime,
13-
Content,
14-
Memory,
13+
knowledge,
1514
stringToUuid,
16-
embeddingZeroVector,
17-
splitChunks,
18-
embed,
1915
} from "@ai16z/eliza";
2016
import { validateGithubConfig } from "./enviroment";
2117

@@ -112,11 +108,8 @@ export class GitHubClient {
112108
relativePath
113109
);
114110

115-
const memory: Memory = {
111+
await knowledge.set(this.runtime, {
116112
id: knowledgeId,
117-
agentId: this.runtime.agentId,
118-
userId: this.runtime.agentId,
119-
roomId: this.runtime.agentId,
120113
content: {
121114
text: content,
122115
hash: contentHash,
@@ -128,39 +121,7 @@ export class GitHubClient {
128121
owner: this.config.owner,
129122
},
130123
},
131-
embedding: embeddingZeroVector,
132-
};
133-
134-
await this.runtime.documentsManager.createMemory(memory);
135-
136-
// Only split if content exceeds 4000 characters
137-
const fragments =
138-
content.length > 4000
139-
? await splitChunks(content, 2000, 200)
140-
: [content];
141-
142-
for (const fragment of fragments) {
143-
// Skip empty fragments
144-
if (!fragment.trim()) continue;
145-
146-
// Add file path context to the fragment before embedding
147-
const fragmentWithPath = `File: ${relativePath}\n\n${fragment}`;
148-
const embedding = await embed(this.runtime, fragmentWithPath);
149-
150-
await this.runtime.knowledgeManager.createMemory({
151-
// We namespace the knowledge base uuid to avoid id
152-
// collision with the document above.
153-
id: stringToUuid(knowledgeId + fragment),
154-
roomId: this.runtime.agentId,
155-
agentId: this.runtime.agentId,
156-
userId: this.runtime.agentId,
157-
content: {
158-
source: knowledgeId,
159-
text: fragment,
160-
},
161-
embedding,
162-
});
163-
}
124+
});
164125
}
165126
}
166127

packages/core/src/generation.ts

+15-11
Original file line numberDiff line numberDiff line change
@@ -463,34 +463,38 @@ export async function generateShouldRespond({
463463
* Splits content into chunks of specified size with optional overlapping bleed sections
464464
* @param content - The text content to split into chunks
465465
* @param chunkSize - The maximum size of each chunk in tokens
466-
* @param bleed - Number of characters to overlap between chunks (default: 100)
467466
* @param model - The model name to use for tokenization (default: runtime.model)
467+
* @param bleed - Number of characters to overlap between chunks (default: 100)
468468
* @returns Promise resolving to array of text chunks with bleed sections
469469
*/
470470
export async function splitChunks(
471471
content: string,
472472
chunkSize: number,
473+
model: string,
473474
bleed: number = 100
474475
): Promise<string[]> {
475-
const encoding = encoding_for_model("gpt-4o-mini");
476-
476+
const encoding = encoding_for_model(model as TiktokenModel);
477477
const tokens = encoding.encode(content);
478478
const chunks: string[] = [];
479479
const textDecoder = new TextDecoder();
480480

481481
for (let i = 0; i < tokens.length; i += chunkSize) {
482-
const chunk = tokens.slice(i, i + chunkSize);
483-
const decodedChunk = textDecoder.decode(encoding.decode(chunk));
482+
let chunk = tokens.slice(i, i + chunkSize);
484483

485484
// Append bleed characters from the previous chunk
486-
const startBleed = i > 0 ? content.slice(i - bleed, i) : "";
485+
if (i > 0) {
486+
chunk = new Uint32Array([...tokens.slice(i - bleed, i), ...chunk]);
487+
}
488+
487489
// Append bleed characters from the next chunk
488-
const endBleed =
489-
i + chunkSize < tokens.length
490-
? content.slice(i + chunkSize, i + chunkSize + bleed)
491-
: "";
490+
if (i + chunkSize < tokens.length) {
491+
chunk = new Uint32Array([
492+
...chunk,
493+
...tokens.slice(i + chunkSize, i + chunkSize + bleed),
494+
]);
495+
}
492496

493-
chunks.push(startBleed + decodedChunk + endBleed);
497+
chunks.push(textDecoder.decode(encoding.decode(chunk)));
494498
}
495499

496500
return chunks;

packages/core/src/index.ts

+1
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ export * from "./parsing.ts";
2020
export * from "./uuid.ts";
2121
export * from "./enviroment.ts";
2222
export * from "./cache.ts";
23+
export { default as knowledge } from "./knowledge.ts";

packages/core/src/knowledge.ts

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import { UUID } from "crypto";
2+
3+
import { AgentRuntime } from "./runtime.ts";
4+
import { embed } from "./embedding.ts";
5+
import { Content, ModelClass, type Memory } from "./types.ts";
6+
import { stringToUuid } from "./uuid.ts";
7+
import { embeddingZeroVector } from "./memory.ts";
8+
import { splitChunks } from "./generation.ts";
9+
import { models } from "./models.ts";
10+
import elizaLogger from "./logger.ts";
11+
12+
async function get(runtime: AgentRuntime, message: Memory): Promise<string[]> {
13+
const processed = preprocess(message.content.text);
14+
elizaLogger.log(`Querying knowledge for: ${processed}`);
15+
const embedding = await embed(runtime, processed);
16+
const fragments = await runtime.knowledgeManager.searchMemoriesByEmbedding(
17+
embedding,
18+
{
19+
roomId: message.agentId,
20+
agentId: message.agentId,
21+
count: 3,
22+
match_threshold: 0.1,
23+
}
24+
);
25+
26+
const uniqueSources = [
27+
...new Set(
28+
fragments.map((memory) => {
29+
elizaLogger.log(
30+
`Matched fragment: ${memory.content.text} with similarity: ${message.similarity}`
31+
);
32+
return memory.content.source;
33+
})
34+
),
35+
];
36+
37+
const knowledgeDocuments = await Promise.all(
38+
uniqueSources.map((source) =>
39+
runtime.documentsManager.getMemoryById(source as UUID)
40+
)
41+
);
42+
43+
const knowledge = knowledgeDocuments
44+
.filter((memory) => memory !== null)
45+
.map((memory) => memory.content.text);
46+
return knowledge;
47+
}
48+
49+
export type KnowledgeItem = {
50+
id: UUID;
51+
content: Content;
52+
};
53+
54+
async function set(runtime: AgentRuntime, item: KnowledgeItem) {
55+
await runtime.documentsManager.createMemory({
56+
embedding: embeddingZeroVector,
57+
id: item.id,
58+
agentId: runtime.agentId,
59+
roomId: runtime.agentId,
60+
userId: runtime.agentId,
61+
createdAt: Date.now(),
62+
content: item.content,
63+
});
64+
65+
const preprocessed = preprocess(item.content.text);
66+
const fragments = await splitChunks(
67+
preprocessed,
68+
10,
69+
models[runtime.character.modelProvider].model?.[ModelClass.EMBEDDING],
70+
5
71+
);
72+
73+
for (const fragment of fragments) {
74+
const embedding = await embed(runtime, fragment);
75+
await runtime.knowledgeManager.createMemory({
76+
// We namespace the knowledge base uuid to avoid id
77+
// collision with the document above.
78+
id: stringToUuid(item.id + fragment),
79+
roomId: runtime.agentId,
80+
agentId: runtime.agentId,
81+
userId: runtime.agentId,
82+
createdAt: Date.now(),
83+
content: {
84+
source: item.id,
85+
text: fragment,
86+
},
87+
embedding,
88+
});
89+
}
90+
}
91+
92+
export function preprocess(content: string): string {
93+
return (
94+
content
95+
// Remove code blocks and their content
96+
.replace(/```[\s\S]*?```/g, "")
97+
// Remove inline code
98+
.replace(/`.*?`/g, "")
99+
// Convert headers to plain text with emphasis
100+
.replace(/#{1,6}\s*(.*)/g, "$1")
101+
// Remove image links but keep alt text
102+
.replace(/!\[(.*?)\]\(.*?\)/g, "$1")
103+
// Remove links but keep text
104+
.replace(/\[(.*?)\]\(.*?\)/g, "$1")
105+
// Remove HTML tags
106+
.replace(/<[^>]*>/g, "")
107+
// Remove horizontal rules
108+
.replace(/^\s*[-*_]{3,}\s*$/gm, "")
109+
// Remove comments
110+
.replace(/\/\*[\s\S]*?\*\//g, "")
111+
.replace(/\/\/.*/g, "")
112+
// Normalize whitespace
113+
.replace(/\s+/g, " ")
114+
// Remove multiple newlines
115+
.replace(/\n{3,}/g, "\n\n")
116+
// strip all special characters
117+
.replace(/[^a-zA-Z0-9\s]/g, "")
118+
// Remove Discord mentions
119+
.replace(/<@!?\d+>/g, "")
120+
.trim()
121+
.toLowerCase()
122+
);
123+
}
124+
125+
export default {
126+
get,
127+
set,
128+
process,
129+
};

0 commit comments

Comments
 (0)