Skip to content

Commit 14527c4

Browse files
committed
working inference, kinda
1 parent 0dcb26e commit 14527c4

File tree

7 files changed

+234
-75
lines changed

7 files changed

+234
-75
lines changed

godot/autoloads/llama.tscn

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[gd_scene load_steps=2 format=3 uid="uid://bxobxniygk7jm"]
22

3-
[ext_resource type="LlamaModel" path="res://models/stablelm-2-zephyr-1_6b-Q4_K_M.gguf" id="1_8pggd"]
3+
[ext_resource type="LlamaModel" path="res://models/OGNO-7B-Q4_K_M.gguf" id="1_vd8h8"]
44

55
[node name="LlamaContext" type="LlamaContext"]
6-
model = ExtResource("1_8pggd")
6+
model = ExtResource("1_vd8h8")

godot/main.gd

+26-11
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,31 @@
11
extends Node
22

3-
@onready var input: TextEdit = $"Form/Input"
4-
5-
# Called when the node enters the scene tree for the first time.
6-
func _ready():
7-
pass # Replace with function body.
8-
9-
10-
# Called every frame. 'delta' is the elapsed time since the previous frame.
11-
func _process(delta):
12-
pass
13-
3+
@onready var input: TextEdit = %Input
4+
@onready var submit_button: Button = %SubmitButton
5+
@onready var output: Label = %Output
146

157
func _on_button_pressed():
8+
handle_submit()
9+
10+
#func _unhandled_key_input(event: InputEvent) -> void:
11+
#if (event.is_action_released("submit_form") and input.has_focus()):
12+
#handle_submit()
13+
14+
func handle_submit():
1615
print(input.text)
16+
Llama.request_completion(input.text)
17+
18+
input.clear()
19+
input.editable = false
20+
submit_button.disabled = true
21+
output.text = "..."
22+
23+
var completion = await Llama.completion_generated
24+
output.text = ""
25+
while !completion[1]:
26+
print(completion[0])
27+
output.text += completion[0]
28+
completion = await Llama.completion_generated
29+
30+
input.editable = true
31+
submit_button.disabled = false

godot/main.tscn

+76-50
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1-
[gd_scene load_steps=3 format=3 uid="uid://7oo8yj56scb1"]
1+
[gd_scene load_steps=4 format=3 uid="uid://7oo8yj56scb1"]
22

33
[ext_resource type="Texture2D" uid="uid://dplw232htshgc" path="res://addons/godot-llama-cpp/assets/godot-llama-cpp-1024x1024.svg" id="1_ojdoj"]
44
[ext_resource type="Script" path="res://main.gd" id="1_vvrqe"]
55

6+
[sub_resource type="StyleBoxFlat" id="StyleBoxFlat_3e37a"]
7+
corner_radius_top_left = 5
8+
corner_radius_top_right = 5
9+
corner_radius_bottom_right = 5
10+
corner_radius_bottom_left = 5
11+
612
[node name="Main" type="Node"]
713
script = ExtResource("1_vvrqe")
814

@@ -14,63 +20,83 @@ grow_horizontal = 2
1420
grow_vertical = 2
1521
color = Color(0.980392, 0.952941, 0.929412, 1)
1622

17-
[node name="Form" type="HBoxContainer" parent="."]
18-
custom_minimum_size = Vector2(300, 50)
19-
anchors_preset = -1
20-
anchor_top = 0.7
21-
anchor_right = 1.0
22-
anchor_bottom = 0.7
23-
offset_left = 350.0
24-
offset_top = -1.66893e-05
25-
offset_right = -350.0
26-
offset_bottom = 50.0
23+
[node name="CenterContainer" type="CenterContainer" parent="."]
24+
anchors_preset = 8
25+
anchor_left = 0.5
26+
anchor_top = 0.5
27+
anchor_right = 0.5
28+
anchor_bottom = 0.5
29+
offset_left = -400.0
30+
offset_top = -479.0
31+
offset_right = 400.0
32+
offset_bottom = 479.0
2733
grow_horizontal = 2
28-
grow_vertical = 0
29-
alignment = 1
34+
grow_vertical = 2
3035

