Skip to content

Commit 4e149e2

Browse files
committed
start refactor
1 parent a243391 commit 4e149e2

File tree

5 files changed

+146
-134
lines changed

5 files changed

+146
-134
lines changed

build.zig

-2
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,6 @@ pub fn build(b: *std.Build) !void {
251251
extension.linkFramework("MetalKit");
252252
extension.linkFramework("Foundation");
253253
extension.linkFramework("Accelerate");
254-
// b.installFile("llama.cpp/ggml-metal.metal", b.pathJoin(&.{ std.fs.path.basename(b.lib_dir), "ggml-metal.metal" }));
255-
// b.installFile("llama.cpp/ggml-common.h", b.pathJoin(&.{ std.fs.path.basename(b.lib_dir), "ggml-common.h" }));
256254
} else {
257255
if (target.result.os.tag == .windows) {
258256
const vk_path = b.graph.env_map.get("VK_SDK_PATH") orelse @panic("VK_SDK_PATH not set");

godot/main.gd

+6-6
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,19 @@ func _on_button_pressed():
99

1010
func handle_submit():
1111
print(input.text)
12-
Llama.request_completion(input.text)
12+
Llama.prompt(input.text)
1313

1414
input.clear()
1515
input.editable = false
1616
submit_button.disabled = true
1717
output.text = "..."
1818

19-
var completion = await Llama.completion_generated
19+
var completion = await Llama.prompt_generated
2020
output.text = ""
21-
while !completion[1]:
22-
print(completion[0])
23-
output.text += completion[0]
24-
completion = await Llama.completion_generated
21+
# while !completion[1]:
22+
# print(completion[0])
23+
# output.text += completion[0]
24+
# completion = await Llama.prompt_generated
2525

2626
input.editable = true
2727
submit_button.disabled = false

llama.cpp

src/llama_context.cpp

+105-110
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
#include "llama_model.h"
55
#include <godot_cpp/classes/engine.hpp>
66
#include <godot_cpp/classes/os.hpp>
7-
#include <godot_cpp/classes/worker_thread_pool.hpp>
7+
#include <godot_cpp/classes/semaphore.hpp>
8+
#include <godot_cpp/classes/thread.hpp>
89
#include <godot_cpp/core/class_db.hpp>
910
#include <godot_cpp/variant/utility_functions.hpp>
1011

@@ -15,29 +16,42 @@ void LlamaContext::_bind_methods() {
1516
ClassDB::bind_method(D_METHOD("get_model"), &LlamaContext::get_model);
1617
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::OBJECT, "model", PROPERTY_HINT_RESOURCE_TYPE, "LlamaModel"), "set_model", "get_model");
1718

18-
ClassDB::bind_method(D_METHOD("get_seed"), &LlamaContext::get_seed);
19-
ClassDB::bind_method(D_METHOD("set_seed", "seed"), &LlamaContext::set_seed);
20-
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "seed"), "set_seed", "get_seed");
19+
ClassDB::bind_method(D_METHOD("get_seed"), &LlamaContext::get_seed);
20+
ClassDB::bind_method(D_METHOD("set_seed", "seed"), &LlamaContext::set_seed);
21+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "seed"), "set_seed", "get_seed");
2122

22-
ClassDB::bind_method(D_METHOD("get_n_ctx"), &LlamaContext::get_n_ctx);
23-
ClassDB::bind_method(D_METHOD("set_n_ctx", "n_ctx"), &LlamaContext::set_n_ctx);
24-
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_ctx"), "set_n_ctx", "get_n_ctx");
23+
ClassDB::bind_method(D_METHOD("get_n_ctx"), &LlamaContext::get_n_ctx);
24+
ClassDB::bind_method(D_METHOD("set_n_ctx", "n_ctx"), &LlamaContext::set_n_ctx);
25+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_ctx"), "set_n_ctx", "get_n_ctx");
2526

26-
ClassDB::bind_method(D_METHOD("get_n_threads"), &LlamaContext::get_n_threads);
27-
ClassDB::bind_method(D_METHOD("set_n_threads", "n_threads"), &LlamaContext::set_n_threads);
28-
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_threads"), "set_n_threads", "get_n_threads");
27+
ClassDB::bind_method(D_METHOD("get_temperature"), &LlamaContext::get_temperature);
28+
ClassDB::bind_method(D_METHOD("set_temperature", "temperature"), &LlamaContext::set_temperature);
29+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "temperature"), "set_temperature", "get_temperature");
2930

