Skip to content

Commit 17f069c

Browse files
committed
expose some sampling params
1 parent ad808a1 commit 17f069c

File tree

7 files changed

+98
-18
lines changed

7 files changed

+98
-18
lines changed

godot/addons/godot-llama-cpp/chat/chat_formatter.gd

+19-2
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ static func apply(format: String, messages: Array) -> String:
66
return format_llama3(messages)
77
"phi3":
88
return format_phi3(messages)
9+
"mistral":
10+
return format_mistral(messages)
911
_:
1012
printerr("Unknown chat format: ", format)
1113
return ""
1214

1315
static func format_llama3(messages: Array) -> String:
14-
var res = "<|begin_of_text|>"
16+
var res = ""
1517

1618
for i in range(messages.size()):
1719
match messages[i]:
@@ -27,7 +29,7 @@ static func format_llama3(messages: Array) -> String:
2729
return res
2830

2931
static func format_phi3(messages: Array) -> String:
30-
var res = "<s>"
32+
var res = ""
3133

3234
for i in range(messages.size()):
3335
match messages[i]:
@@ -37,3 +39,18 @@ static func format_phi3(messages: Array) -> String:
3739
printerr("Invalid message at index ", i)
3840
res += "<|assistant|>\n"
3941
return res
42+
43+
static func format_mistral(messages: Array) -> String:
44+
var res = ""
45+
46+
for i in range(messages.size()):
47+
match messages[i]:
48+
{"text": var text, "sender": var sender}:
49+
if sender == "user":
50+
res += "[INST] %s [/INST]" % text
51+
else:
52+
res += "%s</s>"
53+
_:
54+
printerr("Invalid message at index ", i)
55+
56+
return res

godot/addons/godot-llama-cpp/plugin.gdextension

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ compatibility_minimum = "4.2"
55

66
[libraries]
77

8-
macos.debug = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp-aarch64-macos-none-Debug.dylib"
8+
macos.debug = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp-aarch64-macos-none-ReleaseSafe.dylib"
99
macos.release = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp-aarch64-macos-none-ReleaseSafe.dylib"
1010
windows.debug.x86_32 = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp.windows.template_debug.x86_32.dll"
1111
windows.release.x86_32 = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp.windows.template_release.x86_32.dll"

godot/examples/simple/simple.gd

+5-2
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@ func _on_text_edit_submit(input: String) -> void:
99
handle_input(input)
1010

1111
func handle_input(input: String) -> void:
12-
var messages = [{ "sender": "system", "text": "You are a helpful assistant" }]
12+
#var messages = [{ "sender": "system", "text": "You are a pirate chatbot who always responds in pirate speak!" }]
13+
14+
#var messages = [{ "sender": "system", "text": "You are a helpful chatbot assistant!" }]
15+
var messages = []
1316
messages.append_array(messages_container.get_children().filter(func(msg: Message): return msg.include_in_prompt).map(
1417
func(msg: Message) -> Dictionary:
1518
return { "text": msg.text, "sender": msg.sender }
1619
))
1720
messages.append({"text": input, "sender": "user"})
18-
var prompt = ChatFormatter.apply("phi3", messages)
21+
var prompt = ChatFormatter.apply("llama3", messages)
1922
print("prompt: ", prompt)
2023

2124
var completion_id = llama_context.request_completion(prompt)

godot/examples/simple/simple.tscn

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
[ext_resource type="Script" path="res://examples/simple/simple.gd" id="1_sruc3"]
55
[ext_resource type="PackedScene" uid="uid://t862t0v8ht2q" path="res://examples/simple/message.tscn" id="2_7iip7"]
66
[ext_resource type="Script" path="res://examples/simple/TextEdit.gd" id="2_7usqw"]
7-
[ext_resource type="LlamaModel" path="res://models/Phi-3-mini-128k-instruct.Q5_K_M.gguf" id="5_qpeda"]
7+
[ext_resource type="LlamaModel" path="res://models/meta-llama-3-8b-instruct.Q5_K_M.gguf" id="5_qov1l"]
88

