Skip to content

Commit 9a04908

Browse files
authored
Merge pull request #246 from Oneirocom/google-and-system-prompt
Support google models in generation
2 parents ce4d327 + c7e9bf0 commit 9a04908

File tree

7 files changed

+39
-40
lines changed

7 files changed

+39
-40
lines changed

.env.example

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ OPENAI_API_KEY=sk-* # OpenAI API key, starting with sk-
55
REDPILL_API_KEY= # REDPILL API Key
66
GROQ_API_KEY=gsk_*
77
OPENROUTER_API_KEY=
8+
GOOGLE_GENERATIVE_AI_API_KEY= # Gemini API key
89

910
ELEVENLABS_XI_API_KEY= # API key from elevenlabs
1011

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ DISCORD_APPLICATION_ID=
8383
DISCORD_API_TOKEN= # Bot token
8484
OPENAI_API_KEY=sk-* # OpenAI API key, starting with sk-
8585
ELEVENLABS_XI_API_KEY= # API key from elevenlabs
86+
GOOGLE_GENERATIVE_AI_API_KEY= # Gemini API key
8687
8788
# ELEVENLABS SETTINGS
8889
ELEVENLABS_MODEL_ID=eleven_multilingual_v2

packages/core/src/generation.ts

+25-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import { default as tiktoken, TiktokenModel } from "tiktoken";
1111
import Together from "together-ai";
1212
import { elizaLogger } from "./index.ts";
1313
import models from "./models.ts";
14+
import { createGoogleGenerativeAI } from "@ai-sdk/google";
1415
import {
1516
parseBooleanFromText,
1617
parseJsonArrayFromText,
@@ -104,6 +105,25 @@ export async function generateText({
104105
break;
105106
}
106107

108+
case ModelProviderName.GOOGLE:
109+
const google = createGoogleGenerativeAI();
110+
111+
const { text: anthropicResponse } = await aiGenerateText({
112+
model: google(model),
113+
prompt: context,
114+
system:
115+
runtime.character.system ??
116+
settings.SYSTEM_PROMPT ??
117+
undefined,
118+
temperature: temperature,
119+
maxTokens: max_response_length,
120+
frequencyPenalty: frequency_penalty,
121+
presencePenalty: presence_penalty,
122+
});
123+
124+
response = anthropicResponse;
125+
break;
126+
107127
case ModelProviderName.ANTHROPIC: {
108128
elizaLogger.log("Initializing Anthropic model.");
109129

@@ -214,7 +234,6 @@ export async function generateText({
214234
break;
215235
}
216236

217-
218237
case ModelProviderName.OPENROUTER: {
219238
elizaLogger.log("Initializing OpenRouter model.");
220239
const serverUrl = models[provider].endpoint;
@@ -238,7 +257,6 @@ export async function generateText({
238257
break;
239258
}
240259

241-
242260
case ModelProviderName.OLLAMA:
243261
{
244262
console.log("Initializing Ollama model.");
@@ -425,10 +443,13 @@ export async function generateTrueOrFalse({
425443
modelClass: string;
426444
}): Promise<boolean> {
427445
let retryDelay = 1000;
428-
console.log("modelClass", modelClass)
446+
console.log("modelClass", modelClass);
429447

430448
const stop = Array.from(
431-
new Set([...(models[runtime.modelProvider].settings.stop || []), ["\n"]])
449+
new Set([
450+
...(models[runtime.modelProvider].settings.stop || []),
451+
["\n"],
452+
])
432453
) as string[];
433454

434455
while (true) {

packages/core/src/models.ts

+4-5
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,9 @@ const models: Models = {
137137
temperature: 0.3,
138138
},
139139
model: {
140-
[ModelClass.SMALL]: "gemini-1.5-flash",
141-
[ModelClass.MEDIUM]: "gemini-1.5-flash",
142-
[ModelClass.LARGE]: "gemini-1.5-pro",
140+
[ModelClass.SMALL]: "gemini-1.5-flash-latest",
141+
[ModelClass.MEDIUM]: "gemini-1.5-flash-latest",
142+
[ModelClass.LARGE]: "gemini-1.5-pro-latest",
143143
[ModelClass.EMBEDDING]: "text-embedding-004",
144144
},
145145
},
@@ -187,8 +187,7 @@ const models: Models = {
187187
settings.LARGE_OPENROUTER_MODEL ||
188188
settings.OPENROUTER_MODEL ||
189189
"nousresearch/hermes-3-llama-3.1-405b",
190-
[ModelClass.EMBEDDING]:
191-
"text-embedding-3-small",
190+
[ModelClass.EMBEDDING]: "text-embedding-3-small",
192191
},
193192
},
194193
[ModelProviderName.OLLAMA]: {

packages/core/src/runtime.ts

+3-3
Original file line numberDiff line numberDiff line change
@@ -498,14 +498,14 @@ export class AgentRuntime implements IAgentRuntime {
498498
* @returns The results of the evaluation.
499499
*/
500500
async evaluate(message: Memory, state?: State, didRespond?: boolean) {
501-
console.log("Evaluate: ", didRespond)
501+
console.log("Evaluate: ", didRespond);
502502
const evaluatorPromises = this.evaluators.map(
503503
async (evaluator: Evaluator) => {
504-
console.log("Evaluating", evaluator.name)
504+
console.log("Evaluating", evaluator.name);
505505
if (!evaluator.handler) {
506506
return null;
507507
}
508-
if(!didRespond && !evaluator.alwaysRun) {
508+
if (!didRespond && !evaluator.alwaysRun) {
509509
return null;
510510
}
511511
const result = await evaluator.validate(this, message, state);

packages/core/src/types.ts

+5-1
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,11 @@ export interface IAgentRuntime {
550550
state?: State,
551551
callback?: HandlerCallback
552552
): Promise<void>;
553-
evaluate(message: Memory, state?: State, didRespond?: boolean): Promise<string[]>;
553+
evaluate(
554+
message: Memory,
555+
state?: State,
556+
didRespond?: boolean
557+
): Promise<string[]>;
554558
ensureParticipantExists(userId: UUID, roomId: UUID): Promise<void>;
555559
ensureUserExists(
556560
userId: UUID,

pnpm-lock.yaml

-27
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)