Skip to content

Commit 40c7e86

Browse files
Merge pull request #496 from ai16z/feat/embeddings
fix: improve embeddings
2 parents 5b920ec + abd2c43 commit 40c7e86

File tree

13 files changed

+458
-134
lines changed

13 files changed

+458
-134
lines changed

.github/workflows/pre-release.yml

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
name: Release
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
workflow_dispatch:
8+
inputs:
9+
release_type:
10+
description: "Type of release (prerelease, prepatch, patch, minor, preminor, major)"
11+
required: true
12+
default: "prerelease"
13+
14+
jobs:
15+
release:
16+
runs-on: ubuntu-latest
17+
steps:
18+
- uses: actions/checkout@v4
19+
with:
20+
fetch-depth: 0
21+
22+
- uses: pnpm/action-setup@v3
23+
with:
24+
version: 8
25+
26+
- name: Configure Git
27+
run: |
28+
git config user.name "${{ github.actor }}"
29+
git config user.email "${{ github.actor }}@users.noreply.github.com"
30+
31+
- name: "Setup npm for npmjs"
32+
run: |
33+
npm config set registry https://registry.npmjs.org/
34+
echo "//registry.npmjs.org/:_authToken=${{ secrets.NPM_TOKEN }}" > ~/.npmrc
35+
36+
- name: Install Protobuf Compiler
37+
run: sudo apt-get install -y protobuf-compiler
38+
39+
- name: Install dependencies
40+
run: pnpm install
41+
42+
- name: Build packages
43+
run: pnpm run build
44+
45+
- name: Tag and Publish Packages
46+
id: tag_publish
47+
run: |
48+
RELEASE_TYPE=${{ github.event_name == 'push' && 'prerelease' || github.event.inputs.release_type }}
49+
npx lerna version $RELEASE_TYPE --conventional-commits --yes --no-private --force-publish
50+
npx lerna publish from-git --yes --dist-tag next
51+
52+
- name: Get Version Tag
53+
id: get_tag
54+
run: echo "TAG=$(git describe --tags --abbrev=0)" >> $GITHUB_OUTPUT
55+
56+
- name: Generate Release Body
57+
id: release_body
58+
run: |
59+
if [ -f CHANGELOG.md ]; then
60+
echo "body=$(cat CHANGELOG.md)" >> $GITHUB_OUTPUT
61+
else
62+
echo "body=No changelog provided for this release." >> $GITHUB_OUTPUT
63+
fi
64+
65+
- name: Create GitHub Release
66+
uses: actions/create-release@v1
67+
env:
68+
GITHUB_TOKEN: ${{ secrets.GH_TOKEN }}
69+
PNPM_HOME: /home/runner/setup-pnpm/node_modules/.bin
70+
with:
71+
tag_name: ${{ steps.get_tag.outputs.TAG }}
72+
release_name: Release
73+
body_path: CHANGELOG.md
74+
draft: false
75+
prerelease: ${{ github.event_name == 'push' }}

packages/adapter-postgres/src/index.ts

