Skip to content

Commit eedf278

Browse files
authored
Merge pull request #2137 from GravitonINC/mistral-text-generation
feat: Add Mistral AI as new model provider
2 parents 5bade12 + a50c24f commit eedf278

File tree

8 files changed

+126
-7
lines changed

8 files changed

+126
-7
lines changed

.env.example

+6
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,12 @@ MEDIUM_GOOGLE_MODEL= # Default: gemini-1.5-flash-latest
184184
LARGE_GOOGLE_MODEL= # Default: gemini-1.5-pro-latest
185185
EMBEDDING_GOOGLE_MODEL= # Default: text-embedding-004
186186

187+
# Mistral Configuration
188+
MISTRAL_MODEL=
189+
SMALL_MISTRAL_MODEL= # Default: mistral-small-latest
190+
MEDIUM_MISTRAL_MODEL= # Default: mistral-large-latest
191+
LARGE_MISTRAL_MODEL= # Default: mistral-large-latest
192+
187193
# Groq Configuration
188194
GROQ_API_KEY= # Starts with gsk_
189195
SMALL_GROQ_MODEL= # Default: llama-3.1-8b-instant

agent/src/index.ts

+5
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,11 @@ export function getTokenForProvider(
398398
character.settings?.secrets?.GOOGLE_GENERATIVE_AI_API_KEY ||
399399
settings.GOOGLE_GENERATIVE_AI_API_KEY
400400
);
401+
case ModelProviderName.MISTRAL:
402+
return (
403+
character.settings?.secrets?.MISTRAL_API_KEY ||
404+
settings.MISTRAL_API_KEY
405+
);
401406
case ModelProviderName.LETZAI:
402407
return (
403408
character.settings?.secrets?.LETZAI_API_KEY ||

docs/docs/advanced/fine-tuning.md

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ enum ModelProviderName {
2222
LLAMACLOUD,
2323
LLAMALOCAL,
2424
GOOGLE,
25+
MISTRAL,
2526
REDPILL,
2627
OPENROUTER,
2728
HEURIST,

packages/core/package.json

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
"@ai-sdk/google": "0.0.55",
7070
"@ai-sdk/google-vertex": "0.0.43",
7171
"@ai-sdk/groq": "0.0.3",
72+
"@ai-sdk/mistral": "^1.0.8",
7273
"@ai-sdk/openai": "1.0.5",
7374
"@anthropic-ai/sdk": "0.30.1",
7475
"@fal-ai/client": "1.2.0",

packages/core/src/generation.ts

+49
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { createAnthropic } from "@ai-sdk/anthropic";
22
import { createGoogleGenerativeAI } from "@ai-sdk/google";
3+
import { createMistral } from "@ai-sdk/mistral";
34
import { createGroq } from "@ai-sdk/groq";
45
import { createOpenAI } from "@ai-sdk/openai";
56
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
@@ -528,6 +529,27 @@ export async function generateText({
528529
break;
529530
}
530531

532+
case ModelProviderName.MISTRAL: {
533+
const mistral = createMistral();
534+
535+
const { text: mistralResponse } = await aiGenerateText({
536+
model: mistral(model),
537+
prompt: context,
538+
system:
539+
runtime.character.system ??
540+
settings.SYSTEM_PROMPT ??
541+
undefined,
542+
temperature: temperature,
543+
maxTokens: max_response_length,
544+
frequencyPenalty: frequency_penalty,
545+
presencePenalty: presence_penalty,
546+
});
547+
548+
response = mistralResponse;
549+
elizaLogger.debug("Received response from Mistral model.");
550+
break;
551+
}
552+
531553
case ModelProviderName.ANTHROPIC: {
532554
elizaLogger.debug("Initializing Anthropic model with Cloudflare check");
533555
const baseURL = getCloudflareGatewayBaseURL(runtime, 'anthropic') || "https://api.anthropic.com/v1";
@@ -1863,6 +1885,8 @@ export async function handleProvider(
18631885
});
18641886
case ModelProviderName.GOOGLE:
18651887
return await handleGoogle(options);
1888+
case ModelProviderName.MISTRAL:
1889+
return await handleMistral(options);
18661890
case ModelProviderName.REDPILL:
18671891
return await handleRedPill(options);
18681892
case ModelProviderName.OPENROUTER:
@@ -2019,6 +2043,31 @@ async function handleGoogle({
20192043
});
20202044
}
20212045

2046+
/**
2047+
* Handles object generation for Mistral models.
2048+
*
2049+
* @param {ProviderOptions} options - Options specific to Mistral.
2050+
* @returns {Promise<GenerateObjectResult<unknown>>} - A promise that resolves to generated objects.
2051+
*/
2052+
async function handleMistral({
2053+
model,
2054+
schema,
2055+
schemaName,
2056+
schemaDescription,
2057+
mode,
2058+
modelOptions,
2059+
}: ProviderOptions): Promise<GenerateObjectResult<unknown>> {
2060+
const mistral = createMistral();
2061+
return await aiGenerateObject({
2062+
model: mistral(model),
2063+
schema,
2064+
schemaName,
2065+
schemaDescription,
2066+
mode,
2067+
...modelOptions,
2068+
});
2069+
}
2070+
20222071
/**
20232072
* Handles object generation for Redpill models.
20242073
*

packages/core/src/models.ts

+40
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,46 @@ export const models: Models = {
378378
},
379379
},
380380
},
381+
[ModelProviderName.MISTRAL]: {
382+
model: {
383+
[ModelClass.SMALL]: {
384+
name:
385+
settings.SMALL_MISTRAL_MODEL ||
386+
settings.MISTRAL_MODEL ||
387+
"mistral-small-latest",
388+
stop: [],
389+
maxInputTokens: 128000,
390+
maxOutputTokens: 8192,
391+
frequency_penalty: 0.4,
392+
presence_penalty: 0.4,
393+
temperature: 0.7,
394+
},
395+
[ModelClass.MEDIUM]: {
396+
name:
397+
settings.MEDIUM_MISTRAL_MODEL ||
398+
settings.MISTRAL_MODEL ||
399+
"mistral-large-latest",
400+
stop: [],
401+
maxInputTokens: 128000,
402+
maxOutputTokens: 8192,
403+
frequency_penalty: 0.4,
404+
presence_penalty: 0.4,
405+
temperature: 0.7,
406+
},
407+
[ModelClass.LARGE]: {
408+
name:
409+
settings.LARGE_MISTRAL_MODEL ||
410+
settings.MISTRAL_MODEL ||
411+
"mistral-large-latest",
412+
stop: [],
413+
maxInputTokens: 128000,
414+
maxOutputTokens: 8192,
415+
frequency_penalty: 0.4,
416+
presence_penalty: 0.4,
417+
temperature: 0.7,
418+
}
419+
},
420+
},
381421
[ModelProviderName.REDPILL]: {
382422
endpoint: "https://api.red-pill.ai/v1",
383423
// Available models: https://docs.red-pill.ai/get-started/supported-models

packages/core/src/types.ts

+2
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ export type Models = {
210210
[ModelProviderName.TOGETHER]: Model;
211211
[ModelProviderName.LLAMALOCAL]: Model;
212212
[ModelProviderName.GOOGLE]: Model;
213+
[ModelProviderName.MISTRAL]: Model;
213214
[ModelProviderName.CLAUDE_VERTEX]: Model;
214215
[ModelProviderName.REDPILL]: Model;
215216
[ModelProviderName.OPENROUTER]: Model;
@@ -242,6 +243,7 @@ export enum ModelProviderName {
242243
TOGETHER = "together",
243244
LLAMALOCAL = "llama_local",
244245
GOOGLE = "google",
246+
MISTRAL = "mistral",
245247
CLAUDE_VERTEX = "claude_vertex",
246248
REDPILL = "redpill",
247249
OPENROUTER = "openrouter",

pnpm-lock.yaml

+22-7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)