Skip to content

Commit 154bdcb

Browse files
authoredMay 21, 2024··
Merge pull request #3 from hazelnutcloud/dev
Update inference code
2 parents a243391 + 17f069c commit 154bdcb

30 files changed

+785
-574
lines changed
 

‎build.zig

+116-237
Large diffs are not rendered by default.

‎godot/.gitattributes

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Normalize EOL for all files that Git considers text files.
2+
* text=auto eol=lf

‎godot/.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Godot 4+ specific ignores
2+
.godot/

‎godot/addons/godot-llama-cpp/autoloads/llama-backend.gd

-10
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
class_name ChatFormatter
2+
3+
static func apply(format: String, messages: Array) -> String:
4+
match format:
5+
"llama3":
6+
return format_llama3(messages)
7+
"phi3":
8+
return format_phi3(messages)
9+
"mistral":
10+
return format_mistral(messages)
11+
_:
12+
printerr("Unknown chat format: ", format)
13+
return ""
14+
15+
static func format_llama3(messages: Array) -> String:
16+
var res = ""
17+
18+
for i in range(messages.size()):
19+
match messages[i]:
20+
{"text": var text, "sender": var sender}:
21+
res += """<|start_header_id|>%s<|end_header_id|>
22+
23+
%s<|eot_id|>
24+
""" % [sender, text]
25+
_:
26+
printerr("Invalid message at index ", i)
27+
28+
res += "<|start_header_id|>assistant<|end_header_id|>\n\n"
29+
return res
30+
31+
static func format_phi3(messages: Array) -> String:
32+
var res = ""
33+
34+
for i in range(messages.size()):
35+
match messages[i]:
36+
{"text": var text, "sender": var sender}:
37+
res +="<|%s|>\n%s<|end|>\n" % [sender, text]
38+
_:
39+
printerr("Invalid message at index ", i)
40+
res += "<|assistant|>\n"
41+
return res
42+
43+
static func format_mistral(messages: Array) -> String:
44+
var res = ""
45+
46+
for i in range(messages.size()):
47+
match messages[i]:
48+
{"text": var text, "sender": var sender}:
49+
if sender == "user":
50+
res += "[INST] %s [/INST]" % text
51+
else:
52+
res += "%s</s>"
53+
_:
54+
printerr("Invalid message at index ", i)
55+
56+
return res

‎godot/addons/godot-llama-cpp/plugin.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ name="godot-llama-cpp"
44
description="Run large language models in Godot. Powered by llama.cpp."
55
author="hazelnutcloud"
66
version="0.0.1"
7-
script="godot-llama-cpp.gd"
7+
script="plugin.gd"

‎godot/addons/godot-llama-cpp/godot-llama-cpp.gd ‎godot/addons/godot-llama-cpp/plugin.gd

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ extends EditorPlugin
44

55
func _enter_tree():
66
# Initialization of the plugin goes here.
7-
add_autoload_singleton("__LlamaBackend", "res://addons/godot-llama-cpp/autoloads/llama-backend.gd")
7+
pass
88

99

1010
func _exit_tree():
1111
# Clean-up of the plugin goes here.
12-
remove_autoload_singleton("__LlamaBackend")
12+
pass

‎godot/addons/godot-llama-cpp/godot-llama-cpp.gdextension ‎godot/addons/godot-llama-cpp/plugin.gdextension

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ compatibility_minimum = "4.2"
55

66
[libraries]
77

8-
macos.debug = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp-aarch64-macos-none-Debug.dylib"
9-
macos.release = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp-aarch64-macos-none-ReleaseFast.dylib"
8+
macos.debug = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp-aarch64-macos-none-ReleaseSafe.dylib"
9+
macos.release = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp-aarch64-macos-none-ReleaseSafe.dylib"
1010
windows.debug.x86_32 = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp.windows.template_debug.x86_32.dll"
1111
windows.release.x86_32 = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp.windows.template_release.x86_32.dll"
1212
windows.debug.x86_64 = "res://addons/godot-llama-cpp/lib/godot-llama-cpp-x86_64-windows-gnu-Debug.dll"

‎godot/autoloads/llama.tscn

-6
This file was deleted.

‎godot/examples/simple/TextEdit.gd

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
extends TextEdit
2+
3+
signal submit(input: String)
4+
5+
func _gui_input(event: InputEvent) -> void:
6+
if event is InputEventKey:
7+
var keycode = event.get_keycode_with_modifiers()
8+
if keycode == KEY_ENTER and event.is_pressed():
9+
handle_submit()
10+
accept_event()
11+
if keycode == KEY_ENTER | KEY_MASK_SHIFT and event.is_pressed():
12+
insert_text_at_caret("\n")
13+
accept_event()
14+
15+
func _on_button_pressed() -> void:
16+
handle_submit()
17+
18+
func handle_submit() -> void:
19+
submit.emit(text)
20+
text = ""

‎godot/examples/simple/form.gd

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
extends HBoxContainer
2+
3+
@onready var text_edit = %TextEdit
4+
5+
func _on_button_pressed() -> void:
6+
text_edit.handle_submit()

‎godot/examples/simple/message.gd

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
class_name Message
2+
extends Node
3+
4+
@onready var text_container = %Text
5+
@onready var icon = %Panel
6+
@export_enum("user", "assistant") var sender: String
7+
@export var include_in_prompt: bool = true
8+
var text:
9+
get:
10+
return text_container.text
11+
set(value):
12+
text_container.text = value
13+
14+
var completion_id: int = -1
15+
var pending: bool = false
16+
var errored: bool = false
17+
18+
func set_text(new_text: String):
19+
text_container.text = new_text
20+
21+
func append_text(new_text: String):
22+
text_container.text += new_text
23+

