Skip to content

Commit 10dce46

Browse files
committed
start refactor
1 parent a243391 commit 10dce46

File tree

5 files changed

+136
-151
lines changed

5 files changed

+136
-151
lines changed

godot/autoloads/llama.tscn

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[gd_scene load_steps=2 format=3 uid="uid://bxobxniygk7jm"]
22

3-
[ext_resource type="LlamaModel" path="res://models/OGNO-7B-Q4_K_M.gguf" id="1_vd8h8"]
3+
[ext_resource type="LlamaModel" path="res://models/stablelm-2-zephyr-1_6b-Q8_0.gguf" id="1_h1o7k"]
44

55
[node name="LlamaContext" type="LlamaContext"]
6-
model = ExtResource("1_vd8h8")
6+
model = ExtResource("1_h1o7k")

godot/main.gd

+16-16
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,20 @@ func _on_button_pressed():
88
handle_submit()
99

1010
func handle_submit():
11-
print(input.text)
12-
Llama.request_completion(input.text)
11+
print(Llama.prompt(input.text))
12+
#print(input.text)
1313

14-
input.clear()
15-
input.editable = false
16-
submit_button.disabled = true
17-
output.text = "..."
18-
19-
var completion = await Llama.completion_generated
20-
output.text = ""
21-
while !completion[1]:
22-
print(completion[0])
23-
output.text += completion[0]
24-
completion = await Llama.completion_generated
25-
26-
input.editable = true
27-
submit_button.disabled = false
14+
#input.clear()
15+
#input.editable = false
16+
#submit_button.disabled = true
17+
#output.text = "..."
18+
#
19+
var completion = await Llama.text_generated
20+
#output.text = ""
21+
while !completion[2]:
22+
print(completion)
23+
#output.text += completion[0]
24+
completion = await Llama.text_generated
25+
#
26+
#input.editable = true
27+
#submit_button.disabled = false

llama.cpp

src/llama_context.cpp

+80-118
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <godot_cpp/classes/os.hpp>
77
#include <godot_cpp/classes/worker_thread_pool.hpp>
88
#include <godot_cpp/core/class_db.hpp>
9+
#include <godot_cpp/variant/string.hpp>
910
#include <godot_cpp/variant/utility_functions.hpp>
1011

1112
using namespace godot;
@@ -15,26 +16,18 @@ void LlamaContext::_bind_methods() {
1516
ClassDB::bind_method(D_METHOD("get_model"), &LlamaContext::get_model);
1617
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::OBJECT, "model", PROPERTY_HINT_RESOURCE_TYPE, "LlamaModel"), "set_model", "get_model");
1718

18-
ClassDB::bind_method(D_METHOD("get_seed"), &LlamaContext::get_seed);
19-
ClassDB::bind_method(D_METHOD("set_seed", "seed"), &LlamaContext::set_seed);
20-
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "seed"), "set_seed", "get_seed");
19+
ClassDB::bind_method(D_METHOD("get_seed"), &LlamaContext::get_seed);
20+
ClassDB::bind_method(D_METHOD("set_seed", "seed"), &LlamaContext::set_seed);
21+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "seed"), "set_seed", "get_seed");
2122

22-
ClassDB::bind_method(D_METHOD("get_n_ctx"), &LlamaContext::get_n_ctx);
23-
ClassDB::bind_method(D_METHOD("set_n_ctx", "n_ctx"), &LlamaContext::set_n_ctx);
24-
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_ctx"), "set_n_ctx", "get_n_ctx");
23+
ClassDB::bind_method(D_METHOD("get_n_ctx"), &LlamaContext::get_n_ctx);
24+
ClassDB::bind_method(D_METHOD("set_n_ctx", "n_ctx"), &LlamaContext::set_n_ctx);
25+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_ctx"), "set_n_ctx", "get_n_ctx");
2526

26-
ClassDB::bind_method(D_METHOD("get_n_threads"), &LlamaContext::get_n_threads);
27-
ClassDB::bind_method(D_METHOD("set_n_threads", "n_threads"), &LlamaContext::set_n_threads);
28-
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_threads"), "set_n_threads", "get_n_threads");
27+
ClassDB::bind_method(D_METHOD("prompt", "prompt", "max_new_tokens", "temperature", "top_p", "top_k", "presence_penalty", "frequency_penalty"), &LlamaContext::prompt, DEFVAL(32), DEFVAL(0.80f), DEFVAL(0.95f), DEFVAL(40), DEFVAL(0.0), DEFVAL(0.0));
28+
ClassDB::bind_method(D_METHOD("_thread_prompt_loop"), &LlamaContext::_thread_prompt_loop);
2929