99
[node name="Node" type="Node"]
1010
script = ExtResource("1_sruc3")
@@ -68,7 +68,8 @@ icon = ExtResource("1_gjsev")
6868
expand_icon = true
6969

7070
[node name="LlamaContext" type="LlamaContext" parent="."]
71-
model = ExtResource("5_qpeda")
71+
model = ExtResource("5_qov1l")
72+
temperature = 0.9
7273
unique_name_in_owner = true
7374

7475
[connection signal="submit" from="Panel/MarginContainer/VBoxContainer/HBoxContainer/TextEdit" to="." method="_on_text_edit_submit"]

llama.cpp

src/llama_context.cpp

+61-10
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,22 @@ void LlamaContext::_bind_methods() {
2121
ClassDB::bind_method(D_METHOD("set_seed", "seed"), &LlamaContext::set_seed);
2222
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "seed"), "set_seed", "get_seed");
2323

24+
ClassDB::bind_method(D_METHOD("get_temperature"), &LlamaContext::get_temperature);
25+
ClassDB::bind_method(D_METHOD("set_temperature", "temperature"), &LlamaContext::set_temperature);
26+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "temperature"), "set_temperature", "get_temperature");
27+
28+
ClassDB::bind_method(D_METHOD("get_top_p"), &LlamaContext::get_top_p);
29+
ClassDB::bind_method(D_METHOD("set_top_p", "top_p"), &LlamaContext::set_top_p);
30+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "top_p"), "set_top_p", "get_top_p");
31+
32+
ClassDB::bind_method(D_METHOD("get_frequency_penalty"), &LlamaContext::get_frequency_penalty);
33+
ClassDB::bind_method(D_METHOD("set_frequency_penalty", "frequency_penalty"), &LlamaContext::set_frequency_penalty);
34+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "frequency_penalty"), "set_frequency_penalty", "get_frequency_penalty");
35+
36+
ClassDB::bind_method(D_METHOD("get_presence_penalty"), &LlamaContext::get_presence_penalty);
37+
ClassDB::bind_method(D_METHOD("set_presence_penalty", "presence_penalty"), &LlamaContext::set_presence_penalty);
38+
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "presence_penalty"), "set_presence_penalty", "get_presence_penalty");
39+
2440
ClassDB::bind_method(D_METHOD("get_n_ctx"), &LlamaContext::get_n_ctx);
2541
ClassDB::bind_method(D_METHOD("set_n_ctx", "n_ctx"), &LlamaContext::set_n_ctx);
2642
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_ctx"), "set_n_ctx", "get_n_ctx");
@@ -106,13 +122,13 @@ void LlamaContext::__thread_loop() {
106122
shared_prefix_idx = std::min(context_tokens.size(), request_tokens.size());
107123
}
108124

109-
bool rm_success = llama_kv_cache_seq_rm(ctx, 0, shared_prefix_idx, -1);
125+
bool rm_success = llama_kv_cache_seq_rm(ctx, -1, shared_prefix_idx, -1);
110126
if (!rm_success) {
111127
UtilityFunctions::printerr(vformat("%s: Failed to remove tokens from kv cache", __func__));
112128
Dictionary response;
113129
response["id"] = req.id;
114130
response["error"] = "Failed to remove tokens from kv cache";
115-
call_deferred("emit_signal", "completion_generated", response);
131+
call_thread_safe("emit_signal", "completion_generated", response);
116132
continue;
117133
}
118134
context_tokens.erase(context_tokens.begin() + shared_prefix_idx, context_tokens.end());
@@ -128,6 +144,14 @@ void LlamaContext::__thread_loop() {
128144
sequences.push_back(std::vector<llama_token>(request_tokens.begin() + i, request_tokens.begin() + std::min(i + batch_size, request_tokens.size())));
129145
}
130146

147+
printf("Request tokens: \n");
148+
for (auto sequence : sequences) {
149+
for (auto token : sequence) {
150+
printf("%s", llama_token_to_piece(ctx, token).c_str());
151+
}
152+
}
153+
printf("\n");
154+
131155
int curr_token_pos = context_tokens.size();
132156
bool decode_failed = false;
133157

