Skip to content

Commit c06e598

Browse files
authored
Merge pull request #220 from ai16z/feat/lazy-load-llama
lazy load llama
2 parents c537cb3 + 0578cf8 commit c06e598

File tree

1 file changed

+19
-36
lines changed

1 file changed

+19
-36
lines changed

core/src/services/llama.ts

+19-36
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import fs from "fs";
1515
import https from "https";
1616
import si from "systeminformation";
1717
import { wordsToPunish } from "./wordsToPunish.ts";
18+
import { prettyConsole } from "../index.ts";
1819

1920
const __dirname = path.dirname(fileURLToPath(import.meta.url));
2021

@@ -67,28 +68,25 @@ class LlamaService {
6768
private modelInitialized: boolean = false;
6869

6970
private constructor() {
70-
console.log("Constructing");
7171
this.llama = undefined;
7272
this.model = undefined;
7373
this.modelUrl =
7474
"https://huggingface.co/NousResearch/Hermes-3-Llama-3.1-8B-GGUF/resolve/main/Hermes-3-Llama-3.1-8B.Q8_0.gguf?download=true";
7575
const modelName = "model.gguf";
76-
console.log("modelName", modelName);
7776
this.modelPath = path.join(__dirname, modelName);
78-
try {
79-
this.initializeModel();
80-
} catch (error) {
81-
console.error("Error initializing model", error);
77+
78+
}
79+
private async ensureInitialized() {
80+
if (!this.modelInitialized) {
81+
await this.initializeModel();
8282
}
8383
}
84-
8584
public static getInstance(): LlamaService {
8685
if (!LlamaService.instance) {
8786
LlamaService.instance = new LlamaService();
8887
}
8988
return LlamaService.instance;
9089
}
91-
9290
async initializeModel() {
9391
try {
9492
await this.checkModel();
@@ -99,30 +97,26 @@ class LlamaService {
9997
);
10098

10199
if (hasCUDA) {
102-
console.log("**** CUDA detected");
100+
console.log("**** LlamaService: CUDA detected");
103101
} else {
104-
console.log(
105-
"**** No CUDA detected - local response will be slow"
102+
console.warn(
103+
"**** LlamaService: No CUDA detected - local response will be slow"
106104
);
107105
}
108106

109107
this.llama = await getLlama({
110108
gpu: "cuda",
111109
});
112-
console.log("Creating grammar");
113110
const grammar = new LlamaJsonSchemaGrammar(
114111
this.llama,
115112
jsonSchemaGrammar as GbnfJsonSchema
116113
);
117114
this.grammar = grammar;
118-
console.log("Loading model");
119-
console.log("this.modelPath", this.modelPath);
120115

121116
this.model = await this.llama.loadModel({
122117
modelPath: this.modelPath,
123118
});
124-
console.log("Model GPU support", this.llama.getGpuDeviceNames());
125-
console.log("Creating context");
119+
126120
this.ctx = await this.model.createContext({ contextSize: 8192 });
127121
this.sequence = this.ctx.getSequence();
128122

@@ -139,11 +133,7 @@ class LlamaService {
139133
}
140134

141135
async checkModel() {
142-
console.log("Checking model");
143136
if (!fs.existsSync(this.modelPath)) {
144-
console.log("this.modelPath", this.modelPath);
145-
console.log("Model not found. Downloading...");
146-
147137
await new Promise<void>((resolve, reject) => {
148138
const file = fs.createWriteStream(this.modelPath);
149139
let downloadedSize = 0;
@@ -157,14 +147,9 @@ class LlamaService {
157147
if (isRedirect) {
158148
const redirectUrl = response.headers.location;
159149
if (redirectUrl) {
160-
console.log(
161-
"Following redirect to:",
162-
redirectUrl
163-
);
164150
downloadModel(redirectUrl);
165151
return;
166152
} else {
167-
console.error("Redirect URL not found");
168153
reject(new Error("Redirect URL not found"));
169154
return;
170155
}
@@ -191,7 +176,6 @@ class LlamaService {
191176

192177
response.on("end", () => {
193178
file.end();
194-
console.log("\nModel downloaded successfully.");
195179
resolve();
196180
});
197181
})
@@ -211,14 +195,13 @@ class LlamaService {
211195
});
212196
});
213197
} else {
214-
console.log("Model already exists.");
198+
prettyConsole.warn("Model already exists.");
215199
}
216200
}
217201

218202
async deleteModel() {
219203
if (fs.existsSync(this.modelPath)) {
220204
fs.unlinkSync(this.modelPath);
221-
console.log("Model deleted.");
222205
}
223206
}
224207

@@ -230,7 +213,7 @@ class LlamaService {
230213
presence_penalty: number,
231214
max_tokens: number
232215
): Promise<any> {
233-
console.log("Queueing message generateText");
216+
await this.ensureInitialized();
234217
return new Promise((resolve, reject) => {
235218
this.messageQueue.push({
236219
context,
@@ -255,13 +238,15 @@ class LlamaService {
255238
presence_penalty: number,
256239
max_tokens: number
257240
): Promise<string> {
241+
await this.ensureInitialized();
242+
258243
return new Promise((resolve, reject) => {
259244
this.messageQueue.push({
260245
context,
261246
temperature,
262247
stop,
263-
frequency_penalty,
264-
presence_penalty,
248+
frequency_penalty: frequency_penalty ?? 1.0,
249+
presence_penalty: presence_penalty ?? 1.0,
265250
max_tokens,
266251
useGrammar: false,
267252
resolve,
@@ -286,7 +271,6 @@ class LlamaService {
286271
const message = this.messageQueue.shift();
287272
if (message) {
288273
try {
289-
console.log("Processing message");
290274
const response = await this.getCompletionResponse(
291275
message.context,
292276
message.temperature,
@@ -334,7 +318,7 @@ class LlamaService {
334318
};
335319

336320
const responseTokens: Token[] = [];
337-
console.log("Evaluating tokens");
321+
338322
for await (const token of this.sequence.evaluate(tokens, {
339323
temperature: Number(temperature),
340324
repeatPenalty: repeatPenalty,
@@ -374,7 +358,6 @@ class LlamaService {
374358
// try parsing response as JSON
375359
try {
376360
jsonString = JSON.stringify(JSON.parse(response));
377-
console.log("parsedResponse", jsonString);
378361
} catch {
379362
throw new Error("JSON string not found");
380363
}
@@ -384,20 +367,19 @@ class LlamaService {
384367
if (!parsedResponse) {
385368
throw new Error("Parsed response is undefined");
386369
}
387-
console.log("AI: " + parsedResponse.content);
388370
await this.sequence.clearHistory();
389371
return parsedResponse;
390372
} catch (error) {
391373
console.error("Error parsing JSON:", error);
392374
}
393375
} else {
394-
console.log("AI: " + response);
395376
await this.sequence.clearHistory();
396377
return response;
397378
}
398379
}
399380

400381
async getEmbeddingResponse(input: string): Promise<number[] | undefined> {
382+
await this.ensureInitialized();
401383
if (!this.model) {
402384
throw new Error("Model not initialized. Call initialize() first.");
403385
}
@@ -409,3 +391,4 @@ class LlamaService {
409391
}
410392

411393
export default LlamaService;
394+

0 commit comments

Comments
 (0)