30-
ClassDB::bind_method(D_METHOD("get_n_threads_batch"), &LlamaContext::get_n_threads_batch);
31-
ClassDB::bind_method(D_METHOD("set_n_threads_batch", "n_threads_batch"), &LlamaContext::set_n_threads_batch);
32-
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_threads_batch"), "set_n_threads_batch", "get_n_threads_batch");
31+
ClassDB::bind_method(D_METHOD("get_top_p"), &LlamaContext::get_top_p);
32+
ClassDB::bind_method(D_METHOD("set_top_p", "top_p"), &LlamaContext::set_top_p);
33+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "top_p"), "set_top_p", "get_top_p");
3334

34-
ClassDB::bind_method(D_METHOD("request_completion", "prompt"), &LlamaContext::request_completion);
35-
ClassDB::bind_method(D_METHOD("_fulfill_completion", "prompt"), &LlamaContext::_fulfill_completion);
35+
ClassDB::bind_method(D_METHOD("get_top_k"), &LlamaContext::get_top_k);
36+
ClassDB::bind_method(D_METHOD("set_top_k", "top_k"), &LlamaContext::set_top_k);
37+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "top_k"), "set_top_k", "get_top_k");
3638

37-
ADD_SIGNAL(MethodInfo("completion_generated", PropertyInfo(Variant::STRING, "completion"), PropertyInfo(Variant::BOOL, "is_final")));
39+
ClassDB::bind_method(D_METHOD("get_presence_penalty"), &LlamaContext::get_presence_penalty);
40+
ClassDB::bind_method(D_METHOD("set_presence_penalty", "presence_penalty"), &LlamaContext::set_presence_penalty);
41+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "presence_penalty"), "set_presence_penalty", "get_presence_penalty");
42+
43+
ClassDB::bind_method(D_METHOD("get_frequency_penalty"), &LlamaContext::get_frequency_penalty);
44+
ClassDB::bind_method(D_METHOD("set_frequency_penalty", "frequency_penalty"), &LlamaContext::set_frequency_penalty);
45+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "frequency_penalty"), "set_frequency_penalty", "get_frequency_penalty");
46+
47+
ClassDB::bind_method(D_METHOD("prompt", "prompt", "max_new_tokens"), &LlamaContext::prompt);
48+
ClassDB::bind_method(D_METHOD("_run_prompts"), &LlamaContext::_run_prompts);
49+
50+
ADD_SIGNAL(MethodInfo("prompt_completion", PropertyInfo(Variant::STRING, "prompt_id"), PropertyInfo(Variant::STRING, "completion"), PropertyInfo(Variant::BOOL, "is_final")));
3851
}
3952