‎godot/examples/simple/message.tscn

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
[gd_scene load_steps=5 format=3 uid="uid://t862t0v8ht2q"]
2+
3+
[ext_resource type="Script" path="res://examples/simple/message.gd" id="1_pko33"]
4+
[ext_resource type="Texture2D" uid="uid://dplw232htshgc" path="res://addons/godot-llama-cpp/assets/godot-llama-cpp-1024x1024.svg" id="2_dvc7y"]
5+
6+
[sub_resource type="StyleBoxTexture" id="StyleBoxTexture_t8bgj"]
7+
texture = ExtResource("2_dvc7y")
8+
9+
[sub_resource type="Theme" id="Theme_bw3pb"]
10+
Panel/styles/panel = SubResource("StyleBoxTexture_t8bgj")
11+
12+
[node name="RichTextLabel" type="HBoxContainer"]
13+
anchors_preset = 15
14+
anchor_right = 1.0
15+
anchor_bottom = 1.0
16+
grow_horizontal = 2
17+
grow_vertical = 2
18+
size_flags_horizontal = 3
19+
theme_override_constants/separation = 20
20+
script = ExtResource("1_pko33")
21+
sender = "assistant"
22+
23+
[node name="Panel" type="Panel" parent="."]
24+
unique_name_in_owner = true
25+
custom_minimum_size = Vector2(80, 80)
26+
layout_mode = 2
27+
size_flags_vertical = 0
28+
theme = SubResource("Theme_bw3pb")
29+
30+
[node name="Text" type="RichTextLabel" parent="."]
31+
unique_name_in_owner = true
32+
layout_mode = 2
33+
size_flags_horizontal = 3
34+
focus_mode = 2
35+
text = "..."
36+
fit_content = true
37+
selection_enabled = true

‎godot/examples/simple/simple.gd

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
extends Node
2+
3+
const message = preload("res://examples/simple/message.tscn")
4+
5+
@onready var messages_container = %MessagesContainer
6+
@onready var llama_context = %LlamaContext
7+
8+
func _on_text_edit_submit(input: String) -> void:
9+
handle_input(input)
10+
11+
func handle_input(input: String) -> void:
12+
#var messages = [{ "sender": "system", "text": "You are a pirate chatbot who always responds in pirate speak!" }]
13+
14+
#var messages = [{ "sender": "system", "text": "You are a helpful chatbot assistant!" }]
15+
var messages = []
16+
messages.append_array(messages_container.get_children().filter(func(msg: Message): return msg.include_in_prompt).map(
17+
func(msg: Message) -> Dictionary:
18+
return { "text": msg.text, "sender": msg.sender }
19+
))
20+
messages.append({"text": input, "sender": "user"})
21+
var prompt = ChatFormatter.apply("llama3", messages)
22+
print("prompt: ", prompt)
23+
24+
var completion_id = llama_context.request_completion(prompt)
25+
26+
var user_message: Message = message.instantiate()
27+
messages_container.add_child(user_message)
28+
user_message.set_text(input)
29+
user_message.sender = "user"
30+
user_message.completion_id = completion_id
31+
32+
var ai_message: Message = message.instantiate()
33+
messages_container.add_child(ai_message)
34+
ai_message.sender = "assistant"
35+
ai_message.completion_id = completion_id
36+
ai_message.pending = true
37+
ai_message.grab_focus()
38+
39+
40+
41+
func _on_llama_context_completion_generated(chunk: Dictionary) -> void:
42+
var completion_id = chunk.id
43+
for msg: Message in messages_container.get_children():
44+
if msg.completion_id != completion_id or msg.sender != "assistant":
45+
continue
46+
if chunk.has("error"):
47+
msg.errored = true
48+
elif chunk.has("text"):
49+
if msg.pending:
50+
msg.pending = false
51+
msg.set_text(chunk["text"])
52+
else:
53+
msg.append_text(chunk["text"])