+49-10
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import {
1111
type IDatabaseCacheAdapter,
1212
Participant,
1313
DatabaseAdapter,
14+
elizaLogger,
1415
} from "@ai16z/eliza";
1516
import fs from "fs";
1617
import { fileURLToPath } from "url";
@@ -28,15 +29,50 @@ export class PostgresDatabaseAdapter
2829
constructor(connectionConfig: any) {
2930
super();
3031

31-
this.pool = new pg.Pool({
32-
...connectionConfig,
32+
const defaultConfig = {
3333
max: 20,
3434
idleTimeoutMillis: 30000,
3535
connectionTimeoutMillis: 2000,
36+
};
37+
38+
this.pool = new pg.Pool({
39+
...defaultConfig,
40+
...connectionConfig, // Allow overriding defaults
3641
});
3742

38-
this.pool.on("error", (err) => {
39-
console.error("Unexpected error on idle client", err);
43+
this.pool.on("error", async (err) => {
44+
elizaLogger.error("Unexpected error on idle client", err);
45+
46+
// Attempt to reconnect with exponential backoff
47+
let retryCount = 0;
48+
const maxRetries = 5;
49+
const baseDelay = 1000; // Start with 1 second delay
50+
51+
while (retryCount < maxRetries) {
52+
try {
53+
const delay = baseDelay * Math.pow(2, retryCount);
54+
elizaLogger.log(`Attempting to reconnect in ${delay}ms...`);
55+
await new Promise((resolve) => setTimeout(resolve, delay));
56+
57+
// Create new pool with same config
58+
this.pool = new pg.Pool(this.pool.options);
59+
await this.testConnection();
60+
61+
elizaLogger.log("Successfully reconnected to database");
62+
return;
63+
} catch (error) {
64+
retryCount++;
65+
elizaLogger.error(
66+
`Reconnection attempt ${retryCount} failed:`,
67+
error
68+
);
69+
}
70+
}
71+
72+
elizaLogger.error(
73+
`Failed to reconnect after ${maxRetries} attempts`
74+
);
75+
throw new Error("Database connection lost and unable to reconnect");
4076
});
4177
}
4278

@@ -51,7 +87,7 @@ export class PostgresDatabaseAdapter
5187
);
5288
await client.query(schema);
5389
} catch (error) {
54-
console.error(error);
90+
elizaLogger.error(error);
5591
throw error;
5692
}
5793
}
@@ -61,10 +97,13 @@ export class PostgresDatabaseAdapter
6197
try {
6298
client = await this.pool.connect();
6399
const result = await client.query("SELECT NOW()");
64-
console.log("Database connection test successful:", result.rows[0]);
100+
elizaLogger.log(
101+
"Database connection test successful:",
102+
result.rows[0]
103+
);
65104
return true;
66105
} catch (error) {
67-
console.error("Database connection test failed:", error);
106+
elizaLogger.error("Database connection test failed:", error);
68107
throw new Error(`Failed to connect to database: ${error.message}`);
69108
} finally {
70109
if (client) client.release();
@@ -187,7 +226,7 @@ export class PostgresDatabaseAdapter
187226
if (rows.length === 0) return null;
188227

189228
const account = rows[0];
190-
console.log("account", account);
229+
elizaLogger.log("account", account);
191230
return {
192231
...account,
193232
details:
@@ -217,7 +256,7 @@ export class PostgresDatabaseAdapter
217256
);
218257
return true;
219258
} catch (error) {
220-
console.log("Error creating account", error);
259+
elizaLogger.log("Error creating account", error);
221260
return false;
222261
} finally {
223262
client.release();
@@ -370,7 +409,7 @@ export class PostgresDatabaseAdapter
370409
values.push(params.count);
371410
}
372411

373-
console.log("sql", sql, values);
412+
elizaLogger.log("sql", sql, values);
374413

375414
const { rows } = await client.query(sql, values);
376415
return rows.map((row) => ({

packages/client-discord/src/actions/summarize_conversation.ts

+1-6
Original file line numberDiff line numberDiff line change
@@ -251,12 +251,7 @@ const summarizeAction = {
251251
const model = models[runtime.character.settings.model];
252252
const chunkSize = model.settings.maxContextLength - 1000;
253253

254-
const chunks = await splitChunks(
255-
formattedMemories,
256-
chunkSize,
257-
"gpt-4o-mini",
258-
0
259-
);
254+
const chunks = await splitChunks(formattedMemories, chunkSize, 0);
260255

261256
const datestr = new Date().toUTCString().replace(/:/g, "-");
262257

packages/core/package.json

+1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
"gaxios": "6.7.1",
7171
"glob": "11.0.0",
7272
"js-sha1": "0.7.0",
73+
"langchain": "^0.3.6",
7374
"ollama-ai-provider": "^0.16.1",
7475
"openai": "4.69.0",
7576
"tiktoken": "1.0.17",

packages/core/src/embedding.ts

+30-14
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ async function getRemoteEmbedding(
3737
: {}),
3838
},
3939
body: JSON.stringify({
40-
input,
40+
input: trimTokens(input, 8191, "gpt-4o-mini"),
4141
model: options.model,
4242
length: options.length || 384,
4343
}),
@@ -70,25 +70,39 @@ async function getRemoteEmbedding(
7070
* @param input The input to be embedded.
7171
* @returns The embedding of the input.
7272
*/
73+
/**
74+
* Generate embeddings for input text using configured model provider
75+
* @param runtime The agent runtime containing model configuration
76+
* @param input The text to generate embeddings for
77+
* @returns Array of embedding numbers
78+
*/
7379
export async function embed(runtime: IAgentRuntime, input: string) {
80+
// Get model provider configuration
7481
const modelProvider = models[runtime.character.modelProvider];
75-
//need to have env override for this to select what to use for embedding if provider doesnt provide or using openai
82+
83+
// Determine which embedding model to use:
84+
// 1. OpenAI if USE_OPENAI_EMBEDDING is true
85+
// 2. Provider's own embedding model if available
86+
// 3. Fallback to OpenAI embedding model
7687
const embeddingModel = settings.USE_OPENAI_EMBEDDING
77-
? "text-embedding-3-small" // Use OpenAI if specified
78-
: modelProvider.model?.[ModelClass.EMBEDDING] || // Use provider's embedding model if available
79-
models[ModelProviderName.OPENAI].model[ModelClass.EMBEDDING]; // Fallback to OpenAI
88+
? "text-embedding-3-small"
89+
: modelProvider.model?.[ModelClass.EMBEDDING] ||
90+
models[ModelProviderName.OPENAI].model[ModelClass.EMBEDDING];
8091

8192
if (!embeddingModel) {
8293
throw new Error("No embedding model configured");
8394
}
8495

85-
// // Try local embedding first
86-
// Check if we're in Node.js environment
96+
// Check if running in Node.js environment
8797
const isNode =
8898
typeof process !== "undefined" &&
8999
process.versions != null &&
90100
process.versions.node != null;
91101

102+
// Use local embedding if:
103+
// - Running in Node.js
104+
// - Not using OpenAI provider
105+
// - Not forcing OpenAI embeddings
92106
if (
93107
isNode &&
94108
runtime.character.modelProvider !== ModelProviderName.OPENAI &&
@@ -97,28 +111,30 @@ export async function embed(runtime: IAgentRuntime, input: string) {
97111
return await getLocalEmbedding(input);
98112
}
99113

100-
// Check cache
114+
// Try to get cached embedding first
101115
const cachedEmbedding = await retrieveCachedEmbedding(runtime, input);
102116
if (cachedEmbedding) {
103117
return cachedEmbedding;
104118
}
105119

106-
// Get remote embedding
120+
// Generate new embedding remotely
107121
return await getRemoteEmbedding(input, {
108122
model: embeddingModel,
123+
// Use OpenAI endpoint if specified, otherwise use provider endpoint
109124
endpoint: settings.USE_OPENAI_EMBEDDING
110-
? "https://api.openai.com/v1" // Always use OpenAI endpoint when USE_OPENAI_EMBEDDING is true
125+
? "https://api.openai.com/v1"
111126
: runtime.character.modelEndpointOverride || modelProvider.endpoint,
127+
// Use OpenAI API key if specified, otherwise use runtime token
112128
apiKey: settings.USE_OPENAI_EMBEDDING
113-
? settings.OPENAI_API_KEY // Use OpenAI key from settings when USE_OPENAI_EMBEDDING is true
114-
: runtime.token, // Use runtime token for other providers
129+
? settings.OPENAI_API_KEY
130+
: runtime.token,
131+
// Special handling for Ollama provider
115132
isOllama:
116133
runtime.character.modelProvider === ModelProviderName.OLLAMA &&
117134
!settings.USE_OPENAI_EMBEDDING,
118135
});
119136
}
120137

121-
// TODO: Add back in when it can work in browser and locally
122138
async function getLocalEmbedding(input: string): Promise<number[]> {
123139
// Check if we're in Node.js environment
124140
const isNode =
@@ -153,7 +169,7 @@ async function getLocalEmbedding(input: string): Promise<number[]> {
153169
cacheDir: cacheDir,
154170
});
155171

156-
const trimmedInput = trimTokens(input, 8000, "gpt-4o-mini");
172+
const trimmedInput = trimTokens(input, 8191, "gpt-4o-mini");
157173
const embedding = await embeddingModel.queryEmbed(trimmedInput);
158174
return embedding;
159175
} else {

0 commit comments

Comments
 (0)