31-
[node name="Input" type="TextEdit" parent="Form"]
36+
[node name="VBoxContainer" type="VBoxContainer" parent="CenterContainer"]
37+
custom_minimum_size = Vector2(500, 0)
3238
layout_mode = 2
33-
size_flags_horizontal = 3
34-
size_flags_stretch_ratio = 3.0
35-
placeholder_text = "Ask me anything..."
39+
theme_override_constants/separation = 10
40+
alignment = 1
3641

37-
[node name="SubmitButton" type="Button" parent="Form"]
42+
[node name="Name" type="Label" parent="CenterContainer/VBoxContainer"]
3843
layout_mode = 2
39-
size_flags_horizontal = 3
40-
text = "Submit"
44+
theme_override_colors/font_color = Color(0.101961, 0.0823529, 0.0627451, 1)
45+
theme_override_font_sizes/font_size = 32
46+
text = "godot-llama-cpp"
47+
horizontal_alignment = 1
4148

42-
[node name="SpriteContainer" type="CenterContainer" parent="."]
43-
anchors_preset = -1
44-
anchor_left = 0.5
45-
anchor_top = 0.4
46-
anchor_right = 0.5
47-
anchor_bottom = 0.4
48-
offset_left = -20.0
49-
offset_top = -20.0
50-
offset_right = 20.0
51-
offset_bottom = 20.0
52-
grow_horizontal = 2
53-
grow_vertical = 2
49+
[node name="MarginContainer" type="MarginContainer" parent="CenterContainer/VBoxContainer"]
50+
layout_mode = 2
51+
theme_override_constants/margin_left = 100
52+
theme_override_constants/margin_right = 100
5453

55-
[node name="GodotLlamaSprite" type="Sprite2D" parent="SpriteContainer"]
56-
position = Vector2(20, 20)
57-
scale = Vector2(0.2, 0.2)
54+
[node name="TextureRect" type="TextureRect" parent="CenterContainer/VBoxContainer/MarginContainer"]
55+
layout_mode = 2
5856
texture = ExtResource("1_ojdoj")
57+
expand_mode = 4
5958

60-
[node name="Label" type="Label" parent="."]
61-
anchors_preset = -1
62-
anchor_left = 0.5
63-
anchor_top = 0.6
64-
anchor_right = 0.5
65-
anchor_bottom = 0.6
66-
offset_left = -127.0
67-
offset_top = -22.5
68-
offset_right = 127.0
69-
offset_bottom = 22.5
70-
grow_horizontal = 2
71-
grow_vertical = 2
59+
[node name="ScrollContainer" type="ScrollContainer" parent="CenterContainer/VBoxContainer"]
60+
custom_minimum_size = Vector2(0, 60)
61+
layout_mode = 2
62+
horizontal_scroll_mode = 0
63+
64+
[node name="Panel" type="PanelContainer" parent="CenterContainer/VBoxContainer/ScrollContainer"]
65+
layout_mode = 2
66+
size_flags_horizontal = 3
67+
size_flags_vertical = 3
68+
theme_override_styles/panel = SubResource("StyleBoxFlat_3e37a")
69+
70+
[node name="MarginContainer" type="MarginContainer" parent="CenterContainer/VBoxContainer/ScrollContainer/Panel"]
71+
layout_mode = 2
72+
theme_override_constants/margin_left = 20
73+
theme_override_constants/margin_right = 20
74+
75+
[node name="Output" type="Label" parent="CenterContainer/VBoxContainer/ScrollContainer/Panel/MarginContainer"]
76+
unique_name_in_owner = true
77+
custom_minimum_size = Vector2(200, 0)
78+
layout_mode = 2
7279
theme_override_colors/font_color = Color(0.101961, 0.0823529, 0.0627451, 1)
73-
theme_override_font_sizes/font_size = 32
74-
text = "godot-llama-cpp"
80+
text = "Ask me anything!"
81+
autowrap_mode = 3
82+
83+
[node name="Form" type="HBoxContainer" parent="CenterContainer/VBoxContainer"]
84+
custom_minimum_size = Vector2(500, 60)
85+
layout_mode = 2
86+
size_flags_horizontal = 4
87+
alignment = 1
88+
89+
[node name="Input" type="TextEdit" parent="CenterContainer/VBoxContainer/Form"]
90+
unique_name_in_owner = true
91+
layout_mode = 2
92+
size_flags_horizontal = 3
93+
size_flags_stretch_ratio = 3.0
94+
placeholder_text = "Why do cows moo?"
95+
96+
[node name="SubmitButton" type="Button" parent="CenterContainer/VBoxContainer/Form"]
97+
unique_name_in_owner = true
98+
layout_mode = 2
99+
size_flags_horizontal = 3
100+
text = "Submit"
75101

