Skip to content

Commit e75b354

Browse files
committed
remove agentId from memory manager params and always use runtime agentId
add agentId in all queries
1 parent 4ace32e commit e75b354

File tree

14 files changed

+57
-86
lines changed

14 files changed

+57
-86
lines changed

packages/adapter-postgres/src/index.ts

+13-19
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ export class PostgresDatabaseAdapter
8282
return true;
8383
} catch (error) {
8484
console.error("Database connection test failed:", error);
85-
throw new Error(`Failed to connect to database: ${error.message}`);
85+
throw new Error(
86+
`Failed to connect to database: ${(error as Error).message}`
87+
);
8688
} finally {
8789
if (client) client.release();
8890
}
@@ -119,22 +121,17 @@ export class PostgresDatabaseAdapter
119121
}
120122

121123
async getMemoriesByRoomIds(params: {
124+
agentId: UUID;
122125
roomIds: UUID[];
123-
agentId?: UUID;
124126
tableName: string;
125127
}): Promise<Memory[]> {
126128
if (params.roomIds.length === 0) return [];
127129
const placeholders = params.roomIds
128-
.map((_, i) => `$${i + 2}`)
130+
.map((_, i) => `$${i + 3}`)
129131
.join(", ");
130132

131-
let query = `SELECT * FROM memories WHERE type = $1 AND "roomId" IN (${placeholders})`;
132-
let queryParams = [params.tableName, ...params.roomIds];
133-
134-
if (params.agentId) {
135-
query += ` AND "agentId" = $${params.roomIds.length + 2}`;
136-
queryParams = [...queryParams, params.agentId];
137-
}
133+
let query = `SELECT * FROM memories WHERE type = $1 AND "agentId" = $2 AND "roomId" IN (${placeholders})`;
134+
let queryParams = [params.tableName, params.agentId, ...params.roomIds];
138135

139136
const { rows } = await this.query(query, queryParams);
140137
return rows.map((row) => ({
@@ -244,6 +241,7 @@ export class PostgresDatabaseAdapter
244241
memory.embedding,
245242
{
246243
tableName,
244+
agentId: memory.agentId,
247245
roomId: memory.roomId,
248246
match_threshold: 0.95,
249247
count: 1,
@@ -272,6 +270,7 @@ export class PostgresDatabaseAdapter
272270

273271
async searchMemories(params: {
274272
tableName: string;
273+
agentId: UUID;
275274
roomId: UUID;
276275
embedding: number[];
277276
match_threshold: number;
@@ -281,6 +280,7 @@ export class PostgresDatabaseAdapter
281280
return await this.searchMemoriesByEmbedding(params.embedding, {
282281
match_threshold: params.match_threshold,
283282
count: params.match_count,
283+
agentId: params.agentId,
284284
roomId: params.roomId,
285285
unique: params.unique,
286286
tableName: params.tableName,
@@ -292,14 +292,14 @@ export class PostgresDatabaseAdapter
292292
count?: number;
293293
unique?: boolean;
294294
tableName: string;
295-
agentId?: UUID;
295+
agentId: UUID;
296296
start?: number;
297297
end?: number;
298298
}): Promise<Memory[]> {
299299
if (!params.tableName) throw new Error("tableName is required");
300300
if (!params.roomId) throw new Error("roomId is required");
301-
let sql = `SELECT * FROM memories WHERE type = $1 AND "roomId" = $2`;
302-
const values: any[] = [params.tableName, params.roomId];
301+
let sql = `SELECT * FROM memories WHERE type = $1 AND agentId = $2 AND "roomId" = $3`;
302+
const values: any[] = [params.tableName, params.agentId, params.roomId];
303303
let paramCount = 2;
304304

305305
if (params.start) {
@@ -318,12 +318,6 @@ export class PostgresDatabaseAdapter
318318
sql += ` AND "unique" = true`;
319319
}
320320

321-
if (params.agentId) {
322-
paramCount++;
323-
sql += ` AND "agentId" = $${paramCount}`;
324-
values.push(params.agentId);
325-
}
326-
327321
sql += ' ORDER BY "createdAt" DESC';
328322

329323
if (params.count) {

packages/adapter-sqlite/src/index.ts

+16-24
Original file line numberDiff line numberDiff line change
@@ -143,22 +143,17 @@ export class SqliteDatabaseAdapter
143143
}
144144

145145
async getMemoriesByRoomIds(params: {
146+
agentId: UUID;
146147
roomIds: UUID[];
147148
tableName: string;
148-
agentId?: UUID;
149149
}): Promise<Memory[]> {
150150
if (!params.tableName) {
151151
// default to messages
152152
params.tableName = "messages";
153153
}
154154
const placeholders = params.roomIds.map(() => "?").join(", ");
155-
let sql = `SELECT * FROM memories WHERE type = ? AND roomId IN (${placeholders})`;
156-
let queryParams = [params.tableName, ...params.roomIds];
157-
158-
if (params.agentId) {
159-
sql += ` AND agentId = ?`;
160-
queryParams.push(params.agentId);
161-
}
155+
let sql = `SELECT * FROM memories WHERE type = ? AND agentId = ? AND roomId IN (${placeholders})`;
156+
let queryParams = [params.tableName, params.agentId, ...params.roomIds];
162157

163158
const stmt = this.db.prepare(sql);
164159
const rows = stmt.all(...queryParams) as (Memory & {
@@ -189,8 +184,8 @@ export class SqliteDatabaseAdapter
189184

190185
async createMemory(memory: Memory, tableName: string): Promise<void> {
191186
// Delete any existing memory with the same ID first
192-
const deleteSql = `DELETE FROM memories WHERE id = ? AND type = ?`;
193-
this.db.prepare(deleteSql).run(memory.id, tableName);
187+
// const deleteSql = `DELETE FROM memories WHERE id = ? AND type = ?`;
188+
// this.db.prepare(deleteSql).run(memory.id, tableName);
194189

195190
let isUnique = true;
196191

@@ -200,6 +195,7 @@ export class SqliteDatabaseAdapter
200195
memory.embedding,
201196
{
202197
tableName,
198+
agentId: memory.agentId,
203199
roomId: memory.roomId,
204200
match_threshold: 0.95, // 5% similarity threshold
205201
count: 1,
@@ -281,7 +277,7 @@ export class SqliteDatabaseAdapter
281277
match_threshold?: number;
282278
count?: number;
283279
roomId?: UUID;
284-
agentId?: UUID;
280+
agentId: UUID;
285281
unique?: boolean;
286282
tableName: string;
287283
}
@@ -290,20 +286,17 @@ export class SqliteDatabaseAdapter
290286
// JSON.stringify(embedding),
291287
new Float32Array(embedding),
292288
params.tableName,
289+
params.agentId,
293290
];
294291

295292
let sql = `
296293
SELECT *, vec_distance_L2(embedding, ?) AS similarity
297294
FROM memories
298-
WHERE type = ?`;
295+
WHERE embedding IS NOT NULL type = ? AND agentId = ?`;
299296

300297
if (params.unique) {
301298
sql += " AND `unique` = 1";
302299
}
303-
if (params.agentId) {
304-
sql += " AND agentId = ?";
305-
queryParams.push(params.agentId);
306-
}
307300

308301
if (params.roomId) {
309302
sql += " AND roomId = ?";
@@ -418,7 +411,7 @@ export class SqliteDatabaseAdapter
418411
count?: number;
419412
unique?: boolean;
420413
tableName: string;
421-
agentId?: UUID;
414+
agentId: UUID;
422415
start?: number;
423416
end?: number;
424417
}): Promise<Memory[]> {
@@ -428,19 +421,18 @@ export class SqliteDatabaseAdapter
428421
if (!params.roomId) {
429422
throw new Error("roomId is required");
430423
}
431-
let sql = `SELECT * FROM memories WHERE type = ? AND roomId = ?`;
424+
let sql = `SELECT * FROM memories WHERE type = ? AND agentId = ? AND roomId = ?`;
432425

433-
const queryParams = [params.tableName, params.roomId] as any[];
426+
const queryParams = [
427+
params.tableName,
428+
params.agentId,
429+
params.roomId,
430+
] as any[];
434431

435432
if (params.unique) {
436433
sql += " AND `unique` = 1";
437434
}
438435

439-
if (params.agentId) {
440-
sql += " AND agentId = ?";
441-
queryParams.push(params.agentId);
442-
}
443-
444436
if (params.start) {
445437
sql += ` AND createdAt >= ?`;
446438
queryParams.push(params.start);

packages/client-discord/src/actions/summarize_conversation.ts

-1
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ const summarizeAction = {
220220
// 2. get these memories from the database
221221
const memories = await runtime.messageManager.getMemories({
222222
roomId,
223-
agentId: runtime.agentId,
224223
// subtract start from current time
225224
start: parseInt(start as string),
226225
end: parseInt(end as string),

packages/client-twitter/src/base.ts

-3
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,6 @@ export class ClientBase extends EventEmitter {
319319
// Get the existing memories from the database
320320
const existingMemories =
321321
await this.runtime.messageManager.getMemoriesByRoomIds({
322-
agentId: this.runtime.agentId,
323322
roomIds: cachedTimeline.map((tweet) =>
324323
stringToUuid(
325324
tweet.conversationId + "-" + this.runtime.agentId
@@ -462,7 +461,6 @@ export class ClientBase extends EventEmitter {
462461
// Check the existing memories in the database
463462
const existingMemories =
464463
await this.runtime.messageManager.getMemoriesByRoomIds({
465-
agentId: this.runtime.agentId,
466464
roomIds: Array.from(roomIds),
467465
});
468466

@@ -564,7 +562,6 @@ export class ClientBase extends EventEmitter {
564562
const recentMessage = await this.runtime.messageManager.getMemories(
565563
{
566564
roomId: message.roomId,
567-
agentId: this.runtime.agentId,
568565
count: 1,
569566
unique: false,
570567
}

packages/client-twitter/src/utils.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ export async function buildConversationThread(
7272
"twitter"
7373
);
7474

75-
client.runtime.messageManager.createMemory({
75+
await client.runtime.messageManager.createMemory({
7676
id: stringToUuid(
7777
currentTweet.id + "-" + client.runtime.agentId
7878
),

packages/core/src/database.ts

+4-1
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,15 @@ export abstract class DatabaseAdapter<DB = any> implements IDatabaseAdapter {
3939
* @returns A Promise that resolves to an array of Memory objects.
4040
*/
4141
abstract getMemories(params: {
42+
agentId: UUID;
4243
roomId: UUID;
4344
count?: number;
4445
unique?: boolean;
4546
tableName: string;
4647
}): Promise<Memory[]>;
4748

4849
abstract getMemoriesByRoomIds(params: {
49-
agentId?: UUID;
50+
agentId: UUID;
5051
roomIds: UUID[];
5152
tableName: string;
5253
}): Promise<Memory[]>;
@@ -105,6 +106,7 @@ export abstract class DatabaseAdapter<DB = any> implements IDatabaseAdapter {
105106
*/
106107
abstract searchMemories(params: {
107108
tableName: string;
109+
agentId: UUID;
108110
roomId: UUID;
109111
embedding: number[];
110112
match_threshold: number;
@@ -188,6 +190,7 @@ export abstract class DatabaseAdapter<DB = any> implements IDatabaseAdapter {
188190
* @returns A Promise that resolves to an array of Goal objects.
189191
*/
190192
abstract getGoals(params: {
193+
agentId: UUID;
191194
roomId: UUID;
192195
userId?: UUID | null;
193196
onlyInProgress?: boolean;

packages/core/src/goals.ts

+3
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,22 @@ import {
66
} from "./types.ts";
77

88
export const getGoals = async ({
9+
agentId,
910
runtime,
1011
roomId,
1112
userId,
1213
onlyInProgress = true,
1314
count = 5,
1415
}: {
1516
runtime: IAgentRuntime;
17+
agentId: UUID;
1618
roomId: UUID;
1719
userId?: UUID;
1820
onlyInProgress?: boolean;
1921
count?: number;
2022
}) => {
2123
return runtime.databaseAdapter.getGoals({
24+
agentId,
2225
roomId,
2326
userId,
2427
onlyInProgress,

0 commit comments

Comments
 (0)