@@ -155,7 +179,7 @@ void LlamaContext::__thread_loop() {
155179
Dictionary response;
156180
response["id"] = req.id;
157181
response["error"] = "llama_decode() failed";
158-
call_deferred("emit_signal", "completion_generated", response);
182+
call_thread_safe("emit_signal", "completion_generated", response);
159183
continue;
160184
}
161185

@@ -171,17 +195,17 @@ void LlamaContext::__thread_loop() {
171195
Dictionary response;
172196
response["id"] = req.id;
173197

198+
context_tokens.push_back(new_token_id);
199+
174200
if (llama_token_is_eog(model->model, new_token_id) || curr_token_pos == n_len) {
175201
response["done"] = true;
176-
call_deferred("emit_signal", "completion_generated", response);
202+
call_thread_safe("emit_signal", "completion_generated", response);
177203
break;
178204
}
179205

180-
context_tokens.push_back(new_token_id);
181-
182206
response["text"] = llama_token_to_piece(ctx, new_token_id).c_str();
183207
response["done"] = false;
184-
call_deferred("emit_signal", "completion_generated", response);
208+
call_thread_safe("emit_signal", "completion_generated", response);
185209

186210
llama_batch_clear(batch);
187211

@@ -199,11 +223,9 @@ void LlamaContext::__thread_loop() {
199223
Dictionary response;
200224
response["id"] = req.id;
201225
response["error"] = "llama_decode() failed";
202-
call_deferred("emit_signal", "completion_generated", response);
226+
call_thread_safe("emit_signal", "completion_generated", response);
203227
continue;
204228
}
205-
206-
llama_sampling_reset(sampling_ctx);
207229
}
208230
}
209231

@@ -258,6 +280,34 @@ void LlamaContext::set_n_len(int n_len) {
258280
this->n_len = n_len;
259281
}
260282

283+
float LlamaContext::get_temperature() {
284+
return sampling_params.temp;
285+
}
286+
void LlamaContext::set_temperature(float temperature) {
287+
sampling_params.temp = temperature;
288+
}
289+
290+
float LlamaContext::get_top_p() {
291+
return sampling_params.top_p;
292+
}
293+
void LlamaContext::set_top_p(float top_p) {
294+
sampling_params.top_p = top_p;
295+
}
296+
297+
float LlamaContext::get_frequency_penalty() {
298+
return sampling_params.penalty_freq;
299+
}
300+
void LlamaContext::set_frequency_penalty(float frequency_penalty) {
301+
sampling_params.penalty_freq = frequency_penalty;
302+
}
303+
304+
float LlamaContext::get_presence_penalty() {
305+
return sampling_params.penalty_present;
306+
}
307+
void LlamaContext::set_presence_penalty(float presence_penalty) {
308+
sampling_params.penalty_present = presence_penalty;
309+
}
310+
261311
void LlamaContext::_exit_tree() {
262312
if (Engine::get_singleton()->is_editor_hint()) {
263313
return;
@@ -275,5 +325,6 @@ void LlamaContext::_exit_tree() {
275325
llama_free(ctx);
276326
}
277327

328+
llama_sampling_free(sampling_ctx);
278329
llama_backend_free();
279330
}

src/llama_context.h

+8
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@ class LlamaContext : public Node {
5151
void set_n_ctx(int n_ctx);
5252
int get_n_len();
5353
void set_n_len(int n_len);
54+
float get_temperature();
55+
void set_temperature(float temperature);
56+
float get_top_p();
57+
void set_top_p(float top_p);
58+
float get_frequency_penalty();
59+
void set_frequency_penalty(float frequency_penalty);
60+
float get_presence_penalty();
61+
void set_presence_penalty(float presence_penalty);
5462

5563
virtual PackedStringArray _get_configuration_warnings() const override;
5664
virtual void _ready() override;

0 commit comments

Comments
 (0)