‎godot/examples/simple/simple.tscn

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
[gd_scene load_steps=6 format=3 uid="uid://c55kb4qvg6geq"]
2+
3+
[ext_resource type="Texture2D" uid="uid://dplw232htshgc" path="res://addons/godot-llama-cpp/assets/godot-llama-cpp-1024x1024.svg" id="1_gjsev"]
4+
[ext_resource type="Script" path="res://examples/simple/simple.gd" id="1_sruc3"]
5+
[ext_resource type="PackedScene" uid="uid://t862t0v8ht2q" path="res://examples/simple/message.tscn" id="2_7iip7"]
6+
[ext_resource type="Script" path="res://examples/simple/TextEdit.gd" id="2_7usqw"]
7+
[ext_resource type="LlamaModel" path="res://models/meta-llama-3-8b-instruct.Q5_K_M.gguf" id="5_qov1l"]
8+
9+
[node name="Node" type="Node"]
10+
script = ExtResource("1_sruc3")
11+
12+
[node name="Panel" type="Panel" parent="."]
13+
anchors_preset = 15
14+
anchor_right = 1.0
15+
anchor_bottom = 1.0
16+
grow_horizontal = 2
17+
grow_vertical = 2
18+
19+
[node name="MarginContainer" type="MarginContainer" parent="Panel"]
20+
layout_mode = 1
21+
anchors_preset = 15
22+
anchor_right = 1.0
23+
anchor_bottom = 1.0
24+
grow_horizontal = 2
25+
grow_vertical = 2
26+
theme_override_constants/margin_left = 10
27+
theme_override_constants/margin_top = 10
28+
theme_override_constants/margin_right = 10
29+
theme_override_constants/margin_bottom = 10
30+
31+
[node name="VBoxContainer" type="VBoxContainer" parent="Panel/MarginContainer"]
32+
layout_mode = 2
33+
34+
[node name="ScrollContainer" type="ScrollContainer" parent="Panel/MarginContainer/VBoxContainer"]
35+
layout_mode = 2
36+
size_flags_vertical = 3
37+
follow_focus = true
38+
39+
[node name="MessagesContainer" type="VBoxContainer" parent="Panel/MarginContainer/VBoxContainer/ScrollContainer"]
40+
unique_name_in_owner = true
41+
layout_mode = 2
42+
size_flags_horizontal = 3
43+
size_flags_vertical = 3
44+
theme_override_constants/separation = 30
45+
46+
[node name="RichTextLabel2" parent="Panel/MarginContainer/VBoxContainer/ScrollContainer/MessagesContainer" instance=ExtResource("2_7iip7")]
47+
layout_mode = 2
48+
include_in_prompt = false
49+
50+
[node name="Text" parent="Panel/MarginContainer/VBoxContainer/ScrollContainer/MessagesContainer/RichTextLabel2" index="1"]
51+
text = "How can I help you?"
52+
53+
[node name="HBoxContainer" type="HBoxContainer" parent="Panel/MarginContainer/VBoxContainer"]
54+
layout_mode = 2
55+
56+
[node name="TextEdit" type="TextEdit" parent="Panel/MarginContainer/VBoxContainer/HBoxContainer"]
57+
custom_minimum_size = Vector2(2.08165e-12, 100)
58+
layout_mode = 2
59+
size_flags_horizontal = 3
60+
placeholder_text = "Ask me anything..."
61+
wrap_mode = 1
62+
script = ExtResource("2_7usqw")
63+
64+
[node name="Button" type="Button" parent="Panel/MarginContainer/VBoxContainer/HBoxContainer"]
65+
custom_minimum_size = Vector2(100, 2.08165e-12)
66+
layout_mode = 2
67+
icon = ExtResource("1_gjsev")
68+
expand_icon = true
69+
70+
[node name="LlamaContext" type="LlamaContext" parent="."]
71+
model = ExtResource("5_qov1l")
72+
temperature = 0.9
73+
unique_name_in_owner = true
74+
75+
[connection signal="submit" from="Panel/MarginContainer/VBoxContainer/HBoxContainer/TextEdit" to="." method="_on_text_edit_submit"]
76+
[connection signal="pressed" from="Panel/MarginContainer/VBoxContainer/HBoxContainer/Button" to="Panel/MarginContainer/VBoxContainer/HBoxContainer/TextEdit" method="_on_button_pressed"]
77+
[connection signal="completion_generated" from="LlamaContext" to="." method="_on_llama_context_completion_generated"]
78+
79+
[editable path="Panel/MarginContainer/VBoxContainer/ScrollContainer/MessagesContainer/RichTextLabel2"]

‎godot/icon.svg

+1
Loading

‎godot/icon.svg.import

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
[remap]
2+
3+
importer="texture"
4+
type="CompressedTexture2D"
5+
uid="uid://beeg0oqle7bnk"
6+
path="res://.godot/imported/icon.svg-218a8f2b3041327d8a5756f3a245f83b.ctex"
7+
metadata={
8+
"vram_texture": false
9+
}
10+
11+
[deps]
12+
13+
source_file="res://icon.svg"
14+
dest_files=["res://.godot/imported/icon.svg-218a8f2b3041327d8a5756f3a245f83b.ctex"]
15+
16+
[params]
17+
18+
compress/mode=0
19+
compress/high_quality=false
20+
compress/lossy_quality=0.7
21+
compress/hdr_compression=1
22+
compress/normal_map=0
23+
compress/channel_pack=0
24+
mipmaps/generate=false
25+
mipmaps/limit=-1
26+
roughness/mode=0
27+
roughness/src_normal=""
28+
process/fix_alpha_border=true
29+
process/premult_alpha=false
30+
process/normal_map_invert_y=false
31+
process/hdr_as_srgb=false
32+
process/hdr_clamp_exposure=false
33+
process/size_limit=0
34+
detect_3d/compress_to=1
35+
svg/scale=1.0
36+
editor/scale_with_editor_scale=false
37+
editor/convert_colors_with_editor_theme=false

‎godot/main.gd

-27
This file was deleted.

‎godot/main.tscn

-103
This file was deleted.

‎godot/project.godot

+4-25
Original file line numberDiff line numberDiff line change
@@ -11,35 +11,14 @@ config_version=5
1111
[application]
1212

1313
config/name="godot-llama-cpp"
14-
run/main_scene="res://main.tscn"
14+
run/main_scene="res://examples/simple/simple.tscn"
1515
config/features=PackedStringArray("4.2", "Forward Plus")
16-
config/icon="res://addons/godot-llama-cpp/assets/godot-llama-cpp-1024x1024.svg"
17-
18-
[autoload]
19-
20-
__LlamaBackend="*res://addons/godot-llama-cpp/autoloads/llama-backend.gd"
21-
Llama="*res://autoloads/llama.tscn"
22-
23-
[display]
24-
25-
window/size/viewport_width=1280
26-
window/size/viewport_height=720
16+
config/icon="res://icon.svg"
2717

2818
[editor_plugins]
2919

3020
enabled=PackedStringArray("res://addons/godot-llama-cpp/plugin.cfg")
3121