30-
ClassDB::bind_method(D_METHOD("get_n_threads_batch"), &LlamaContext::get_n_threads_batch);
31-
ClassDB::bind_method(D_METHOD("set_n_threads_batch", "n_threads_batch"), &LlamaContext::set_n_threads_batch);
32-
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_threads_batch"), "set_n_threads_batch", "get_n_threads_batch");
33-
34-
ClassDB::bind_method(D_METHOD("request_completion", "prompt"), &LlamaContext::request_completion);
35-
ClassDB::bind_method(D_METHOD("_fulfill_completion", "prompt"), &LlamaContext::_fulfill_completion);
36-
37-
ADD_SIGNAL(MethodInfo("completion_generated", PropertyInfo(Variant::STRING, "completion"), PropertyInfo(Variant::BOOL, "is_final")));
30+
ADD_SIGNAL(MethodInfo("text_generated", PropertyInfo(Variant::INT, "id"), PropertyInfo(Variant::STRING, "text"), PropertyInfo(Variant::BOOL, "is_final")));
3831
}
3932

4033
LlamaContext::LlamaContext() {
@@ -47,6 +40,11 @@ LlamaContext::LlamaContext() {
4740
int32_t n_threads = OS::get_singleton()->get_processor_count();
4841
ctx_params.n_threads = n_threads;
4942
ctx_params.n_threads_batch = n_threads;
43+
44+
sampling_params = llama_sampling_params();
45+
46+
n_prompts = 0;
47+
should_exit = false;
5048
}
5149

5250
void LlamaContext::_ready() {
@@ -66,101 +64,66 @@ void LlamaContext::_ready() {
6664
return;
6765
}
6866
UtilityFunctions::print(vformat("%s: Context initialized", __func__));
69-
}
7067

71-
PackedStringArray LlamaContext::_get_configuration_warnings() const {
72-
PackedStringArray warnings;
73-
if (model == NULL) {
74-
warnings.push_back("Model resource property not defined");
75-
}
76-
return warnings;
68+
sampling_ctx = llama_sampling_init(sampling_params);
69+
70+
prompt_mutex.instantiate();
71+
prompt_semaphore.instantiate();
72+
prompt_thread.instantiate();
73+
74+
prompt_thread->start(callable_mp(this, &LlamaContext::_thread_prompt_loop));
7775
}
7876

79-
Variant LlamaContext::request_completion(const String &prompt) {
80-
UtilityFunctions::print(vformat("%s: Requesting completion for prompt: %s", __func__, prompt));
81-
if (task_id) {
82-
WorkerThreadPool::get_singleton()->wait_for_task_completion(task_id);
83-
}
84-
task_id = WorkerThreadPool::get_singleton()->add_task(Callable(this, "_fulfill_completion").bind(prompt));
85-
return OK;
77+
int LlamaContext::prompt(const String &prompt, const int max_new_tokens, const float temperature, const float top_p, const int top_k, const float presence_penalty, const float frequency_penalty) {
78+
UtilityFunctions::print(vformat("%s: Prompting with prompt: %s, max_new_tokens: %d, temperature: %f, top_p: %f, top_k: %d, presence_penalty: %f, frequency_penalty: %f", __func__, prompt, max_new_tokens, temperature, top_p, top_k, presence_penalty, frequency_penalty));
79+
prompt_mutex->lock();
80+
int id = n_prompts++;
81+
prompt_requests.push_back({ id, prompt, max_new_tokens, temperature, top_p, top_k, presence_penalty, frequency_penalty });
82+
prompt_mutex->unlock();
83+
prompt_semaphore->post();
84+
return id;
8685
}
8786

88-
void LlamaContext::_fulfill_completion(const String &prompt) {
89-
UtilityFunctions::print(vformat("%s: Fulfilling completion for prompt: %s", __func__, prompt));
90-
std::vector<llama_token> tokens_list;
91-
tokens_list = ::llama_tokenize(ctx, std::string(prompt.utf8().get_data()), true);
87+
void LlamaContext::_thread_prompt_loop() {
88+
while (true) {
89+
prompt_semaphore->wait();
9290

93-
const int n_len = 128;
94-
const int n_ctx = llama_n_ctx(ctx);
95-
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
96-
if (n_kv_req > n_ctx) {
97-
UtilityFunctions::printerr(vformat("%s: n_kv_req > n_ctx, the required KV cache size is not big enough\neither reduce n_len or increase n_ctx", __func__));
98-
return;
99-
}
91+
prompt_mutex->lock();
92+
if (should_exit) {
93+
prompt_mutex->unlock();
94+
return;
95+
}
96+
if (prompt_requests.size() == 0) {
97+
prompt_mutex->unlock();
98+
continue;
99+
}
100+
prompt_request req = prompt_requests.get(0);
101+
prompt_requests.remove_at(0);
102+
prompt_mutex->unlock();
100103

101-
for (size_t i = 0; i < tokens_list.size(); i++) {
102-
llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
103-
}
104+
UtilityFunctions::print(vformat("%s: Running prompt %d: %s, max_new_tokens: %d, temperature: %f, top_p: %f, top_k: %d, presence_penalty: %f, frequency_penalty: %f", __func__, req.id, req.prompt, req.max_new_tokens, req.temperature, req.top_p, req.top_k, req.presence_penalty, req.frequency_penalty));
104105

105-
batch.logits[batch.n_tokens - 1] = true;
106+
llama_sampling_reset(sampling_ctx);
107+
llama_batch_clear(batch);
108+
llama_kv_cache_clear(ctx);
106109

107-
llama_kv_cache_clear(ctx);
110+
auto &params = sampling_ctx->params;
111+
params.temp = req.temperature;
112+
params.top_p = req.top_p;
113+
params.top_k = req.top_k;
114+
params.penalty_present = req.presence_penalty;
115+
params.penalty_freq = req.frequency_penalty;
108116

109-
int decode_res = llama_decode(ctx, batch);
110-
if (decode_res != 0) {
111-
UtilityFunctions::printerr(vformat("%s: Failed to decode prompt with error code: %d", __func__, decode_res));
112-
return;
117+
std::vector<llama_token> tokens = ::llama_tokenize(ctx, req.prompt.utf8().get_data(), false, true);
113118
}
119+
}
114120

115-
int n_cur = batch.n_tokens;
116-
int n_decode = 0;
117-
llama_model *llama_model = model->model;
118-
119-
while (n_cur <= n_len) {
120-
// sample the next token
121-
{
122-
auto n_vocab = llama_n_vocab(llama_model);
123-
auto *logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
124-
125-
std::vector<llama_token_data> candidates;
126-
candidates.reserve(n_vocab);
127-
128-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
129-
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
130-
}
131-
132-
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
133-
134-
// sample the most likely token
135-
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
136-
137-
// is it an end of stream?
138-
if (new_token_id == llama_token_eos(llama_model) || n_cur == n_len) {
139-
call_thread_safe("emit_signal", "completion_generated", "\n", true);
140-
141-
break;
142-
}
143-
144-
call_thread_safe("emit_signal", "completion_generated", vformat("%s", llama_token_to_piece(ctx, new_token_id).c_str()), false);
145-
146-
// prepare the next batch
147-
llama_batch_clear(batch);
148-
149-
// push this new token for next evaluation
150-
llama_batch_add(batch, new_token_id, n_cur, { 0 }, true);
151-
152-
n_decode += 1;
153-
}
154-
155-
n_cur += 1;
156-
157-
// evaluate the current batch with the transformer model
158-
int decode_res = llama_decode(ctx, batch);
159-
if (decode_res != 0) {
160-
UtilityFunctions::printerr(vformat("%s: Failed to decode batch with error code: %d", __func__, decode_res));
161-
break;
162-
}
121+
PackedStringArray LlamaContext::_get_configuration_warnings() const {
122+
PackedStringArray warnings;
123+
if (model == NULL) {
124+
warnings.push_back("Model resource property not defined");
163125
}
126+
return warnings;
164127
}
165128

166129
void LlamaContext::set_model(const Ref<LlamaModel> p_model) {
@@ -173,39 +136,38 @@ Ref<LlamaModel> LlamaContext::get_model() {
173136
int LlamaContext::get_seed() {
174137
return ctx_params.seed;
175138
}
176-
void LlamaContext::set_seed(int seed) {
139+
void LlamaContext::set_seed(const int seed) {
177140
ctx_params.seed = seed;
178141
}
179142

180143
int LlamaContext::get_n_ctx() {
181144
return ctx_params.n_ctx;
182145
}
183-
void LlamaContext::set_n_ctx(int n_ctx) {
146+
void LlamaContext::set_n_ctx(const int n_ctx) {
184147
ctx_params.n_ctx = n_ctx;
185148
}
186149

187-
int LlamaContext::get_n_threads() {
188-
return ctx_params.n_threads;
189-
}
190-
void LlamaContext::set_n_threads(int n_threads) {
191-
ctx_params.n_threads = n_threads;
192-
}
150+
void LlamaContext::_exit_tree() {
151+
prompt_mutex->lock();
152+
should_exit = true;
153+
prompt_requests.clear();
154+
prompt_mutex->unlock();
193155

194-
int LlamaContext::get_n_threads_batch() {
195-
return ctx_params.n_threads_batch;
196-
}
197-
void LlamaContext::set_n_threads_batch(int n_threads_batch) {
198-
ctx_params.n_threads_batch = n_threads_batch;
199-
}
156+
prompt_semaphore->post();
157+
158+
if (prompt_thread.is_valid()) {
159+
prompt_thread->wait_to_finish();
160+
}
161+
prompt_thread.unref();
200162

201-
LlamaContext::~LlamaContext() {
202163
if (ctx) {
203164
llama_free(ctx);
204165
}
166+
if (sampling_ctx) {
167+
llama_sampling_free(sampling_ctx);
168+
}
169+
}
205170

171+
LlamaContext::~LlamaContext() {
206172
llama_batch_free(batch);
207-
208-
if (task_id) {
209-
WorkerThreadPool::get_singleton()->wait_for_task_completion(task_id);
210-
}
211173
}

src/llama_context.h

+37-14
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,46 @@
11
#ifndef LLAMA_CONTEXT_H
22
#define LLAMA_CONTEXT_H
33

4+
#include "common.h"
45
#include "llama.h"
56
#include "llama_model.h"
7+
#include <godot_cpp/classes/mutex.hpp>
68
#include <godot_cpp/classes/node.hpp>
9+
#include <godot_cpp/classes/semaphore.hpp>
10+
#include <godot_cpp/classes/thread.hpp>
11+
#include <godot_cpp/templates/vector.hpp>
712

813
namespace godot {
14+
15+
struct prompt_request {
16+
int id;
17+
String prompt;
18+
int max_new_tokens;
19+
float temperature;
20+
float top_p;
21+
int top_k;
22+
float presence_penalty;
23+
float frequency_penalty;
24+
};
25+
926
class LlamaContext : public Node {
1027
GDCLASS(LlamaContext, Node)
1128

1229
private:
1330
Ref<LlamaModel> model;
14-
llama_context *ctx = nullptr;
1531
llama_context_params ctx_params;
32+
llama_context *ctx = nullptr;
33+
llama_sampling_params sampling_params;
34+
llama_sampling_context *sampling_ctx = nullptr;
1635
llama_batch batch;
17-
int task_id;
36+
37+
Ref<Thread> prompt_thread;
38+
Ref<Mutex> prompt_mutex;
39+
Ref<Semaphore> prompt_semaphore;
40+
bool should_exit;
41+
42+
Vector<prompt_request> prompt_requests;
43+
int n_prompts;
1844

1945
protected:
2046
static void _bind_methods();
@@ -23,21 +49,18 @@ class LlamaContext : public Node {
2349
void set_model(const Ref<LlamaModel> model);
2450
Ref<LlamaModel> get_model();
2551

26-
Variant request_completion(const String &prompt);
27-
void _fulfill_completion(const String &prompt);
52+
int prompt(const String &prompt, const int max_new_tokens, const float temperature, const float top_p, const int top_k, const float presence_penalty, const float frequency_penalty);
53+
void _thread_prompt_loop();
2854

29-
int get_seed();
30-
void set_seed(int seed);
31-
int get_n_ctx();
32-
void set_n_ctx(int n_ctx);
33-
int get_n_threads();
34-
void set_n_threads(int n_threads);
35-
int get_n_threads_batch();
36-
void set_n_threads_batch(int n_threads_batch);
55+
int get_seed();
56+
void set_seed(const int seed);
57+
int get_n_ctx();
58+
void set_n_ctx(const int n_ctx);
3759

38-
virtual PackedStringArray _get_configuration_warnings() const override;
60+
virtual PackedStringArray _get_configuration_warnings() const override;
3961
virtual void _ready() override;
40-
LlamaContext();
62+
virtual void _exit_tree() override;
63+
LlamaContext();
4164
~LlamaContext();
4265
};
4366
} //namespace godot

0 commit comments

Comments
 (0)