Skip to content

Commit 6960f3b

Browse files
committed
working inference
1 parent bc4b614 commit 6960f3b

File tree

6 files changed

+164
-21
lines changed

6 files changed

+164
-21
lines changed

godot/examples/simple/TextEdit.gd

+4-1
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@ func _gui_input(event: InputEvent) -> void:
1010
accept_event()
1111
if keycode == KEY_ENTER | KEY_MASK_SHIFT and event.is_pressed():
1212
insert_text_at_caret("\n")
13+
accept_event()
14+
15+
func _on_button_pressed() -> void:
16+
handle_submit()
1317

1418
func handle_submit() -> void:
1519
submit.emit(text)
1620
text = ""
17-

godot/examples/simple/simple.gd

+4-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ func handle_input(input: String) -> void:
1616
var id = llama_context.request_completion(input)
1717
print("request id: ", id)
1818

19-
var chunk = await llama_context.completion_generated
20-
print('new chunk: ', chunk)
21-
19+
20+
21+
func _on_llama_context_completion_generated(chunk: Dictionary) -> void:
22+
print("new chunk: ", chunk)

godot/examples/simple/simple.tscn

+4-10
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
[ext_resource type="LlamaModel" path="res://models/Phi-3-mini-128k-instruct.Q5_K_M.gguf" id="1_ff70a"]
44
[ext_resource type="Texture2D" uid="uid://dplw232htshgc" path="res://addons/godot-llama-cpp/assets/godot-llama-cpp-1024x1024.svg" id="1_gjsev"]
55
[ext_resource type="Script" path="res://examples/simple/simple.gd" id="1_sruc3"]
6+
[ext_resource type="PackedScene" uid="uid://t862t0v8ht2q" path="res://examples/simple/message.tscn" id="2_7iip7"]
67
[ext_resource type="Script" path="res://examples/simple/TextEdit.gd" id="2_7usqw"]
7-
[ext_resource type="Script" path="res://examples/simple/form.gd" id="2_p1ih5"]
88

99
[node name="Node" type="Node"]
1010
script = ExtResource("1_sruc3")
@@ -41,20 +41,14 @@ layout_mode = 2
4141
size_flags_horizontal = 3
4242
size_flags_vertical = 3
4343

44-
[node name="RichTextLabel" type="RichTextLabel" parent="Panel/MarginContainer/VBoxContainer/ScrollContainer/MessagesContainer"]
44+
[node name="RichTextLabel2" parent="Panel/MarginContainer/VBoxContainer/ScrollContainer/MessagesContainer" instance=ExtResource("2_7iip7")]
4545
layout_mode = 2
46-
focus_mode = 2
4746
text = "How can I help you?"
48-
fit_content = true
49-
scroll_active = false
50-
selection_enabled = true
5147

5248
[node name="HBoxContainer" type="HBoxContainer" parent="Panel/MarginContainer/VBoxContainer"]
5349
layout_mode = 2
54-
script = ExtResource("2_p1ih5")
5550

5651
[node name="TextEdit" type="TextEdit" parent="Panel/MarginContainer/VBoxContainer/HBoxContainer"]
57-
unique_name_in_owner = true
5852
custom_minimum_size = Vector2(2.08165e-12, 100)
5953
layout_mode = 2
6054
size_flags_horizontal = 3
@@ -73,5 +67,5 @@ model = ExtResource("1_ff70a")
7367
unique_name_in_owner = true
7468

7569
[connection signal="submit" from="Panel/MarginContainer/VBoxContainer/HBoxContainer/TextEdit" to="." method="_on_text_edit_submit"]
76-
[connection signal="pressed" from="Panel/MarginContainer/VBoxContainer/HBoxContainer/Button" to="Panel/MarginContainer/VBoxContainer/HBoxContainer" method="_on_button_pressed"]
77-
[connection signal="pressed" from="Panel/MarginContainer/VBoxContainer/HBoxContainer/Button" to="Panel/MarginContainer/VBoxContainer/HBoxContainer/Button" method="_on_pressed"]
70+
[connection signal="pressed" from="Panel/MarginContainer/VBoxContainer/HBoxContainer/Button" to="Panel/MarginContainer/VBoxContainer/HBoxContainer/TextEdit" method="_on_button_pressed"]
71+
[connection signal="completion_generated" from="LlamaContext" to="." method="_on_llama_context_completion_generated"]

godot/project.godot

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ config_version=5
1111
[application]
1212

1313
config/name="godot-llama-cpp"
14+
run/main_scene="res://examples/simple/simple.tscn"
1415
config/features=PackedStringArray("4.2", "Forward Plus")
1516
config/icon="res://icon.svg"
1617

src/llama_context.cpp

