Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/lazy load llama #219

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@ tweets/
*.onnx
*.wav
*.mp3

*.code-workspace
logs/

test-report.json
55 changes: 19 additions & 36 deletions core/src/services/llama.ts
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@ import fs from "fs";
import https from "https";
import si from "systeminformation";
import { wordsToPunish } from "./wordsToPunish.ts";
import { prettyConsole } from "../index.ts";

const __dirname = path.dirname(fileURLToPath(import.meta.url));

@@ -67,28 +68,25 @@ class LlamaService {
private modelInitialized: boolean = false;

private constructor() {
console.log("Constructing");
this.llama = undefined;
this.model = undefined;
this.modelUrl =
"https://huggingface.co/NousResearch/Hermes-3-Llama-3.1-8B-GGUF/resolve/main/Hermes-3-Llama-3.1-8B.Q8_0.gguf?download=true";
const modelName = "model.gguf";
console.log("modelName", modelName);
this.modelPath = path.join(__dirname, modelName);
try {
this.initializeModel();
} catch (error) {
console.error("Error initializing model", error);

}
private async ensureInitialized() {
if (!this.modelInitialized) {
await this.initializeModel();
}
}

public static getInstance(): LlamaService {
if (!LlamaService.instance) {
LlamaService.instance = new LlamaService();
}
return LlamaService.instance;
}

async initializeModel() {
try {
await this.checkModel();
@@ -99,30 +97,26 @@ class LlamaService {
);

if (hasCUDA) {
console.log("**** CUDA detected");
console.log("**** LlamaService: CUDA detected");
} else {
console.log(
"**** No CUDA detected - local response will be slow"
console.warn(
"**** LlamaService: No CUDA detected - local response will be slow"
);
}

this.llama = await getLlama({
gpu: "cuda",
});
console.log("Creating grammar");
const grammar = new LlamaJsonSchemaGrammar(
this.llama,
jsonSchemaGrammar as GbnfJsonSchema
);
this.grammar = grammar;
console.log("Loading model");
console.log("this.modelPath", this.modelPath);

this.model = await this.llama.loadModel({
modelPath: this.modelPath,
});
console.log("Model GPU support", this.llama.getGpuDeviceNames());
console.log("Creating context");

this.ctx = await this.model.createContext({ contextSize: 8192 });
this.sequence = this.ctx.getSequence();

@@ -139,11 +133,7 @@ class LlamaService {
}

async checkModel() {
console.log("Checking model");
if (!fs.existsSync(this.modelPath)) {
console.log("this.modelPath", this.modelPath);
console.log("Model not found. Downloading...");

await new Promise<void>((resolve, reject) => {
const file = fs.createWriteStream(this.modelPath);
let downloadedSize = 0;
@@ -157,14 +147,9 @@ class LlamaService {
if (isRedirect) {
const redirectUrl = response.headers.location;
if (redirectUrl) {
console.log(
"Following redirect to:",
redirectUrl
);
downloadModel(redirectUrl);
return;
} else {
console.error("Redirect URL not found");
reject(new Error("Redirect URL not found"));
return;
}
@@ -191,7 +176,6 @@ class LlamaService {

response.on("end", () => {
file.end();
console.log("\nModel downloaded successfully.");
resolve();
});
})
@@ -211,14 +195,13 @@ class LlamaService {
});
});
} else {
console.log("Model already exists.");
prettyConsole.warn("Model already exists.");
}
}

async deleteModel() {
if (fs.existsSync(this.modelPath)) {
fs.unlinkSync(this.modelPath);
console.log("Model deleted.");
}
}

@@ -230,7 +213,7 @@ class LlamaService {
presence_penalty: number,
max_tokens: number
): Promise<any> {
console.log("Queueing message generateText");
await this.ensureInitialized();
return new Promise((resolve, reject) => {
this.messageQueue.push({
context,
@@ -255,13 +238,15 @@ class LlamaService {
presence_penalty: number,
max_tokens: number
): Promise<string> {
await this.ensureInitialized();

return new Promise((resolve, reject) => {
this.messageQueue.push({
context,
temperature,
stop,
frequency_penalty,
presence_penalty,
frequency_penalty: frequency_penalty ?? 1.0,
presence_penalty: presence_penalty ?? 1.0,
max_tokens,
useGrammar: false,
resolve,
@@ -286,7 +271,6 @@ class LlamaService {
const message = this.messageQueue.shift();
if (message) {
try {
console.log("Processing message");
const response = await this.getCompletionResponse(
message.context,
message.temperature,
@@ -334,7 +318,7 @@ class LlamaService {
};

const responseTokens: Token[] = [];
console.log("Evaluating tokens");

for await (const token of this.sequence.evaluate(tokens, {
temperature: Number(temperature),
repeatPenalty: repeatPenalty,
@@ -374,7 +358,6 @@ class LlamaService {
// try parsing response as JSON
try {
jsonString = JSON.stringify(JSON.parse(response));
console.log("parsedResponse", jsonString);
} catch {
throw new Error("JSON string not found");
}
@@ -384,20 +367,19 @@ class LlamaService {
if (!parsedResponse) {
throw new Error("Parsed response is undefined");
}
console.log("AI: " + parsedResponse.content);
await this.sequence.clearHistory();
return parsedResponse;
} catch (error) {
console.error("Error parsing JSON:", error);
}
} else {
console.log("AI: " + response);
await this.sequence.clearHistory();
return response;
}
}

async getEmbeddingResponse(input: string): Promise<number[] | undefined> {
await this.ensureInitialized();
if (!this.model) {
throw new Error("Model not initialized. Call initialize() first.");
}
@@ -409,3 +391,4 @@ class LlamaService {
}

export default LlamaService;

65 changes: 37 additions & 28 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,31 +1,40 @@
{
"name": "eliza",
"scripts": {
"preinstall": "npx only-allow pnpm",
"build": "pnpm --dir core build",
"build-docs": "pnpm --dir docs build",
"start:all": "pnpm --dir core start:all --isRoot",
"stop:all": "pnpm --dir core stop:all --isRoot",
"start:service:all": "pnpm --dir core start:service:all --isRoot",
"stop:service:all": "pnpm --dir core stop:service:all --isRoot",
"start": "pnpm --dir core start --isRoot",
"dev": "pnpm --dir core dev --isRoot",
"lint": "pnpm --dir core lint",
"prettier-check": "npx prettier --check .",
"prettier": "npx prettier --write .",
"clean": "bash ./scripts/clean.sh"
},
"devDependencies": {
"husky": "^9.1.6",
"lerna": "^8.1.5",
"only-allow": "^1.2.1",
"prettier": "^3.3.3",
"typedoc": "^0.26.11"
},
"engines": {
"node": ">=22"
},
"dependencies": {
"typescript": "5.6.3"
"name": "eliza",
"version": "1.0.0",
"scripts": {
"preinstall": "npx only-allow pnpm",
"build": "pnpm --dir core build",
"build-docs": "pnpm --dir docs build",
"start:all": "pnpm --dir core start:all",
"stop:all": "pnpm --dir core stop:all",
"start:service:all": "pnpm --dir core start:service:all",
"stop:service:all": "pnpm --dir core stop:service:all",
"start": "pnpm --dir core start",
"dev": "pnpm --dir core dev",
"lint": "pnpm --dir core lint",
"prettier-check": "npx prettier --check .",
"prettier": "npx prettier --write .",
"clean": "bash ./scripts/clean.sh"
},
"dependencies": {
"onnxruntime-node": "^1.20.0",
"optional": "^0.1.4",
"sharp": "^0.33.5",
"typescript": "5.6.3"
},
"devDependencies": {
"husky": "^9.1.6",
"lerna": "^8.1.5",
"only-allow": "^1.2.1",
"prettier": "^3.3.3",
"typedoc": "^0.26.11"
},
"engines": {
"node": ">=22"
},
"pnpm": {
"overrides": {
"onnxruntime-node": "1.20.0"
}
}
}
Loading