40-
LlamaContext::LlamaContext() {
53+
LlamaContext::LlamaContext() :
54+
sampling_params() {
4155
batch = llama_batch_init(4096, 0, 1);
4256

4357
ctx_params = llama_context_default_params();
@@ -66,100 +80,57 @@ void LlamaContext::_ready() {
6680
return;
6781
}
6882
UtilityFunctions::print(vformat("%s: Context initialized", __func__));
69-
}
70-
71-
PackedStringArray LlamaContext::_get_configuration_warnings() const {
72-
PackedStringArray warnings;
73-
if (model == NULL) {
74-
warnings.push_back("Model resource property not defined");
75-
}
76-
return warnings;
77-
}
78-
79-
Variant LlamaContext::request_completion(const String &prompt) {
80-
UtilityFunctions::print(vformat("%s: Requesting completion for prompt: %s", __func__, prompt));
81-
if (task_id) {
82-
WorkerThreadPool::get_singleton()->wait_for_task_completion(task_id);
83-
}
84-
task_id = WorkerThreadPool::get_singleton()->add_task(Callable(this, "_fulfill_completion").bind(prompt));
85-
return OK;
86-
}
87-
88-
void LlamaContext::_fulfill_completion(const String &prompt) {
89-
UtilityFunctions::print(vformat("%s: Fulfilling completion for prompt: %s", __func__, prompt));
90-
std::vector<llama_token> tokens_list;
91-
tokens_list = ::llama_tokenize(ctx, std::string(prompt.utf8().get_data()), true);
9283

93-
const int n_len = 128;
94-
const int n_ctx = llama_n_ctx(ctx);
95-
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
96-
if (n_kv_req > n_ctx) {
97-
UtilityFunctions::printerr(vformat("%s: n_kv_req > n_ctx, the required KV cache size is not big enough\neither reduce n_len or increase n_ctx", __func__));
98-
return;
99-
}
100-
101-
for (size_t i = 0; i < tokens_list.size(); i++) {
102-
llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
103-
}
84+
sampling_ctx = llama_sampling_init(sampling_params);
10485

105-
batch.logits[batch.n_tokens - 1] = true;
86+
semaphore.instantiate();
87+
mutex.instantiate();
88+
worker_thread.instantiate();
10689

107-
llama_kv_cache_clear(ctx);
90+
worker_thread->start(Callable(this, "_run_prompts"));
91+
}
10892

109-
int decode_res = llama_decode(ctx, batch);
110-
if (decode_res != 0) {
111-
UtilityFunctions::printerr(vformat("%s: Failed to decode prompt with error code: %d", __func__, decode_res));
112-
return;
93+
PackedStringArray LlamaContext::_get_configuration_warnings() const {
94+
PackedStringArray warnings;
95+
if (model == NULL) {
96+
warnings.push_back("Model resource property not defined");
11397
}
98+
return warnings;
99+
}
114100

115-
int n_cur = batch.n_tokens;
116-
int n_decode = 0;
117-
llama_model *llama_model = model->model;
118-
119-
while (n_cur <= n_len) {
120-
// sample the next token
121-
{
122-
auto n_vocab = llama_n_vocab(llama_model);
123-
auto *logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
124-
125-
std::vector<llama_token_data> candidates;
126-
candidates.reserve(n_vocab);
127-
128-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
129-
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
130-
}
131-
132-
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
133-
134-
// sample the most likely token
135-
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
136-
137-
// is it an end of stream?
138-
if (new_token_id == llama_token_eos(llama_model) || n_cur == n_len) {
139-
call_thread_safe("emit_signal", "completion_generated", "\n", true);
101+
int LlamaContext::prompt(const String &prompt, int max_new_tokens) {
102+
mutex->lock();
103+
int prompt_id = n_prompts++;
104+
prompts.push_back(prompt);
105+
mutex->unlock();
140106

141-
break;
142-
}
107+
semaphore->post();
143108

144-
call_thread_safe("emit_signal", "completion_generated", vformat("%s", llama_token_to_piece(ctx, new_token_id).c_str()), false);
109+
UtilityFunctions::print(vformat("New prompt %d: %s", prompt_id, prompt));
145110

146-
// prepare the next batch
147-
llama_batch_clear(batch);
111+
return prompt_id;
112+
}
148113

149-
// push this new token for next evaluation
150-
llama_batch_add(batch, new_token_id, n_cur, { 0 }, true);
114+
void LlamaContext::_run_prompts() {
115+
while (true) {
116+
semaphore->wait();
151117

152-
n_decode += 1;
118+
mutex->lock();
119+
if (should_exit) {
120+
mutex->unlock();
121+
break;
122+
}
123+
if (prompts.is_empty()) {
124+
mutex->unlock();
125+
continue;
153126
}
127+
String prompt = prompts.get(0);
128+
prompts.remove_at(0);
129+
mutex->unlock();
154130

155-
n_cur += 1;
131+
UtilityFunctions::print(vformat("Running prompt %s", prompt));
156132

157-
// evaluate the current batch with the transformer model
158-
int decode_res = llama_decode(ctx, batch);
159-
if (decode_res != 0) {
160-
UtilityFunctions::printerr(vformat("%s: Failed to decode batch with error code: %d", __func__, decode_res));
161-
break;
162-
}
133+
OS::get_singleton()->delay_msec(2000);
163134
}
164135
}
165136

@@ -184,28 +155,52 @@ void LlamaContext::set_n_ctx(int n_ctx) {
184155
ctx_params.n_ctx = n_ctx;
185156
}
186157

187-
int LlamaContext::get_n_threads() {
188-
return ctx_params.n_threads;
158+
float LlamaContext::get_temperature() {
159+
return sampling_params.temp;
189160
}
190-
void LlamaContext::set_n_threads(int n_threads) {
191-
ctx_params.n_threads = n_threads;
161+
void LlamaContext::set_temperature(float temperature) {
162+
sampling_params.temp = temperature;
163+
}
164+
165+
float LlamaContext::get_top_p() {
166+
return sampling_params.top_p;
167+
}
168+
void LlamaContext::set_top_p(float top_p) {
169+
sampling_params.top_p = top_p;
170+
}
171+
172+
int LlamaContext::get_top_k() {
173+
return sampling_params.top_k;
174+
}
175+
void LlamaContext::set_top_k(int top_k) {
176+
sampling_params.top_k = top_k;
192177
}
193178

194-
int LlamaContext::get_n_threads_batch() {
195-
return ctx_params.n_threads_batch;
179+
float LlamaContext::get_presence_penalty() {
180+
return sampling_params.penalty_present;
196181
}
197-
void LlamaContext::set_n_threads_batch(int n_threads_batch) {
198-
ctx_params.n_threads_batch = n_threads_batch;
182+
void LlamaContext::set_presence_penalty(float presence_penalty) {
183+
sampling_params.penalty_present = presence_penalty;
184+
}
185+
186+
float LlamaContext::get_frequency_penalty() {
187+
return sampling_params.penalty_freq;
188+
}
189+
void LlamaContext::set_frequency_penalty(float frequency_penalty) {
190+
sampling_params.penalty_freq = frequency_penalty;
199191
}
200192

201193
LlamaContext::~LlamaContext() {
194+
llama_batch_free(batch);
195+
llama_sampling_free(sampling_ctx);
202196
if (ctx) {
203197
llama_free(ctx);
204198
}
205199

206-
llama_batch_free(batch);
207-
208-
if (task_id) {
209-
WorkerThreadPool::get_singleton()->wait_for_task_completion(task_id);
210-
}
200+
mutex->lock();
201+
prompts.clear();
202+
should_exit = true;
203+
mutex->unlock();
204+
semaphore->post();
205+
worker_thread->wait_to_finish();
211206
}

src/llama_context.h

+34-15
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,31 @@
22
#define LLAMA_CONTEXT_H
33

44
#include "llama.h"
5+
#include "common.h"
56
#include "llama_model.h"
7+
#include <godot_cpp/classes/mutex.hpp>
68
#include <godot_cpp/classes/node.hpp>
7-
9+
#include <godot_cpp/classes/semaphore.hpp>
10+
#include <godot_cpp/classes/thread.hpp>
11+
#include <godot_cpp/templates/vector.hpp>
812
namespace godot {
13+
914
class LlamaContext : public Node {
1015
GDCLASS(LlamaContext, Node)
1116

1217
private:
1318
Ref<LlamaModel> model;
14-
llama_context *ctx = nullptr;
19+
Ref<Thread> worker_thread;
20+
Ref<Semaphore> semaphore;
21+
Ref<Mutex> mutex;
22+
bool should_exit = false;
1523
llama_context_params ctx_params;
24+
llama_sampling_params sampling_params;
25+
llama_context *ctx = nullptr;
26+
llama_sampling_context *sampling_ctx = nullptr;
1627
llama_batch batch;
17-
int task_id;
28+
Vector<String> prompts;
29+
int n_prompts = 0;
1830

1931
protected:
2032
static void _bind_methods();
@@ -23,21 +35,28 @@ class LlamaContext : public Node {
2335
void set_model(const Ref<LlamaModel> model);
2436
Ref<LlamaModel> get_model();
2537

26-
Variant request_completion(const String &prompt);
27-
void _fulfill_completion(const String &prompt);
38+
int prompt(const String &prompt, int max_new_tokens);
39+
void _run_prompts();
40+
41+
int get_seed();
42+
void set_seed(int seed);
43+
int get_n_ctx();
44+
void set_n_ctx(int n_ctx);
2845

29-
int get_seed();
30-
void set_seed(int seed);
31-
int get_n_ctx();
32-
void set_n_ctx(int n_ctx);
33-
int get_n_threads();
34-
void set_n_threads(int n_threads);
35-
int get_n_threads_batch();
36-
void set_n_threads_batch(int n_threads_batch);
46+
float get_temperature();
47+
void set_temperature(float temperature);
48+
float get_top_p();
49+
void set_top_p(float top_p);
50+
int get_top_k();
51+
void set_top_k(int top_k);
52+
float get_presence_penalty();
53+
void set_presence_penalty(float presence_penalty);
54+
float get_frequency_penalty();
55+
void set_frequency_penalty(float frequency_penalty);
3756

38-
virtual PackedStringArray _get_configuration_warnings() const override;
57+
virtual PackedStringArray _get_configuration_warnings() const override;
3958
virtual void _ready() override;
40-
LlamaContext();
59+
LlamaContext();
4160
~LlamaContext();
4261
};
4362
} //namespace godot

0 commit comments

Comments
 (0)