Skip to content

Commit d5d1781

Browse files
committedMar 6, 2024
Fix LlamaContext initialization and batch handling
1 parent 34d83c2 commit d5d1781

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed
 

‎src/llama_context.cpp

+11-5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ void LlamaContext::_ready() {
3030
return;
3131
}
3232

33+
ctx_params.seed = -1;
3334
ctx_params.n_ctx = 2048;
3435
int32_t n_threads = OS::get_singleton()->get_processor_count();
3536
ctx_params.n_threads = n_threads;
@@ -45,9 +46,9 @@ void LlamaContext::_ready() {
4546

4647
Variant LlamaContext::request_completion(const String &prompt) {
4748
UtilityFunctions::print(vformat("%s: Requesting completion for prompt: %s", __func__, prompt));
48-
if (task_id) {
49-
WorkerThreadPool::get_singleton()->wait_for_task_completion(task_id);
50-
}
49+
if (task_id) {
50+
WorkerThreadPool::get_singleton()->wait_for_task_completion(task_id);
51+
}
5152
task_id = WorkerThreadPool::get_singleton()->add_task(Callable(this, "_fulfill_completion").bind(prompt));
5253
return OK;
5354
}
@@ -65,9 +66,12 @@ void LlamaContext::_fulfill_completion(const String &prompt) {
6566
return;
6667
}
6768

69+
llama_batch batch = llama_batch_init(tokens_list.size(), 0, 1);
70+
6871
for (size_t i = 0; i < tokens_list.size(); i++) {
6972
llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
7073
}
74+
7175
batch.logits[batch.n_tokens - 1] = true;
7276

7377
int decode_res = llama_decode(ctx, batch);
@@ -79,6 +83,7 @@ void LlamaContext::_fulfill_completion(const String &prompt) {
7983
int n_cur = batch.n_tokens;
8084
int n_decode = 0;
8185
llama_model *llama_model = model->model;
86+
8287
while (n_cur <= n_len) {
8388
// sample the next token
8489
{
@@ -121,9 +126,11 @@ void LlamaContext::_fulfill_completion(const String &prompt) {
121126
int decode_res = llama_decode(ctx, batch);
122127
if (decode_res != 0) {
123128
UtilityFunctions::printerr(vformat("%s: Failed to decode batch with error code: %d", __func__, decode_res));
124-
return;
129+
break;
125130
}
126131
}
132+
133+
llama_batch_free(batch);
127134
}
128135

129136
void LlamaContext::set_model(const Ref<LlamaModel> p_model) {
@@ -138,7 +145,6 @@ LlamaContext::~LlamaContext() {
138145
if (ctx) {
139146
llama_free(ctx);
140147
}
141-
llama_batch_free(batch);
142148
if (task_id) {
143149
WorkerThreadPool::get_singleton()->wait_for_task_completion(task_id);
144150
}

‎src/llama_context.h

-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ class LlamaContext : public Node {
1313
Ref<LlamaModel> model;
1414
llama_context *ctx = nullptr;
1515
llama_context_params ctx_params = llama_context_default_params();
16-
llama_batch batch = llama_batch_init(512, 0, 1);
1716
int task_id;
1817

1918
protected:

0 commit comments

Comments
 (0)