32-
[input]
33-
34-
submit_form={
35-
"deadzone": 0.5,
36-
"events": [Object(InputEventKey,"resource_local_to_scene":false,"resource_name":"","device":-1,"window_id":0,"alt_pressed":false,"shift_pressed":false,"ctrl_pressed":false,"meta_pressed":false,"pressed":false,"keycode":0,"physical_keycode":4194309,"key_label":0,"unicode":0,"echo":false,"script":null)
37-
]
38-
}
39-
40-
[rendering]
22+
[gui]
4123

42-
anti_aliasing/quality/msaa_2d=3
43-
anti_aliasing/quality/msaa_3d=3
44-
anti_aliasing/quality/screen_space_aa=1
45-
anti_aliasing/quality/use_taa=true
24+
theme/default_theme_scale=2.0

‎godot_cpp

Submodule godot_cpp updated 62 files

‎llama.cpp

‎src/llama_backend.cpp

-19
This file was deleted.

‎src/llama_backend.h

-19
This file was deleted.

‎src/llama_context.cpp

+215-96
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
#include "common.h"
33
#include "llama.h"
44
#include "llama_model.h"
5+
#include <algorithm>
56
#include <godot_cpp/classes/engine.hpp>
67
#include <godot_cpp/classes/os.hpp>
78
#include <godot_cpp/classes/worker_thread_pool.hpp>
89
#include <godot_cpp/core/class_db.hpp>
10+
#include <godot_cpp/variant/dictionary.hpp>
911
#include <godot_cpp/variant/utility_functions.hpp>
1012

1113
using namespace godot;
@@ -15,31 +17,41 @@ void LlamaContext::_bind_methods() {
1517
ClassDB::bind_method(D_METHOD("get_model"), &LlamaContext::get_model);
1618
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::OBJECT, "model", PROPERTY_HINT_RESOURCE_TYPE, "LlamaModel"), "set_model", "get_model");
1719

18-
ClassDB::bind_method(D_METHOD("get_seed"), &LlamaContext::get_seed);
19-
ClassDB::bind_method(D_METHOD("set_seed", "seed"), &LlamaContext::set_seed);
20-
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "seed"), "set_seed", "get_seed");
20+
ClassDB::bind_method(D_METHOD("get_seed"), &LlamaContext::get_seed);
21+
ClassDB::bind_method(D_METHOD("set_seed", "seed"), &LlamaContext::set_seed);
22+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "seed"), "set_seed", "get_seed");
2123

22-
ClassDB::bind_method(D_METHOD("get_n_ctx"), &LlamaContext::get_n_ctx);
23-
ClassDB::bind_method(D_METHOD("set_n_ctx", "n_ctx"), &LlamaContext::set_n_ctx);
24-
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_ctx"), "set_n_ctx", "get_n_ctx");
24+
ClassDB::bind_method(D_METHOD("get_temperature"), &LlamaContext::get_temperature);
25+
ClassDB::bind_method(D_METHOD("set_temperature", "temperature"), &LlamaContext::set_temperature);
26+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "temperature"), "set_temperature", "get_temperature");
2527

26-
ClassDB::bind_method(D_METHOD("get_n_threads"), &LlamaContext::get_n_threads);
27-
ClassDB::bind_method(D_METHOD("set_n_threads", "n_threads"), &LlamaContext::set_n_threads);
28-
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_threads"), "set_n_threads", "get_n_threads");
28+
ClassDB::bind_method(D_METHOD("get_top_p"), &LlamaContext::get_top_p);
29+
ClassDB::bind_method(D_METHOD("set_top_p", "top_p"), &LlamaContext::set_top_p);
30+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "top_p"), "set_top_p", "get_top_p");
2931

30-
ClassDB::bind_method(D_METHOD("get_n_threads_batch"), &LlamaContext::get_n_threads_batch);
31-
ClassDB::bind_method(D_METHOD("set_n_threads_batch", "n_threads_batch"), &LlamaContext::set_n_threads_batch);
32-
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_threads_batch"), "set_n_threads_batch", "get_n_threads_batch");
32+
ClassDB::bind_method(D_METHOD("get_frequency_penalty"), &LlamaContext::get_frequency_penalty);
33+
ClassDB::bind_method(D_METHOD("set_frequency_penalty", "frequency_penalty"), &LlamaContext::set_frequency_penalty);
34+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "frequency_penalty"), "set_frequency_penalty", "get_frequency_penalty");
35+
36+
ClassDB::bind_method(D_METHOD("get_presence_penalty"), &LlamaContext::get_presence_penalty);
37+
ClassDB::bind_method(D_METHOD("set_presence_penalty", "presence_penalty"), &LlamaContext::set_presence_penalty);
38+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "presence_penalty"), "set_presence_penalty", "get_presence_penalty");
39+
40+
ClassDB::bind_method(D_METHOD("get_n_ctx"), &LlamaContext::get_n_ctx);
41+
ClassDB::bind_method(D_METHOD("set_n_ctx", "n_ctx"), &LlamaContext::set_n_ctx);
42+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_ctx"), "set_n_ctx", "get_n_ctx");
43+
44+
ClassDB::bind_method(D_METHOD("get_n_len"), &LlamaContext::get_n_len);
45+
ClassDB::bind_method(D_METHOD("set_n_len", "n_len"), &LlamaContext::set_n_len);
46+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_len"), "set_n_len", "get_n_len");
3347

3448
ClassDB::bind_method(D_METHOD("request_completion", "prompt"), &LlamaContext::request_completion);
35-
ClassDB::bind_method(D_METHOD("_fulfill_completion", "prompt"), &LlamaContext::_fulfill_completion);
49+
ClassDB::bind_method(D_METHOD("__thread_loop"), &LlamaContext::__thread_loop);
3650