76-
[connection signal="pressed" from="Form/SubmitButton" to="." method="_on_button_pressed"]
102+
[connection signal="pressed" from="CenterContainer/VBoxContainer/Form/SubmitButton" to="." method="_on_button_pressed"]

godot/project.godot

+21
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,33 @@ config_version=5
1313
config/name="godot-llama-cpp"
1414
run/main_scene="res://main.tscn"
1515
config/features=PackedStringArray("4.2", "Forward Plus")
16+
config/icon="res://addons/godot-llama-cpp/assets/godot-llama-cpp-1024x1024.svg"
1617

1718
[autoload]
1819

1920
__LlamaBackend="*res://addons/godot-llama-cpp/autoloads/llama-backend.gd"
2021
Llama="*res://autoloads/llama.tscn"
2122

23+
[display]
24+
25+
window/size/viewport_width=1280
26+
window/size/viewport_height=720
27+
2228
[editor_plugins]
2329

2430
enabled=PackedStringArray("res://addons/godot-llama-cpp/plugin.cfg")
31+
32+
[input]
33+
34+
submit_form={
35+
"deadzone": 0.5,
36+
"events": [Object(InputEventKey,"resource_local_to_scene":false,"resource_name":"","device":-1,"window_id":0,"alt_pressed":false,"shift_pressed":false,"ctrl_pressed":false,"meta_pressed":false,"pressed":false,"keycode":0,"physical_keycode":4194309,"key_label":0,"unicode":0,"echo":false,"script":null)
37+
]
38+
}
39+
40+
[rendering]
41+
42+
anti_aliasing/quality/msaa_2d=3
43+
anti_aliasing/quality/msaa_3d=3
44+
anti_aliasing/quality/screen_space_aa=1
45+
anti_aliasing/quality/use_taa=true

src/llama_context.cpp

+104-12
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
#include "llama_context.h"
2+
#include "common.h"
23
#include "llama.h"
34
#include "llama_model.h"
45
#include <godot_cpp/classes/engine.hpp>
56
#include <godot_cpp/classes/os.hpp>
7+
#include <godot_cpp/classes/worker_thread_pool.hpp>
68
#include <godot_cpp/core/class_db.hpp>
79
#include <godot_cpp/variant/utility_functions.hpp>
810

911
using namespace godot;
1012

11-
void LlamaContext::set_model(const Ref<LlamaModel> p_model) {
12-
model = p_model;
13-
}
14-
15-
Ref<LlamaModel> LlamaContext::get_model() {
16-
return model;
13+
void LlamaContext::_bind_methods() {
14+
ClassDB::bind_method(D_METHOD("set_model", "model"), &LlamaContext::set_model);
15+
ClassDB::bind_method(D_METHOD("get_model"), &LlamaContext::get_model);
16+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::OBJECT, "model", PROPERTY_HINT_RESOURCE_TYPE, "LlamaModel"), "set_model", "get_model");
17+
ClassDB::bind_method(D_METHOD("request_completion", "prompt"), &LlamaContext::request_completion);
18+
ClassDB::bind_method(D_METHOD("_fulfill_completion", "prompt"), &LlamaContext::_fulfill_completion);
19+
ADD_SIGNAL(MethodInfo("completion_generated", PropertyInfo(Variant::STRING, "completion"), PropertyInfo(Variant::BOOL, "is_final")));
1720
}
1821

