Skip to content

Commit ad808a1

Browse files
committed
add chat formatters
1 parent f73ef8c commit ad808a1

File tree

6 files changed

+70
-13
lines changed

6 files changed

+70
-13
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
class_name ChatFormatter
2+
3+
static func apply(format: String, messages: Array) -> String:
4+
match format:
5+
"llama3":
6+
return format_llama3(messages)
7+
"phi3":
8+
return format_phi3(messages)
9+
_:
10+
printerr("Unknown chat format: ", format)
11+
return ""
12+
13+
static func format_llama3(messages: Array) -> String:
14+
var res = "<|begin_of_text|>"
15+
16+
for i in range(messages.size()):
17+
match messages[i]:
18+
{"text": var text, "sender": var sender}:
19+
res += """<|start_header_id|>%s<|end_header_id|>
20+
21+
%s<|eot_id|>
22+
""" % [sender, text]
23+
_:
24+
printerr("Invalid message at index ", i)
25+
26+
res += "<|start_header_id|>assistant<|end_header_id|>\n\n"
27+
return res
28+
29+
static func format_phi3(messages: Array) -> String:
30+
var res = "<s>"
31+
32+
for i in range(messages.size()):
33+
match messages[i]:
34+
{"text": var text, "sender": var sender}:
35+
res +="<|%s|>\n%s<|end|>\n" % [sender, text]
36+
_:
37+
printerr("Invalid message at index ", i)
38+
res += "<|assistant|>\n"
39+
return res

godot/examples/simple/message.gd

+6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ extends Node
44
@onready var text_container = %Text
55
@onready var icon = %Panel
66
@export_enum("user", "assistant") var sender: String
7+
@export var include_in_prompt: bool = true
8+
var text:
9+
get:
10+
return text_container.text
11+
set(value):
12+
text_container.text = value
713

814
var completion_id: int = -1
915
var pending: bool = false

godot/examples/simple/simple.gd

+18-8
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,16 @@ func _on_text_edit_submit(input: String) -> void:
99
handle_input(input)
1010

1111
func handle_input(input: String) -> void:
12-
var completion_id = llama_context.request_completion(input)
12+
var messages = [{ "sender": "system", "text": "You are a helpful assistant" }]
13+
messages.append_array(messages_container.get_children().filter(func(msg: Message): return msg.include_in_prompt).map(
14+
func(msg: Message) -> Dictionary:
15+
return { "text": msg.text, "sender": msg.sender }
16+
))
17+
messages.append({"text": input, "sender": "user"})
18+
var prompt = ChatFormatter.apply("phi3", messages)
19+
print("prompt: ", prompt)
20+
21+
var completion_id = llama_context.request_completion(prompt)
1322

1423
var user_message: Message = message.instantiate()
1524
messages_container.add_child(user_message)
@@ -22,19 +31,20 @@ func handle_input(input: String) -> void:
2231
ai_message.sender = "assistant"
2332
ai_message.completion_id = completion_id
2433
ai_message.pending = true
34+
ai_message.grab_focus()
2535

2636

2737

2838
func _on_llama_context_completion_generated(chunk: Dictionary) -> void:
2939
var completion_id = chunk.id
30-
for message: Message in messages_container.get_children():
31-
if message.completion_id != completion_id or message.sender != "assistant":
40+
for msg: Message in messages_container.get_children():
41+
if msg.completion_id != completion_id or msg.sender != "assistant":
3242
continue
3343
if chunk.has("error"):
34-
message.errored = true
44+
msg.errored = true
3545
elif chunk.has("text"):
36-
if message.pending:
37-
message.pending = false
38-
message.set_text(chunk["text"])
46+
if msg.pending:
47+
msg.pending = false
48+
msg.set_text(chunk["text"])
3949
else:
40-
message.append_text(chunk["text"])
50+
msg.append_text(chunk["text"])

godot/examples/simple/simple.tscn

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
[gd_scene load_steps=6 format=3 uid="uid://c55kb4qvg6geq"]
22

3-
[ext_resource type="LlamaModel" path="res://models/Phi-3-mini-128k-instruct.Q5_K_M.gguf" id="1_ff70a"]
43
[ext_resource type="Texture2D" uid="uid://dplw232htshgc" path="res://addons/godot-llama-cpp/assets/godot-llama-cpp-1024x1024.svg" id="1_gjsev"]
54
[ext_resource type="Script" path="res://examples/simple/simple.gd" id="1_sruc3"]
65
[ext_resource type="PackedScene" uid="uid://t862t0v8ht2q" path="res://examples/simple/message.tscn" id="2_7iip7"]
76
[ext_resource type="Script" path="res://examples/simple/TextEdit.gd" id="2_7usqw"]
7+
[ext_resource type="LlamaModel" path="res://models/Phi-3-mini-128k-instruct.Q5_K_M.gguf" id="5_qpeda"]
88

99
[node name="Node" type="Node"]
1010
script = ExtResource("1_sruc3")
@@ -34,6 +34,7 @@ layout_mode = 2
3434
[node name="ScrollContainer" type="ScrollContainer" parent="Panel/MarginContainer/VBoxContainer"]
3535
layout_mode = 2
3636
size_flags_vertical = 3
37+
follow_focus = true
3738

3839
[node name="MessagesContainer" type="VBoxContainer" parent="Panel/MarginContainer/VBoxContainer/ScrollContainer"]
3940
unique_name_in_owner = true
@@ -44,6 +45,7 @@ theme_override_constants/separation = 30
4445

4546
[node name="RichTextLabel2" parent="Panel/MarginContainer/VBoxContainer/ScrollContainer/MessagesContainer" instance=ExtResource("2_7iip7")]
4647
layout_mode = 2
48+
include_in_prompt = false
4749

4850
[node name="Text" parent="Panel/MarginContainer/VBoxContainer/ScrollContainer/MessagesContainer/RichTextLabel2" index="1"]
4951
text = "How can I help you?"
@@ -66,7 +68,7 @@ icon = ExtResource("1_gjsev")
6668
expand_icon = true
6769

6870
[node name="LlamaContext" type="LlamaContext" parent="."]
69-
model = ExtResource("1_ff70a")
71+
model = ExtResource("5_qpeda")
7072
unique_name_in_owner = true
7173

7274
[connection signal="submit" from="Panel/MarginContainer/VBoxContainer/HBoxContainer/TextEdit" to="." method="_on_text_edit_submit"]

llama.cpp

src/llama_context.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ void LlamaContext::__thread_loop() {
9696
UtilityFunctions::print(vformat("%s: Running completion for prompt id: %d", __func__, req.id));
9797

9898
std::vector<llama_token> request_tokens;
99-
request_tokens = ::llama_tokenize(ctx, req.prompt.utf8().get_data(), true);
99+
request_tokens = ::llama_tokenize(ctx, req.prompt.utf8().get_data(), true, true);
100100

101101
size_t shared_prefix_idx = 0;
102102
auto diff = std::mismatch(context_tokens.begin(), context_tokens.end(), request_tokens.begin(), request_tokens.end());
@@ -166,7 +166,7 @@ void LlamaContext::__thread_loop() {
166166
return;
167167
}
168168
llama_token new_token_id = llama_sampling_sample(sampling_ctx, ctx, NULL, batch.n_tokens - 1);
169-
llama_sampling_accept(sampling_ctx, ctx, new_token_id, true);
169+
llama_sampling_accept(sampling_ctx, ctx, new_token_id, false);
170170

171171
Dictionary response;
172172
response["id"] = req.id;

0 commit comments

Comments
 (0)