37-
ADD_SIGNAL(MethodInfo("completion_generated", PropertyInfo(Variant::STRING, "completion"), PropertyInfo(Variant::BOOL, "is_final")));
51+
ADD_SIGNAL(MethodInfo("completion_generated", PropertyInfo(Variant::DICTIONARY, "chunk")));
3852
}
3953

4054
LlamaContext::LlamaContext() {
41-
batch = llama_batch_init(4096, 0, 1);
42-
4355
ctx_params = llama_context_default_params();
4456
ctx_params.seed = -1;
4557
ctx_params.n_ctx = 4096;
@@ -60,109 +72,186 @@ void LlamaContext::_ready() {
6072
return;
6173
}
6274

75+
mutex.instantiate();
76+
semaphore.instantiate();
77+
thread.instantiate();
78+
79+
llama_backend_init();
80+
llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_DISABLED);
81+
6382
ctx = llama_new_context_with_model(model->model, ctx_params);
6483
if (ctx == NULL) {
6584
UtilityFunctions::printerr(vformat("%s: Failed to initialize llama context, null ctx", __func__));
6685
return;
6786
}
87+
88+
sampling_ctx = llama_sampling_init(sampling_params);
89+
6890
UtilityFunctions::print(vformat("%s: Context initialized", __func__));
69-
}
7091

71-
PackedStringArray LlamaContext::_get_configuration_warnings() const {
72-
PackedStringArray warnings;
73-
if (model == NULL) {
74-
warnings.push_back("Model resource property not defined");
75-
}
76-
return warnings;
92+
thread->start(callable_mp(this, &LlamaContext::__thread_loop));
7793
}
7894