1922
void LlamaContext::_ready() {
@@ -40,14 +43,103 @@ void LlamaContext::_ready() {
4043
UtilityFunctions::print(vformat("%s: Context initialized", __func__));
4144
}
4245

46+
Variant LlamaContext::request_completion(const String &prompt) {
47+
UtilityFunctions::print(vformat("%s: Requesting completion for prompt: %s", __func__, prompt));
48+
if (task_id) {
49+
WorkerThreadPool::get_singleton()->wait_for_task_completion(task_id);
50+
}
51+
task_id = WorkerThreadPool::get_singleton()->add_task(Callable(this, "_fulfill_completion").bind(prompt));
52+
return OK;
53+
}
54+
55+
void LlamaContext::_fulfill_completion(const String &prompt) {
56+
UtilityFunctions::print(vformat("%s: Fulfilling completion for prompt: %s", __func__, prompt));
57+
std::vector<llama_token> tokens_list;
58+
tokens_list = ::llama_tokenize(ctx, std::string(prompt.utf8().get_data()), true);
59+
60+
const int n_len = 128;
61+
const int n_ctx = llama_n_ctx(ctx);
62+
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
63+
if (n_kv_req > n_ctx) {
64+
UtilityFunctions::printerr(vformat("%s: n_kv_req > n_ctx, the required KV cache size is not big enough\neither reduce n_len or increase n_ctx", __func__));
65+
return;
66+
}
67+
68+
for (size_t i = 0; i < tokens_list.size(); i++) {
69+
llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
70+
}
71+
batch.logits[batch.n_tokens - 1] = true;
72+
73+
int decode_res = llama_decode(ctx, batch);
74+
if (decode_res != 0) {
75+
UtilityFunctions::printerr(vformat("%s: Failed to decode prompt with error code: %d", __func__, decode_res));
76+
return;
77+
}
78+
79+
int n_cur = batch.n_tokens;
80+
int n_decode = 0;
81+
llama_model *llama_model = model->model;
82+
while (n_cur <= n_len) {
83+
// sample the next token
84+
{
85+
auto n_vocab = llama_n_vocab(llama_model);
86+
auto *logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
87+
88+
std::vector<llama_token_data> candidates;
89+
candidates.reserve(n_vocab);
90+
91+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
92+
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
93+
}
94+
95+
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
96+
97+
// sample the most likely token
98+
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
99+
100+
// is it an end of stream?
101+
if (new_token_id == llama_token_eos(llama_model) || n_cur == n_len) {
102+
call_thread_safe("emit_signal", "completion_generated", "\n", true);
103+
104+
break;
105+
}
106+
107+
call_thread_safe("emit_signal", "completion_generated", vformat("%s", llama_token_to_piece(ctx, new_token_id).c_str()), false);
108+
109+
// prepare the next batch
110+
llama_batch_clear(batch);
111+
112+
// push this new token for next evaluation
113+
llama_batch_add(batch, new_token_id, n_cur, { 0 }, true);
114+
115+
n_decode += 1;
116+
}
117+
118+
n_cur += 1;
119+
120+
// evaluate the current batch with the transformer model
121+
int decode_res = llama_decode(ctx, batch);
122+
if (decode_res != 0) {
123+
UtilityFunctions::printerr(vformat("%s: Failed to decode batch with error code: %d", __func__, decode_res));
124+
return;
125+
}
126+
}
127+
}
128+
129+
void LlamaContext::set_model(const Ref<LlamaModel> p_model) {
130+
model = p_model;
131+
}
132+
133+
Ref<LlamaModel> LlamaContext::get_model() {
134+
return model;
135+
}
136+
43137
LlamaContext::~LlamaContext() {
44138
if (ctx) {
45139
llama_free(ctx);
46140
}
47-
}
48-
49-
void LlamaContext::_bind_methods() {
50-
ClassDB::bind_method(D_METHOD("set_model", "model"), &LlamaContext::set_model);
51-
ClassDB::bind_method(D_METHOD("get_model"), &LlamaContext::get_model);
52-
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::OBJECT, "model", PROPERTY_HINT_RESOURCE_TYPE, "LlamaModel"), "set_model", "get_model");
141+
llama_batch_free(batch);
142+
if (task_id) {
143+
WorkerThreadPool::get_singleton()->wait_for_task_completion(task_id);
144+
}
53145
}

src/llama_context.h

+4
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@ class LlamaContext : public Node {
1313
Ref<LlamaModel> model;
1414
llama_context *ctx = nullptr;
1515
llama_context_params ctx_params = llama_context_default_params();
16+
llama_batch batch = llama_batch_init(512, 0, 1);
17+
int task_id;
1618

1719
protected:
1820
static void _bind_methods();
1921

2022
public:
2123
void set_model(const Ref<LlamaModel> model);
2224
Ref<LlamaModel> get_model();
25+
Variant request_completion(const String &prompt);
26+
void _fulfill_completion(const String &prompt);
2327
virtual void _ready() override;
2428
~LlamaContext();
2529
};

0 commit comments

Comments
 (0)