From 79287330039f338f89358b7762caa703c5ec6c15 Mon Sep 17 00:00:00 2001 From: zainuldeen <78583049+Zain-ul-din@users.noreply.github.com> Date: Fri, 22 Nov 2024 23:56:07 +0500 Subject: [PATCH] (#71) --- package.json | 2 ++ src/index.ts | 2 +- src/models/CustomModel.ts | 11 +++++++++-- src/models/GeminiModel.ts | 20 +++++++++++++++----- src/models/OpenAIModel.ts | 7 ++++++- src/types/Config.d.ts | 1 + 6 files changed, 34 insertions(+), 9 deletions(-) diff --git a/package.json b/package.json index e80104e..30d5388 100644 --- a/package.json +++ b/package.json @@ -12,6 +12,7 @@ "build": "npx yarn" }, "devDependencies": { + "@types/invariant": "^2.2.37", "@types/node": "^18.14.0", "@types/node-fetch": "^2.6.2", "prettier": "^2.8.4", @@ -24,6 +25,7 @@ "@whiskeysockets/baileys": "^6.7.7", "@whiskeysockets/libsignal-node": "github:WhiskeySockets/libsignal-node", "dotenv": "^16.0.3", + "invariant": "^2.2.4", "mongo-baileys": "^1.0.1", "mongodb": "^6.8.0", "openai": "^4.56.0", diff --git a/src/index.ts b/src/index.ts index f57cbc6..540ebec 100644 --- a/src/index.ts +++ b/src/index.ts @@ -7,5 +7,5 @@ // whatsappClient.messageEvent.on('self', welcomeUser); -import { connectToWhatsApp } from "./baileys"; +import { connectToWhatsApp } from './baileys'; connectToWhatsApp(); diff --git a/src/models/CustomModel.ts b/src/models/CustomModel.ts index e6c404a..f2f243a 100644 --- a/src/models/CustomModel.ts +++ b/src/models/CustomModel.ts @@ -25,13 +25,14 @@ class CustomAIModel extends AIModel { this.chatGPTModel = new ChatGPTModel(); } - private static constructInstructAblePrompt({ + private constructInstructAblePrompt({ prompt, instructions }: { prompt: string; instructions: string; }) { + if (!this.self.dangerouslyAllowFewShotApproach) return prompt; return ` ${instructions} @@ -46,16 +47,22 @@ prompt: public async sendMessage({ prompt, ...rest }: AIArguments, handle: AIHandle) { try { const instructions = await CustomAIModel.readContext(this.self); - const promptWithInstructions = CustomAIModel.constructInstructAblePrompt({ + const promptWithInstructions = this.constructInstructAblePrompt({ prompt, instructions: instructions }); switch (this.selectedBaseModel) { case 'ChatGPT': + this.chatGPTModel.instructions = !this.self.dangerouslyAllowFewShotApproach + ? instructions + : undefined; await this.chatGPTModel.sendMessage({ prompt: promptWithInstructions, ...rest }, handle); break; case 'Gemini': + this.geminiModel.instructions = !this.self.dangerouslyAllowFewShotApproach + ? instructions + : undefined; await this.geminiModel.sendMessage({ prompt: promptWithInstructions, ...rest }, handle); break; } diff --git a/src/models/GeminiModel.ts b/src/models/GeminiModel.ts index 734046f..ac7f8fb 100644 --- a/src/models/GeminiModel.ts +++ b/src/models/GeminiModel.ts @@ -10,20 +10,19 @@ import { downloadMediaMessage } from '@whiskeysockets/baileys'; /* Local modules */ import { AIModel, AIArguments, AIHandle, AIMetaData } from './BaseAiModel'; import { ENV } from '../baileys/env'; +import invariant from 'invariant'; /* Gemini Model */ class GeminiModel extends AIModel { /* Variables */ - private generativeModel: GenerativeModel; + private generativeModel?: GenerativeModel; private Gemini: GoogleGenerativeAI; public chats: { [from: string]: ChatSession }; + public instructions: string | undefined; public constructor() { super(ENV.API_KEY_GEMINI, 'Gemini', ENV.GEMINI_ICON_PREFIX); this.Gemini = new GoogleGenerativeAI(ENV.API_KEY_GEMINI as string); - - // https://ai.google.dev/gemini-api/docs/models/gemini - this.generativeModel = this.Gemini.getGenerativeModel({ model: 'gemini-1.5-flash' }); this.chats = {}; } @@ -31,7 +30,7 @@ class GeminiModel extends AIModel { public async generateCompletion(user: string, prompt: string): Promise { if (!this.sessionExists(user)) { this.sessionCreate(user); - this.chats[user] = this.generativeModel.startChat(); + this.chats[user] = this.generativeModel!.startChat(); } const chat = this.chats[user]; @@ -47,7 +46,18 @@ class GeminiModel extends AIModel { }; } + private initGenerativeModel() { + // https://ai.google.dev/gemini-api/docs/models/gemini + this.generativeModel = this.Gemini.getGenerativeModel({ + model: 'gemini-1.5-flash', + systemInstruction: this.instructions + }); + } + public async generateImageCompletion(prompt: string, metadata: AIMetaData): Promise { + this.initGenerativeModel(); + invariant(this.generativeModel, 'Unable to initialize Gemini Generative model'); + const { mimeType } = metadata.quoteMetaData.imgMetaData; if (mimeType === 'image/jpeg') { const buffer = await downloadMediaMessage( diff --git a/src/models/OpenAIModel.ts b/src/models/OpenAIModel.ts index 5fcdd94..55eb1e3 100644 --- a/src/models/OpenAIModel.ts +++ b/src/models/OpenAIModel.ts @@ -24,6 +24,8 @@ class ChatGPTModel extends AIModel { private OpenAI: OpenAI; public DalleSize: DalleSizeImage; + public instructions: string | undefined = undefined; + public constructor() { super(ENV.API_KEY_OPENAI, 'ChatGPT'); @@ -42,7 +44,10 @@ class ChatGPTModel extends AIModel { /* Methods */ public async generateCompletion(user: string): Promise { const completion = await this.OpenAI.chat.completions.create({ - messages: this.history[user], + messages: [ + ...(this.instructions ? [{ role: 'system', content: this.instructions }] : []), + ...this.history[user] + ], model: ENV.OPENAI_MODEL }); diff --git a/src/types/Config.d.ts b/src/types/Config.d.ts index 4c69812..3a0d648 100644 --- a/src/types/Config.d.ts +++ b/src/types/Config.d.ts @@ -18,6 +18,7 @@ export interface IModelType extends IModelConfig { context: string; includeSender?: boolean; baseModel: SupportedBaseModels; + dangerouslyAllowFewShotApproach?: boolean; } export interface IDefaultConfig {