+142-6
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
#include "common.h"
33
#include "llama.h"
44
#include "llama_model.h"
5+
#include <algorithm>
56
#include <godot_cpp/classes/engine.hpp>
67
#include <godot_cpp/classes/os.hpp>
78
#include <godot_cpp/classes/worker_thread_pool.hpp>
89
#include <godot_cpp/core/class_db.hpp>
9-
#include <godot_cpp/variant/utility_functions.hpp>
1010
#include <godot_cpp/variant/dictionary.hpp>
11+
#include <godot_cpp/variant/utility_functions.hpp>
1112

1213
using namespace godot;
1314

@@ -24,6 +25,10 @@ void LlamaContext::_bind_methods() {
2425
ClassDB::bind_method(D_METHOD("set_n_ctx", "n_ctx"), &LlamaContext::set_n_ctx);
2526
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_ctx"), "set_n_ctx", "get_n_ctx");
2627

28+
ClassDB::bind_method(D_METHOD("get_n_len"), &LlamaContext::get_n_len);
29+
ClassDB::bind_method(D_METHOD("set_n_len", "n_len"), &LlamaContext::set_n_len);
30+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_len"), "set_n_len", "get_n_len");
31+
2732
ClassDB::bind_method(D_METHOD("request_completion", "prompt"), &LlamaContext::request_completion);
2833
ClassDB::bind_method(D_METHOD("__thread_loop"), &LlamaContext::__thread_loop);
2934

@@ -63,6 +68,9 @@ void LlamaContext::_ready() {
6368
UtilityFunctions::printerr(vformat("%s: Failed to initialize llama context, null ctx", __func__));
6469
return;
6570
}
71+
72+
sampling_ctx = llama_sampling_init(sampling_params);
73+
6674
UtilityFunctions::print(vformat("%s: Context initialized", __func__));
6775

6876
thread->start(callable_mp(this, &LlamaContext::__thread_loop));
@@ -73,6 +81,10 @@ void LlamaContext::__thread_loop() {
7381
semaphore->wait();
7482

7583
mutex->lock();
84+
if (exit_thread) {
85+
mutex->unlock();
86+
break;
87+
}
7688
if (completion_requests.size() == 0) {
7789
mutex->unlock();
7890
continue;
@@ -83,10 +95,115 @@ void LlamaContext::__thread_loop() {
8395

8496
UtilityFunctions::print(vformat("%s: Running completion for prompt id: %d", __func__, req.id));
8597

86-
Dictionary chunk;
87-
chunk["id"] = req.id;
88-
chunk["text"] = "Hello, world!";
89-
call_deferred("emit_signal", "completion_generated", chunk);
98+
std::vector<llama_token> request_tokens;
99+
request_tokens = ::llama_tokenize(ctx, req.prompt.utf8().get_data(), true);
100+
101+
size_t shared_prefix_idx = 0;
102+
auto diff = std::mismatch(context_tokens.begin(), context_tokens.end(), request_tokens.begin(), request_tokens.end());
103+
if (diff.first != context_tokens.end()) {
104+
shared_prefix_idx = std::distance(context_tokens.begin(), diff.first);
105+
} else {
106+
shared_prefix_idx = std::min(context_tokens.size(), request_tokens.size());
107+
}
108+
109+
bool rm_success = llama_kv_cache_seq_rm(ctx, 0, shared_prefix_idx, -1);
110+
if (!rm_success) {
111+
UtilityFunctions::printerr(vformat("%s: Failed to remove tokens from kv cache", __func__));
112+
Dictionary response;
113+
response["id"] = req.id;
114+
response["error"] = "Failed to remove tokens from kv cache";
115+
call_deferred("emit_signal", "completion_generated", response);
116+
continue;
117+
}
118+
context_tokens.erase(context_tokens.begin() + shared_prefix_idx, context_tokens.end());
119+
request_tokens.erase(request_tokens.begin(), request_tokens.begin() + shared_prefix_idx);
120+
121+
uint batch_size = std::min(ctx_params.n_batch, (uint)request_tokens.size());
122+
123+
llama_batch batch = llama_batch_init(batch_size, 0, 1);
124+
125+
// chunk request_tokens into sequences of size batch_size
126+
std::vector<std::vector<llama_token>> sequences;
127+
for (size_t i = 0; i < request_tokens.size(); i += batch_size) {
128+
sequences.push_back(std::vector<llama_token>(request_tokens.begin() + i, request_tokens.begin() + std::min(i + batch_size, request_tokens.size())));
129+
}
130+
131+
int curr_token_pos = context_tokens.size();
132+
bool decode_failed = false;
133+
134+
for (size_t i = 0; i < sequences.size(); i++) {
135+
llama_batch_clear(batch);
136+
137+
std::vector<llama_token> sequence = sequences[i];
138+
139+
for (size_t j = 0; j < sequence.size(); j++) {
140+
llama_batch_add(batch, sequence[j], j + curr_token_pos, { 0 }, false);
141+
curr_token_pos++;
142+
}
143+
144+
if (i == sequences.size() - 1) {
145+
batch.logits[batch.n_tokens - 1] = true;
146+
}
147+
148+
if (llama_decode(ctx, batch) != 0) {
149+
decode_failed = true;
150+
break;
151+
}
152+
}
153+
154+
if (decode_failed) {
155+
Dictionary response;
156+
response["id"] = req.id;
157+
response["error"] = "llama_decode() failed";
158+
call_deferred("emit_signal", "completion_generated", response);
159+
continue;
160+
}
161+
162+
context_tokens.insert(context_tokens.end(), request_tokens.begin(), request_tokens.end());
163+
164+
while (true) {
165+
if (exit_thread) {
166+
return;
167+
}
168+
llama_token new_token_id = llama_sampling_sample(sampling_ctx, ctx, NULL, batch.n_tokens - 1);
169+
llama_sampling_accept(sampling_ctx, ctx, new_token_id, true);
170+
171+
Dictionary response;
172+
response["id"] = req.id;
173+
174+
if (llama_token_is_eog(model->model, new_token_id) || curr_token_pos == n_len) {
175+
response["done"] = true;
176+
call_deferred("emit_signal", "completion_generated", response);
177+
break;
178+
}
179+
180+
context_tokens.push_back(new_token_id);
181+
182+
response["text"] = llama_token_to_piece(ctx, new_token_id).c_str();
183+
response["done"] = false;
184+
call_deferred("emit_signal", "completion_generated", response);
185+
186+
llama_batch_clear(batch);
187+
188+
llama_batch_add(batch, new_token_id, curr_token_pos, { 0 }, true);
189+
190+
curr_token_pos++;
191+
192+
if (llama_decode(ctx, batch) != 0) {
193+
decode_failed = true;
194+
break;
195+
}
196+
}
197+
198+
if (decode_failed) {
199+
Dictionary response;
200+
response["id"] = req.id;
201+
response["error"] = "llama_decode() failed";
202+
call_deferred("emit_signal", "completion_generated", response);
203+
continue;
204+
}
205+
206+
llama_sampling_reset(sampling_ctx);
90207
}
91208
}
92209

@@ -134,7 +251,26 @@ void LlamaContext::set_n_ctx(int n_ctx) {
134251
ctx_params.n_ctx = n_ctx;
135252
}
136253

137-
LlamaContext::~LlamaContext() {
254+
int LlamaContext::get_n_len() {
255+
return n_len;
256+
}
257+
void LlamaContext::set_n_len(int n_len) {
258+
this->n_len = n_len;
259+
}
260+
261+
void LlamaContext::_exit_tree() {
262+
if (Engine::get_singleton()->is_editor_hint()) {
263+
return;
264+
}
265+
266+
mutex->lock();
267+
exit_thread = true;
268+
mutex->unlock();
269+
270+
semaphore->post();
271+
272+
thread->wait_to_finish();
273+
138274
if (ctx) {
139275
llama_free(ctx);
140276
}

src/llama_context.h

+9-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define LLAMA_CONTEXT_H
33

44
#include "llama.h"
5+
#include "common.h"
56
#include "llama_model.h"
67
#include <godot_cpp/classes/mutex.hpp>
78
#include <godot_cpp/classes/node.hpp>
@@ -21,13 +22,18 @@ class LlamaContext : public Node {
2122
private:
2223
Ref<LlamaModel> model;
2324
llama_context *ctx = nullptr;
25+
llama_sampling_context *sampling_ctx = nullptr;
2426
llama_context_params ctx_params;
27+
llama_sampling_params sampling_params;
28+
int n_len = 1024;
2529
int request_id = 0;
2630
Vector<completion_request> completion_requests;
2731

2832
Ref<Thread> thread;
2933
Ref<Semaphore> semaphore;
3034
Ref<Mutex> mutex;
35+
std::vector<llama_token> context_tokens;
36+
bool exit_thread = false;
3137

3238
protected:
3339
static void _bind_methods();
@@ -43,11 +49,13 @@ class LlamaContext : public Node {
4349
void set_seed(int seed);
4450
int get_n_ctx();
4551
void set_n_ctx(int n_ctx);
52+
int get_n_len();
53+
void set_n_len(int n_len);
4654

4755
virtual PackedStringArray _get_configuration_warnings() const override;
4856
virtual void _ready() override;
57+
virtual void _exit_tree() override;
4958
LlamaContext();
50-
~LlamaContext();
5159
};
5260
} //namespace godot
5361

0 commit comments

Comments
 (0)