79-
Variant LlamaContext::request_completion(const String &prompt) {
80-
UtilityFunctions::print(vformat("%s: Requesting completion for prompt: %s", __func__, prompt));
81-
if (task_id) {
82-
WorkerThreadPool::get_singleton()->wait_for_task_completion(task_id);
83-
}
84-
task_id = WorkerThreadPool::get_singleton()->add_task(Callable(this, "_fulfill_completion").bind(prompt));
85-
return OK;
86-
}
95+
void LlamaContext::__thread_loop() {
96+
while (true) {
97+
semaphore->wait();
8798

88-
void LlamaContext::_fulfill_completion(const String &prompt) {
89-
UtilityFunctions::print(vformat("%s: Fulfilling completion for prompt: %s", __func__, prompt));
90-
std::vector<llama_token> tokens_list;
91-
tokens_list = ::llama_tokenize(ctx, std::string(prompt.utf8().get_data()), true);
99+
mutex->lock();
100+
if (exit_thread) {
101+
mutex->unlock();
102+
break;
103+
}
104+
if (completion_requests.size() == 0) {
105+
mutex->unlock();
106+
continue;
107+
}
108+
completion_request req = completion_requests.get(0);
109+
completion_requests.remove_at(0);
110+
mutex->unlock();
92111

93-
const int n_len = 128;
94-
const int n_ctx = llama_n_ctx(ctx);
95-
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
96-
if (n_kv_req > n_ctx) {
97-
UtilityFunctions::printerr(vformat("%s: n_kv_req > n_ctx, the required KV cache size is not big enough\neither reduce n_len or increase n_ctx", __func__));
98-
return;
99-
}
112+
UtilityFunctions::print(vformat("%s: Running completion for prompt id: %d", __func__, req.id));
100113

101-
for (size_t i = 0; i < tokens_list.size(); i++) {
102-
llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
103-
}
114+
std::vector<llama_token> request_tokens;
115+
request_tokens = ::llama_tokenize(ctx, req.prompt.utf8().get_data(), true, true);
104116

105-
batch.logits[batch.n_tokens - 1] = true;
117+
size_t shared_prefix_idx = 0;
118+
auto diff = std::mismatch(context_tokens.begin(), context_tokens.end(), request_tokens.begin(), request_tokens.end());
119+
if (diff.first != context_tokens.end()) {
120+
shared_prefix_idx = std::distance(context_tokens.begin(), diff.first);
121+
} else {
122+
shared_prefix_idx = std::min(context_tokens.size(), request_tokens.size());
123+
}
106124

107-
llama_kv_cache_clear(ctx);
125+
bool rm_success = llama_kv_cache_seq_rm(ctx, -1, shared_prefix_idx, -1);
126+
if (!rm_success) {
127+
UtilityFunctions::printerr(vformat("%s: Failed to remove tokens from kv cache", __func__));
128+
Dictionary response;
129+
response["id"] = req.id;
130+
response["error"] = "Failed to remove tokens from kv cache";
131+
call_thread_safe("emit_signal", "completion_generated", response);
132+
continue;
133+
}
134+
context_tokens.erase(context_tokens.begin() + shared_prefix_idx, context_tokens.end());
135+
request_tokens.erase(request_tokens.begin(), request_tokens.begin() + shared_prefix_idx);
108136

109-
int decode_res = llama_decode(ctx, batch);
110-
if (decode_res != 0) {
111-
UtilityFunctions::printerr(vformat("%s: Failed to decode prompt with error code: %d", __func__, decode_res));
112-
return;
113-
}
137+
uint batch_size = std::min(ctx_params.n_batch, (uint)request_tokens.size());
138+
139+
llama_batch batch = llama_batch_init(batch_size, 0, 1);
140+
141+
// chunk request_tokens into sequences of size batch_size
142+
std::vector<std::vector<llama_token>> sequences;
143+
for (size_t i = 0; i < request_tokens.size(); i += batch_size) {
144+
sequences.push_back(std::vector<llama_token>(request_tokens.begin() + i, request_tokens.begin() + std::min(i + batch_size, request_tokens.size())));
145+
}
146+
147+
printf("Request tokens: \n");
148+
for (auto sequence : sequences) {
149+
for (auto token : sequence) {
150+
printf("%s", llama_token_to_piece(ctx, token).c_str());
151+
}
152+
}
153+
printf("\n");
114154

115-
int n_cur = batch.n_tokens;
116-
int n_decode = 0;
117-
llama_model *llama_model = model->model;
155+
int curr_token_pos = context_tokens.size();
156+
bool decode_failed = false;
118157

119-
while (n_cur <= n_len) {
120-
// sample the next token
121-
{
122-
auto n_vocab = llama_n_vocab(llama_model);
123-
auto *logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
158+
for (size_t i = 0; i < sequences.size(); i++) {
159+
llama_batch_clear(batch);
124160

125-
std::vector<llama_token_data> candidates;
126-
candidates.reserve(n_vocab);
161+
std::vector<llama_token> sequence = sequences[i];
127162

128-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
129-
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
163+
for (size_t j = 0; j < sequence.size(); j++) {
164+
llama_batch_add(batch, sequence[j], j + curr_token_pos, { 0 }, false);
165+
curr_token_pos++;
130166
}
131167

132-
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
168+
if (i == sequences.size() - 1) {
169+
batch.logits[batch.n_tokens - 1] = true;
170+
}
133171

134-
// sample the most likely token
135-
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
172+
if (llama_decode(ctx, batch) != 0) {
173+
decode_failed = true;
174+
break;
175+
}
176+
}
136177

137-
// is it an end of stream?
138-
if (new_token_id == llama_token_eos(llama_model) || n_cur == n_len) {
139-
call_thread_safe("emit_signal", "completion_generated", "\n", true);
178+
if (decode_failed) {
179+
Dictionary response;
180+
response["id"] = req.id;
181+
response["error"] = "llama_decode() failed";
182+
call_thread_safe("emit_signal", "completion_generated", response);
183+
continue;
184+
}
185+
186+
context_tokens.insert(context_tokens.end(), request_tokens.begin(), request_tokens.end());
187+
188+
while (true) {
189+
if (exit_thread) {
190+
return;
191+
}
192+
llama_token new_token_id = llama_sampling_sample(sampling_ctx, ctx, NULL, batch.n_tokens - 1);
193+
llama_sampling_accept(sampling_ctx, ctx, new_token_id, false);
140194

195+
Dictionary response;
196+
response["id"] = req.id;
197+
198+
context_tokens.push_back(new_token_id);
199+
200+
if (llama_token_is_eog(model->model, new_token_id) || curr_token_pos == n_len) {
201+
response["done"] = true;
202+
call_thread_safe("emit_signal", "completion_generated", response);
141203
break;
142204
}
143205

144-
call_thread_safe("emit_signal", "completion_generated", vformat("%s", llama_token_to_piece(ctx, new_token_id).c_str()), false);
206+
response["text"] = llama_token_to_piece(ctx, new_token_id).c_str();
207+
response["done"] = false;
208+
call_thread_safe("emit_signal", "completion_generated", response);
145209

146-
// prepare the next batch
147210
llama_batch_clear(batch);
148211

149-
// push this new token for next evaluation
150-
llama_batch_add(batch, new_token_id, n_cur, { 0 }, true);
212+
llama_batch_add(batch, new_token_id, curr_token_pos, { 0 }, true);
151213

152-
n_decode += 1;
153-
}
214+
curr_token_pos++;
154215

155-
n_cur += 1;
216+
if (llama_decode(ctx, batch) != 0) {
217+
decode_failed = true;
218+
break;
219+
}
220+
}
156221

157-
// evaluate the current batch with the transformer model
158-
int decode_res = llama_decode(ctx, batch);
159-
if (decode_res != 0) {
160-
UtilityFunctions::printerr(vformat("%s: Failed to decode batch with error code: %d", __func__, decode_res));
161-
break;
222+
if (decode_failed) {
223+
Dictionary response;
224+
response["id"] = req.id;
225+
response["error"] = "llama_decode() failed";
226+
call_thread_safe("emit_signal", "completion_generated", response);
227+
continue;
162228
}
163229
}
164230
}
165231

232+
PackedStringArray LlamaContext::_get_configuration_warnings() const {
233+
PackedStringArray warnings;
234+
if (model == NULL) {
235+
warnings.push_back("Model resource property not defined");
236+
}
237+
return warnings;
238+
}
239+
240+
int LlamaContext::request_completion(const String &prompt) {
241+
int id = request_id++;
242+
243+
UtilityFunctions::print(vformat("%s: Requesting completion for prompt id: %d", __func__, id));
244+
245+
mutex->lock();
246+
completion_request req = { id, prompt };
247+
completion_requests.append(req);
248+
mutex->unlock();
249+
250+
semaphore->post();
251+
252+
return id;
253+
}
254+
166255
void LlamaContext::set_model(const Ref<LlamaModel> p_model) {
167256
model = p_model;
168257
}
@@ -184,28 +273,58 @@ void LlamaContext::set_n_ctx(int n_ctx) {
184273
ctx_params.n_ctx = n_ctx;
185274
}
186275

187-
int LlamaContext::get_n_threads() {
188-
return ctx_params.n_threads;
276+
int LlamaContext::get_n_len() {
277+
return n_len;
189278
}
190-
void LlamaContext::set_n_threads(int n_threads) {
191-
ctx_params.n_threads = n_threads;
279+
void LlamaContext::set_n_len(int n_len) {
280+
this->n_len = n_len;
192281
}
193282

194-
int LlamaContext::get_n_threads_batch() {
195-
return ctx_params.n_threads_batch;
283+
float LlamaContext::get_temperature() {
284+
return sampling_params.temp;
196285
}
197-
void LlamaContext::set_n_threads_batch(int n_threads_batch) {
198-
ctx_params.n_threads_batch = n_threads_batch;
286+
void LlamaContext::set_temperature(float temperature) {
287+
sampling_params.temp = temperature;
199288
}
200289

201-
LlamaContext::~LlamaContext() {
202-
if (ctx) {
203-
llama_free(ctx);
290+
float LlamaContext::get_top_p() {
291+
return sampling_params.top_p;
292+
}
293+
void LlamaContext::set_top_p(float top_p) {
294+
sampling_params.top_p = top_p;
295+
}
296+
297+
float LlamaContext::get_frequency_penalty() {
298+
return sampling_params.penalty_freq;
299+
}
300+
void LlamaContext::set_frequency_penalty(float frequency_penalty) {
301+
sampling_params.penalty_freq = frequency_penalty;
302+
}
303+
304+
float LlamaContext::get_presence_penalty() {
305+
return sampling_params.penalty_present;
306+
}
307+
void LlamaContext::set_presence_penalty(float presence_penalty) {
308+
sampling_params.penalty_present = presence_penalty;
309+
}
310+
311+
void LlamaContext::_exit_tree() {
312+
if (Engine::get_singleton()->is_editor_hint()) {
313+
return;
204314
}
205315

206-
llama_batch_free(batch);
316+
mutex->lock();
317+
exit_thread = true;
318+
mutex->unlock();
319+
320+
semaphore->post();
207321

208-
if (task_id) {
209-
WorkerThreadPool::get_singleton()->wait_for_task_completion(task_id);
322+
thread->wait_to_finish();
323+
324+
if (ctx) {
325+
llama_free(ctx);
210326
}
327+
328+
llama_sampling_free(sampling_ctx);
329+
llama_backend_free();
211330
}

‎src/llama_context.h

+41-16
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,38 @@
22
#define LLAMA_CONTEXT_H
33

44
#include "llama.h"
5+
#include "common.h"
56
#include "llama_model.h"
7+
#include <godot_cpp/classes/mutex.hpp>
68
#include <godot_cpp/classes/node.hpp>
7-
9+
#include <godot_cpp/classes/semaphore.hpp>
10+
#include <godot_cpp/classes/thread.hpp>
11+
#include <godot_cpp/templates/vector.hpp>
812
namespace godot {
13+
14+
struct completion_request {
15+
int id;
16+
String prompt;
17+
};
18+
919
class LlamaContext : public Node {
1020
GDCLASS(LlamaContext, Node)
1121

1222
private:
1323
Ref<LlamaModel> model;
1424
llama_context *ctx = nullptr;
25+
llama_sampling_context *sampling_ctx = nullptr;
1526
llama_context_params ctx_params;
16-
llama_batch batch;
17-
int task_id;
27+
llama_sampling_params sampling_params;
28+
int n_len = 1024;
29+
int request_id = 0;
30+
Vector<completion_request> completion_requests;
31+
32+
Ref<Thread> thread;
33+
Ref<Semaphore> semaphore;
34+
Ref<Mutex> mutex;
35+
std::vector<llama_token> context_tokens;
36+
bool exit_thread = false;
1837

1938
protected:
2039
static void _bind_methods();
@@ -23,22 +42,28 @@ class LlamaContext : public Node {
2342
void set_model(const Ref<LlamaModel> model);
2443
Ref<LlamaModel> get_model();
2544

26-
Variant request_completion(const String &prompt);
27-
void _fulfill_completion(const String &prompt);
45+
int request_completion(const String &prompt);
46+
void __thread_loop();
2847

29-
int get_seed();
30-
void set_seed(int seed);
31-
int get_n_ctx();
32-
void set_n_ctx(int n_ctx);
33-
int get_n_threads();
34-
void set_n_threads(int n_threads);
35-
int get_n_threads_batch();
36-
void set_n_threads_batch(int n_threads_batch);
48+
int get_seed();
49+
void set_seed(int seed);
50+
int get_n_ctx();
51+
void set_n_ctx(int n_ctx);
52+
int get_n_len();
53+
void set_n_len(int n_len);
54+
float get_temperature();
55+
void set_temperature(float temperature);
56+
float get_top_p();
57+
void set_top_p(float top_p);
58+
float get_frequency_penalty();
59+
void set_frequency_penalty(float frequency_penalty);
60+
float get_presence_penalty();
61+
void set_presence_penalty(float presence_penalty);
3762

38-
virtual PackedStringArray _get_configuration_warnings() const override;
63+
virtual PackedStringArray _get_configuration_warnings() const override;
3964
virtual void _ready() override;
40-
LlamaContext();
41-
~LlamaContext();
65+
virtual void _exit_tree() override;
66+
LlamaContext();
4267
};
4368
} //namespace godot
4469

‎src/llama_model.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "llama_model.h"
22
#include "llama.h"
3+
#include <godot_cpp/classes/project_settings.hpp>
34
#include <godot_cpp/core/class_db.hpp>
45
#include <godot_cpp/variant/utility_functions.hpp>
56

@@ -22,14 +23,16 @@ void LlamaModel::load_model(const String &path) {
2223
llama_free_model(model);
2324
}
2425

25-
model = llama_load_model_from_file(path.utf8().get_data(), model_params);
26+
String absPath = ProjectSettings::get_singleton()->globalize_path(path);
27+
28+
model = llama_load_model_from_file(absPath.utf8().get_data(), model_params);
2629

2730
if (model == NULL) {
28-
UtilityFunctions::printerr(vformat("%s: Unable to load model from %s", __func__, path));
31+
UtilityFunctions::printerr(vformat("%s: Unable to load model from %s", __func__, absPath));
2932
return;
3033
}
3134

32-
UtilityFunctions::print(vformat("%s: Model loaded from %s", __func__, path));
35+
UtilityFunctions::print(vformat("%s: Model loaded from %s", __func__, absPath));
3336
}
3437

3538
int LlamaModel::get_n_gpu_layers() {

‎src/llama_model_loader.cpp

+1-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#include "llama_model.h"
33
#include <godot_cpp/core/class_db.hpp>
44
#include <godot_cpp/classes/file_access.hpp>
5-
#include <godot_cpp/classes/project_settings.hpp>
65
#include <godot_cpp/classes/engine.hpp>
76

87
using namespace godot;
@@ -24,9 +23,7 @@ Variant godot::LlamaModelLoader::_load(const String &path, const String &origina
2423
return { model };
2524
}
2625

27-
String absPath = ProjectSettings::get_singleton()->globalize_path(path);
28-
29-
model->load_model(absPath);
26+
model->load_model(path);
3027

3128
return { model };
3229
}

‎src/register_types.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include "llama_model.h"
88
#include "llama_model_loader.h"
99
#include "llama_context.h"
10-
#include "llama_backend.h"
1110

1211
using namespace godot;
1312

@@ -24,7 +23,6 @@ void initialize_types(ModuleInitializationLevel p_level)
2423

2524
ClassDB::register_class<LlamaModel>();
2625
ClassDB::register_class<LlamaContext>();
27-
ClassDB::register_class<LlamaBackend>();
2826
}
2927

3028
void uninitialize_types(ModuleInitializationLevel p_level) {

‎tools/expand_metal.zig

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
const std = @import("std");
2+
3+
const usage =
4+
\\Usage: ./embed_metal [options]
5+
\\
6+
\\Options:
7+
\\ --metal-file ggml-metal.metal
8+
\\ --common-file ggml-common.h
9+
\\ --output-file ggml-metal-embed.metal
10+
\\
11+
;
12+
13+
pub fn main() !void {
14+
var arena_state = std.heap.ArenaAllocator.init(std.heap.page_allocator);
15+
defer arena_state.deinit();
16+
const arena = arena_state.allocator();
17+
18+
const args = try std.process.argsAlloc(arena);
19+
20+
var opt_metal_file_path: ?[]const u8 = null;
21+
var opt_common_file_path: ?[]const u8 = null;
22+
var opt_output_file_path: ?[]const u8 = null;
23+
24+
{
25+
var i: usize = 1;
26+
while (i < args.len) : (i += 1) {
27+
const arg = args[i];
28+
if (std.mem.eql(u8, "-h", arg) or std.mem.eql(u8, "--help", arg)) {
29+
try std.io.getStdOut().writeAll(usage);
30+
return std.process.cleanExit();
31+
} else if (std.mem.eql(u8, "--metal-file", arg)) {
32+
i += 1;
33+
if (i > args.len) std.debug.panic("expected arg after '{s}'", .{arg});
34+
if (opt_metal_file_path != null) std.debug.panic("duplicated {s} argument", .{arg});
35+
opt_metal_file_path = args[i];
36+
} else if (std.mem.eql(u8, "--common-file", arg)) {
37+
i += 1;
38+
if (i > args.len) std.debug.panic("expected arg after '{s}'", .{arg});
39+
if (opt_common_file_path != null) std.debug.panic("duplicated {s} argument", .{arg});
40+
opt_common_file_path = args[i];
41+
} else if (std.mem.eql(u8, "--output-file", arg)) {
42+
i += 1;
43+
if (i > args.len) std.debug.panic("expected arg after '{s}'", .{arg});
44+
if (opt_output_file_path != null) std.debug.panic("duplicated {s} argument", .{arg});
45+
opt_output_file_path = args[i];
46+
} else {
47+
std.debug.panic("unrecognized arg: '{s}'", .{arg});
48+
}
49+
}
50+
}
51+
52+
const metal_file_path = opt_metal_file_path orelse std.debug.panic("missing --input-file", .{});
53+
const common_file_path = opt_common_file_path orelse std.debug.panic("missing --output-file", .{});
54+
const output_file_path = opt_output_file_path orelse std.debug.panic("missing --lang", .{});
55+
56+
const cwd = std.fs.cwd();
57+
58+
var metal_file = try cwd.openFile(metal_file_path, .{});
59+
defer metal_file.close();
60+
61+
var common_file = try cwd.openFile(common_file_path, .{});
62+
defer common_file.close();
63+
64+
const metal_size = (try metal_file.stat()).size;
65+
const metal_contents = try arena.alloc(u8, metal_size);
66+
defer arena.free(metal_contents);
67+
_ = try metal_file.readAll(metal_contents);
68+
69+
const common_size = (try common_file.stat()).size;
70+
const common_contents = try arena.alloc(u8, common_size);
71+
defer arena.free(common_contents);
72+
_ = try common_file.readAll(common_contents);
73+
74+
const output = try std.mem.replaceOwned(u8, arena, metal_contents, "#include \"ggml-common.h\"", common_contents);
75+
defer arena.free(output);
76+
77+
const output_file = try cwd.createFile(output_file_path, .{});
78+
try output_file.writeAll(output);
79+
}

0 commit comments

Comments
 (0)
Please sign in to comment.