diff --git a/build.zig b/build.zig
index 711165b..aadce33 100644
--- a/build.zig
+++ b/build.zig
@@ -7,11 +7,41 @@ const extension_name = "godot-llama-cpp";
pub fn build(b: *std.Build) !void {
const target = b.standardTargetOptions(.{});
const optimize = b.standardOptimizeOption(.{});
- const zig_triple = try target.result.linuxTriple(b.allocator);
+ const triple = try target.result.linuxTriple(b.allocator);
- var objs = std.ArrayList(*std.Build.Step.Compile).init(b.allocator);
+ const lib_godot_cpp = try build_lib_godot_cpp(.{ .b = b, .target = target, .optimize = optimize });
+ const lib_llama_cpp = try build_lib_llama_cpp(.{ .b = b, .target = target, .optimize = optimize });
+
+ const plugin = b.addSharedLibrary(.{
+ .name = b.fmt("{s}-{s}-{s}", .{ extension_name, triple, @tagName(optimize) }),
+ .target = target,
+ .optimize = optimize,
+ });
+ plugin.addCSourceFiles(.{ .files = try findFilesRecursive(b, "src/", &cfiles_exts) });
+ plugin.addIncludePath(.{ .path = "src/" });
+ plugin.addIncludePath(.{ .path = "godot_cpp/gdextension/" });
+ plugin.addIncludePath(.{ .path = "godot_cpp/include/" });
+ plugin.addIncludePath(.{ .path = "godot_cpp/gen/include" });
+ plugin.addIncludePath(.{ .path = "llama.cpp" });
+ plugin.addIncludePath(.{ .path = "llama.cpp/common" });
+ plugin.linkLibrary(lib_llama_cpp);
+ plugin.linkLibrary(lib_godot_cpp);
+
+ b.lib_dir = "./godot/addons/godot-llama-cpp/lib";
+ b.installArtifact(plugin);
+}
+
+const BuildParams = struct {
+ b: *std.Build,
+ target: std.Build.ResolvedTarget,
+ optimize: std.builtin.OptimizeMode,
+};
+
+fn build_lib_godot_cpp(params: BuildParams) !*std.Build.Step.Compile {
+ const b = params.b;
+ const target = params.target;
+ const optimize = params.optimize;
- // godot-cpp
const lib_godot = b.addStaticLibrary(.{
.name = "godot-cpp",
.target = target,
@@ -26,9 +56,7 @@ pub fn build(b: *std.Build) !void {
.cwd_dir = b.build_root.handle,
});
},
- else => {
- return;
- },
+ else => {},
}
};
lib_godot.linkLibCpp();
@@ -40,7 +68,15 @@ pub fn build(b: *std.Build) !void {
lib_godot.addCSourceFiles(.{ .files = lib_godot_gen_sources, .flags = &.{ "-std=c++17", "-fno-exceptions" } });
lib_godot.addCSourceFiles(.{ .files = lib_godot_sources, .flags = &.{ "-std=c++17", "-fno-exceptions" } });
- // llama.cpp
+ return lib_godot;
+}
+
+fn build_lib_llama_cpp(params: BuildParams) !*std.Build.Step.Compile {
+ const b = params.b;
+ const target = params.target;
+ const optimize = params.optimize;
+ const zig_triple = try target.result.zigTriple(b.allocator);
+
const commit_hash = try std.ChildProcess.run(.{ .allocator = b.allocator, .argv = &.{ "git", "rev-parse", "HEAD" }, .cwd = b.pathFromRoot("llama.cpp") });
const zig_version = builtin.zig_version_string;
try b.build_root.handle.writeFile2(.{ .sub_path = "llama.cpp/common/build-info.cpp", .data = b.fmt(
@@ -48,259 +84,102 @@ pub fn build(b: *std.Build) !void {
\\char const *LLAMA_COMMIT = "{s}";
\\char const *LLAMA_COMPILER = "Zig {s}";
\\char const *LLAMA_BUILD_TARGET = "{s}";
- \\
, .{ 0, commit_hash.stdout[0 .. commit_hash.stdout.len - 1], zig_version, zig_triple }) });
- var flags = std.ArrayList([]const u8).init(b.allocator);
- if (target.result.abi != .msvc) try flags.append("-D_GNU_SOURCE");
- if (target.result.os.tag == .macos) try flags.appendSlice(&.{
- "-D_DARWIN_C_SOURCE",
- "-DGGML_USE_METAL",
- "-DGGML_USE_ACCELERATE",
- "-DACCELERATE_USE_LAPACK",
- "-DACCELERATE_LAPACK_ILP64",
- }) else try flags.append("-DGGML_USE_VULKAN");
- try flags.append("-D_XOPEN_SOURCE=600");
+ const lib_llama_cpp = b.addStaticLibrary(.{ .name = "llama.cpp", .target = target, .optimize = optimize });
- var cflags = std.ArrayList([]const u8).init(b.allocator);
- try cflags.append("-std=c11");
- try cflags.appendSlice(flags.items);
- var cxxflags = std.ArrayList([]const u8).init(b.allocator);
- try cxxflags.append("-std=c++11");
- try cxxflags.appendSlice(flags.items);
-
- const include_paths = [_][]const u8{ "llama.cpp", "llama.cpp/common" };
- const llama = buildObj(.{
- .b = b,
- .name = "llama",
- .target = target,
- .optimize = optimize,
- .sources = &.{"llama.cpp/llama.cpp"},
- .include_paths = &include_paths,
- .link_lib_cpp = true,
- .link_lib_c = false,
- .flags = cxxflags.items,
- });
- const ggml = buildObj(.{
- .b = b,
- .name = "ggml",
- .target = target,
- .optimize = optimize,
- .sources = &.{"llama.cpp/ggml.c"},
- .include_paths = &include_paths,
- .link_lib_c = true,
- .link_lib_cpp = false,
- .flags = cflags.items,
- });
- const common = buildObj(.{
- .b = b,
- .name = "common",
- .target = target,
- .optimize = optimize,
- .sources = &.{"llama.cpp/common/common.cpp"},
- .include_paths = &include_paths,
- .link_lib_cpp = true,
- .link_lib_c = false,
- .flags = cxxflags.items,
- });
- const console = buildObj(.{
- .b = b,
- .name = "console",
- .target = target,
- .optimize = optimize,
- .sources = &.{"llama.cpp/common/console.cpp"},
- .include_paths = &include_paths,
- .link_lib_cpp = true,
- .link_lib_c = false,
- .flags = cxxflags.items,
- });
- const sampling = buildObj(.{
- .b = b,
- .name = "sampling",
- .target = target,
- .optimize = optimize,
- .sources = &.{"llama.cpp/common/sampling.cpp"},
- .include_paths = &include_paths,
- .link_lib_cpp = true,
- .link_lib_c = false,
- .flags = cxxflags.items,
- });
- const grammar_parser = buildObj(.{
- .b = b,
- .name = "grammar_parser",
- .target = target,
- .optimize = optimize,
- .sources = &.{"llama.cpp/common/grammar-parser.cpp"},
- .include_paths = &include_paths,
- .link_lib_cpp = true,
- .link_lib_c = false,
- .flags = cxxflags.items,
- });
- const build_info = buildObj(.{
- .b = b,
- .name = "build_info",
- .target = target,
- .optimize = optimize,
- .sources = &.{"llama.cpp/common/build-info.cpp"},
- .include_paths = &include_paths,
- .link_lib_cpp = true,
- .link_lib_c = false,
- .flags = cxxflags.items,
- });
- const ggml_alloc = buildObj(.{
- .b = b,
- .name = "ggml_alloc",
- .target = target,
- .optimize = optimize,
- .sources = &.{"llama.cpp/ggml-alloc.c"},
- .include_paths = &include_paths,
- .link_lib_c = true,
- .link_lib_cpp = false,
- .flags = cflags.items,
- });
- const ggml_backend = buildObj(.{
- .b = b,
- .name = "ggml_backend",
- .target = target,
- .optimize = optimize,
- .sources = &.{"llama.cpp/ggml-backend.c"},
- .include_paths = &include_paths,
- .link_lib_c = true,
- .link_lib_cpp = false,
- .flags = cflags.items,
- });
- const ggml_quants = buildObj(.{
- .b = b,
- .name = "ggml_quants",
- .target = target,
- .optimize = optimize,
- .sources = &.{"llama.cpp/ggml-quants.c"},
- .include_paths = &include_paths,
- .link_lib_c = true,
- .link_lib_cpp = false,
- .flags = cflags.items,
- });
- const unicode = buildObj(.{
+ var objs = std.ArrayList(*std.Build.Step.Compile).init(b.allocator);
+ var objBuilder = ObjBuilder.init(.{
.b = b,
- .name = "unicode",
.target = target,
.optimize = optimize,
- .sources = &.{"llama.cpp/unicode.cpp"},
- .include_paths = &include_paths,
- .link_lib_c = false,
- .link_lib_cpp = true,
- .flags = cxxflags.items,
+ .include_paths = &.{ "llama.cpp", "llama.cpp/common" },
});
- try objs.appendSlice(&.{ llama, ggml, common, console, sampling, grammar_parser, build_info, ggml_alloc, ggml_backend, ggml_quants, unicode });
-
- if (target.result.os.tag == .macos) {
- const ggml_metal = buildObj(.{
- .b = b,
- .name = "ggml_metal",
- .target = target,
- .optimize = optimize,
- .sources = &.{"llama.cpp/ggml-metal.m"},
- .include_paths = &include_paths,
- .link_lib_c = true,
- .link_lib_cpp = false,
- .flags = cflags.items,
- });
- const airCommand = b.addSystemCommand(&.{ "xcrun", "-sdk", "macosx", "metal", "-O3", "-c" });
- airCommand.addFileArg(.{ .path = "llama.cpp/ggml-metal.metal" });
- airCommand.addArg("-o");
- const air = airCommand.addOutputFileArg("ggml-metal.air");
- const libCommand = b.addSystemCommand(&.{ "xcrun", "-sdk", "macosx", "metallib" });
- libCommand.addFileArg(air);
- libCommand.addArg("-o");
- const lib = libCommand.addOutputFileArg("default.metallib");
- const libInstall = b.addInstallLibFile(lib, "default.metallib");
- b.getInstallStep().dependOn(&libInstall.step);
- try objs.append(ggml_metal);
- } else {
- const ggml_vulkan = buildObj(.{
- .b = b,
- .name = "ggml_vulkan",
- .target = target,
- .optimize = optimize,
- .sources = &.{"llama.cpp/ggml-vulkan.cpp"},
- .include_paths = &include_paths,
- .link_lib_cpp = true,
- .link_lib_c = false,
- .flags = cxxflags.items,
- });
- try objs.append(ggml_vulkan);
+ switch (target.result.os.tag) {
+ .macos => {
+ try objBuilder.flags.append("-DGGML_USE_METAL");
+ try objs.append(objBuilder.build(.{ .name = "ggml_metal", .sources = &.{"llama.cpp/ggml-metal.m"} }));
+
+ lib_llama_cpp.linkFramework("Foundation");
+ lib_llama_cpp.linkFramework("Metal");
+ lib_llama_cpp.linkFramework("MetalKit");
+
+ const expand_metal = b.addExecutable(.{
+ .name = "expand_metal",
+ .target = target,
+ .root_source_file = .{ .path = "tools/expand_metal.zig" },
+ });
+ var run_expand_metal = b.addRunArtifact(expand_metal);
+ run_expand_metal.addArg("--metal-file");
+ run_expand_metal.addFileArg(.{ .path = "llama.cpp/ggml-metal.metal" });
+ run_expand_metal.addArg("--common-file");
+ run_expand_metal.addFileArg(.{ .path = "llama.cpp/ggml-common.h" });
+ run_expand_metal.addArg("--output-file");
+ const metal_expanded = run_expand_metal.addOutputFileArg("ggml-metal.metal");
+ const install_metal = b.addInstallFileWithDir(metal_expanded, .lib, "ggml-metal.metal");
+ lib_llama_cpp.step.dependOn(&install_metal.step);
+ },
+ else => {},
}
- const extension = b.addSharedLibrary(.{ .name = b.fmt("{s}-{s}-{s}", .{ extension_name, zig_triple, @tagName(optimize) }), .target = target, .optimize = optimize });
- const sources = try findFilesRecursive(b, "src", &cfiles_exts);
- extension.addCSourceFiles(.{ .files = sources, .flags = &.{ "-std=c++17", "-fno-exceptions" } });
- extension.addIncludePath(.{ .path = "src" });
- extension.addIncludePath(.{ .path = "godot_cpp/include/" });
- extension.addIncludePath(.{ .path = "godot_cpp/gdextension/" });
- extension.addIncludePath(.{ .path = "godot_cpp/gen/include/" });
- extension.addIncludePath(.{ .path = "llama.cpp" });
- extension.addIncludePath(.{ .path = "llama.cpp/common" });
+ try objs.appendSlice(&.{
+ objBuilder.build(.{ .name = "ggml", .sources = &.{"llama.cpp/ggml.c"} }),
+ objBuilder.build(.{ .name = "sgemm", .sources = &.{"llama.cpp/sgemm.cpp"} }),
+ objBuilder.build(.{ .name = "ggml_alloc", .sources = &.{"llama.cpp/ggml-alloc.c"} }),
+ objBuilder.build(.{ .name = "ggml_backend", .sources = &.{"llama.cpp/ggml-backend.c"} }),
+ objBuilder.build(.{ .name = "ggml_quants", .sources = &.{"llama.cpp/ggml-quants.c"} }),
+ objBuilder.build(.{ .name = "llama", .sources = &.{"llama.cpp/llama.cpp"} }),
+ objBuilder.build(.{ .name = "unicode", .sources = &.{"llama.cpp/unicode.cpp"} }),
+ objBuilder.build(.{ .name = "unicode_data", .sources = &.{"llama.cpp/unicode-data.cpp"} }),
+ objBuilder.build(.{ .name = "common", .sources = &.{"llama.cpp/common/common.cpp"} }),
+ objBuilder.build(.{ .name = "console", .sources = &.{"llama.cpp/common/console.cpp"} }),
+ objBuilder.build(.{ .name = "sampling", .sources = &.{"llama.cpp/common/sampling.cpp"} }),
+ objBuilder.build(.{ .name = "grammar_parser", .sources = &.{"llama.cpp/common/grammar-parser.cpp"} }),
+ objBuilder.build(.{ .name = "json_schema_to_grammar", .sources = &.{"llama.cpp/common/json-schema-to-grammar.cpp"} }),
+ objBuilder.build(.{ .name = "build_info", .sources = &.{"llama.cpp/common/build-info.cpp"} }),
+ });
+
for (objs.items) |obj| {
- extension.addObject(obj);
- }
- extension.linkLibC();
- extension.linkLibCpp();
- if (target.result.os.tag == .macos) {
- extension.linkFramework("Metal");
- extension.linkFramework("MetalKit");
- extension.linkFramework("Foundation");
- extension.linkFramework("Accelerate");
- // b.installFile("llama.cpp/ggml-metal.metal", b.pathJoin(&.{ std.fs.path.basename(b.lib_dir), "ggml-metal.metal" }));
- // b.installFile("llama.cpp/ggml-common.h", b.pathJoin(&.{ std.fs.path.basename(b.lib_dir), "ggml-common.h" }));
- } else {
- if (target.result.os.tag == .windows) {
- const vk_path = b.graph.env_map.get("VK_SDK_PATH") orelse @panic("VK_SDK_PATH not set");
- extension.addLibraryPath(.{ .path = b.pathJoin(&.{ vk_path, "Lib" }) });
- extension.linkSystemLibrary("vulkan-1");
- } else {
- extension.linkSystemLibrary("vulkan");
- }
+ lib_llama_cpp.addObject(obj);
}
- extension.linkLibrary(lib_godot);
- b.installArtifact(extension);
+ return lib_llama_cpp;
}
-const BuildObjectParams = struct {
+const ObjBuilder = struct {
b: *std.Build,
- name: []const u8,
target: std.Build.ResolvedTarget,
optimize: std.builtin.OptimizeMode,
- sources: []const []const u8,
include_paths: []const []const u8,
- link_lib_c: bool,
- link_lib_cpp: bool,
- flags: []const []const u8,
-};
-
-fn buildObj(params: BuildObjectParams) *std.Build.Step.Compile {
- const obj = params.b.addObject(.{
- .name = params.name,
- .target = params.target,
- .optimize = params.optimize,
- });
- if (params.target.result.os.tag == .windows) {
- const vk_path = params.b.graph.env_map.get("VK_SDK_PATH") orelse @panic("VK_SDK_PATH not set");
- obj.addIncludePath(.{ .path = params.b.pathJoin(&.{ vk_path, "include" }) });
+ flags: std.ArrayList([]const u8),
+
+ fn init(params: struct {
+ b: *std.Build,
+ target: std.Build.ResolvedTarget,
+ optimize: std.builtin.OptimizeMode,
+ include_paths: []const []const u8,
+ }) ObjBuilder {
+ return ObjBuilder{
+ .b = params.b,
+ .target = params.target,
+ .optimize = params.optimize,
+ .include_paths = params.include_paths,
+ .flags = std.ArrayList([]const u8).init(params.b.allocator),
+ };
}
- for (params.include_paths) |path| {
- obj.addIncludePath(.{ .path = path });
- }
- if (params.link_lib_c) {
+
+ fn build(self: *ObjBuilder, params: struct { name: []const u8, sources: []const []const u8 }) *std.Build.Step.Compile {
+ const obj = self.b.addObject(.{ .name = params.name, .target = self.target, .optimize = self.optimize });
+ obj.addCSourceFiles(.{ .files = params.sources, .flags = self.flags.items });
+ for (self.include_paths) |path| {
+ obj.addIncludePath(.{ .path = path });
+ }
obj.linkLibC();
- }
- if (params.link_lib_cpp) {
obj.linkLibCpp();
+ return obj;
}
- obj.addCSourceFiles(.{ .files = params.sources, .flags = params.flags });
- return obj;
-}
+};
fn findFilesRecursive(b: *std.Build, dir_name: []const u8, exts: []const []const u8) ![][]const u8 {
var sources = std.ArrayList([]const u8).init(b.allocator);
diff --git a/godot/.gitattributes b/godot/.gitattributes
new file mode 100644
index 0000000..8ad74f7
--- /dev/null
+++ b/godot/.gitattributes
@@ -0,0 +1,2 @@
+# Normalize EOL for all files that Git considers text files.
+* text=auto eol=lf
diff --git a/godot/.gitignore b/godot/.gitignore
new file mode 100644
index 0000000..4709183
--- /dev/null
+++ b/godot/.gitignore
@@ -0,0 +1,2 @@
+# Godot 4+ specific ignores
+.godot/
diff --git a/godot/addons/godot-llama-cpp/autoloads/llama-backend.gd b/godot/addons/godot-llama-cpp/autoloads/llama-backend.gd
deleted file mode 100644
index 83bb52e..0000000
--- a/godot/addons/godot-llama-cpp/autoloads/llama-backend.gd
+++ /dev/null
@@ -1,10 +0,0 @@
-# This script will be autoloaded by the editor plugin
-extends Node
-
-var backend: LlamaBackend = LlamaBackend.new()
-
-func _enter_tree() -> void:
- backend.init()
-
-func _exit_tree() -> void:
- backend.deinit()
diff --git a/godot/addons/godot-llama-cpp/chat/chat_formatter.gd b/godot/addons/godot-llama-cpp/chat/chat_formatter.gd
new file mode 100644
index 0000000..84d8369
--- /dev/null
+++ b/godot/addons/godot-llama-cpp/chat/chat_formatter.gd
@@ -0,0 +1,56 @@
+class_name ChatFormatter
+
+static func apply(format: String, messages: Array) -> String:
+ match format:
+ "llama3":
+ return format_llama3(messages)
+ "phi3":
+ return format_phi3(messages)
+ "mistral":
+ return format_mistral(messages)
+ _:
+ printerr("Unknown chat format: ", format)
+ return ""
+
+static func format_llama3(messages: Array) -> String:
+ var res = ""
+
+ for i in range(messages.size()):
+ match messages[i]:
+ {"text": var text, "sender": var sender}:
+ res += """<|start_header_id|>%s<|end_header_id|>
+
+%s<|eot_id|>
+""" % [sender, text]
+ _:
+ printerr("Invalid message at index ", i)
+
+ res += "<|start_header_id|>assistant<|end_header_id|>\n\n"
+ return res
+
+static func format_phi3(messages: Array) -> String:
+ var res = ""
+
+ for i in range(messages.size()):
+ match messages[i]:
+ {"text": var text, "sender": var sender}:
+ res +="<|%s|>\n%s<|end|>\n" % [sender, text]
+ _:
+ printerr("Invalid message at index ", i)
+ res += "<|assistant|>\n"
+ return res
+
+static func format_mistral(messages: Array) -> String:
+ var res = ""
+
+ for i in range(messages.size()):
+ match messages[i]:
+ {"text": var text, "sender": var sender}:
+ if sender == "user":
+ res += "[INST] %s [/INST]" % text
+ else:
+ res += "%s"
+ _:
+ printerr("Invalid message at index ", i)
+
+ return res
diff --git a/godot/addons/godot-llama-cpp/plugin.cfg b/godot/addons/godot-llama-cpp/plugin.cfg
index 36e0b68..6335055 100644
--- a/godot/addons/godot-llama-cpp/plugin.cfg
+++ b/godot/addons/godot-llama-cpp/plugin.cfg
@@ -4,4 +4,4 @@ name="godot-llama-cpp"
description="Run large language models in Godot. Powered by llama.cpp."
author="hazelnutcloud"
version="0.0.1"
-script="godot-llama-cpp.gd"
+script="plugin.gd"
diff --git a/godot/addons/godot-llama-cpp/godot-llama-cpp.gd b/godot/addons/godot-llama-cpp/plugin.gd
similarity index 50%
rename from godot/addons/godot-llama-cpp/godot-llama-cpp.gd
rename to godot/addons/godot-llama-cpp/plugin.gd
index ca63dbc..5cee76e 100644
--- a/godot/addons/godot-llama-cpp/godot-llama-cpp.gd
+++ b/godot/addons/godot-llama-cpp/plugin.gd
@@ -4,9 +4,9 @@ extends EditorPlugin
func _enter_tree():
# Initialization of the plugin goes here.
- add_autoload_singleton("__LlamaBackend", "res://addons/godot-llama-cpp/autoloads/llama-backend.gd")
+ pass
func _exit_tree():
# Clean-up of the plugin goes here.
- remove_autoload_singleton("__LlamaBackend")
+ pass
diff --git a/godot/addons/godot-llama-cpp/godot-llama-cpp.gdextension b/godot/addons/godot-llama-cpp/plugin.gdextension
similarity index 95%
rename from godot/addons/godot-llama-cpp/godot-llama-cpp.gdextension
rename to godot/addons/godot-llama-cpp/plugin.gdextension
index 4171072..9f2561c 100644
--- a/godot/addons/godot-llama-cpp/godot-llama-cpp.gdextension
+++ b/godot/addons/godot-llama-cpp/plugin.gdextension
@@ -5,8 +5,8 @@ compatibility_minimum = "4.2"
[libraries]
-macos.debug = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp-aarch64-macos-none-Debug.dylib"
-macos.release = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp-aarch64-macos-none-ReleaseFast.dylib"
+macos.debug = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp-aarch64-macos-none-ReleaseSafe.dylib"
+macos.release = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp-aarch64-macos-none-ReleaseSafe.dylib"
windows.debug.x86_32 = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp.windows.template_debug.x86_32.dll"
windows.release.x86_32 = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp.windows.template_release.x86_32.dll"
windows.debug.x86_64 = "res://addons/godot-llama-cpp/lib/godot-llama-cpp-x86_64-windows-gnu-Debug.dll"
diff --git a/godot/autoloads/llama.tscn b/godot/autoloads/llama.tscn
deleted file mode 100644
index 9115818..0000000
--- a/godot/autoloads/llama.tscn
+++ /dev/null
@@ -1,6 +0,0 @@
-[gd_scene load_steps=2 format=3 uid="uid://bxobxniygk7jm"]
-
-[ext_resource type="LlamaModel" path="res://models/OGNO-7B-Q4_K_M.gguf" id="1_vd8h8"]
-
-[node name="LlamaContext" type="LlamaContext"]
-model = ExtResource("1_vd8h8")
diff --git a/godot/examples/simple/TextEdit.gd b/godot/examples/simple/TextEdit.gd
new file mode 100644
index 0000000..b8d4bc0
--- /dev/null
+++ b/godot/examples/simple/TextEdit.gd
@@ -0,0 +1,20 @@
+extends TextEdit
+
+signal submit(input: String)
+
+func _gui_input(event: InputEvent) -> void:
+ if event is InputEventKey:
+ var keycode = event.get_keycode_with_modifiers()
+ if keycode == KEY_ENTER and event.is_pressed():
+ handle_submit()
+ accept_event()
+ if keycode == KEY_ENTER | KEY_MASK_SHIFT and event.is_pressed():
+ insert_text_at_caret("\n")
+ accept_event()
+
+func _on_button_pressed() -> void:
+ handle_submit()
+
+func handle_submit() -> void:
+ submit.emit(text)
+ text = ""
diff --git a/godot/examples/simple/form.gd b/godot/examples/simple/form.gd
new file mode 100644
index 0000000..092aa73
--- /dev/null
+++ b/godot/examples/simple/form.gd
@@ -0,0 +1,6 @@
+extends HBoxContainer
+
+@onready var text_edit = %TextEdit
+
+func _on_button_pressed() -> void:
+ text_edit.handle_submit()
diff --git a/godot/examples/simple/message.gd b/godot/examples/simple/message.gd
new file mode 100644
index 0000000..63b4f31
--- /dev/null
+++ b/godot/examples/simple/message.gd
@@ -0,0 +1,23 @@
+class_name Message
+extends Node
+
+@onready var text_container = %Text
+@onready var icon = %Panel
+@export_enum("user", "assistant") var sender: String
+@export var include_in_prompt: bool = true
+var text:
+ get:
+ return text_container.text
+ set(value):
+ text_container.text = value
+
+var completion_id: int = -1
+var pending: bool = false
+var errored: bool = false
+
+func set_text(new_text: String):
+ text_container.text = new_text
+
+func append_text(new_text: String):
+ text_container.text += new_text
+
diff --git a/godot/examples/simple/message.tscn b/godot/examples/simple/message.tscn
new file mode 100644
index 0000000..36b68e7
--- /dev/null
+++ b/godot/examples/simple/message.tscn
@@ -0,0 +1,37 @@
+[gd_scene load_steps=5 format=3 uid="uid://t862t0v8ht2q"]
+
+[ext_resource type="Script" path="res://examples/simple/message.gd" id="1_pko33"]
+[ext_resource type="Texture2D" uid="uid://dplw232htshgc" path="res://addons/godot-llama-cpp/assets/godot-llama-cpp-1024x1024.svg" id="2_dvc7y"]
+
+[sub_resource type="StyleBoxTexture" id="StyleBoxTexture_t8bgj"]
+texture = ExtResource("2_dvc7y")
+
+[sub_resource type="Theme" id="Theme_bw3pb"]
+Panel/styles/panel = SubResource("StyleBoxTexture_t8bgj")
+
+[node name="RichTextLabel" type="HBoxContainer"]
+anchors_preset = 15
+anchor_right = 1.0
+anchor_bottom = 1.0
+grow_horizontal = 2
+grow_vertical = 2
+size_flags_horizontal = 3
+theme_override_constants/separation = 20
+script = ExtResource("1_pko33")
+sender = "assistant"
+
+[node name="Panel" type="Panel" parent="."]
+unique_name_in_owner = true
+custom_minimum_size = Vector2(80, 80)
+layout_mode = 2
+size_flags_vertical = 0
+theme = SubResource("Theme_bw3pb")
+
+[node name="Text" type="RichTextLabel" parent="."]
+unique_name_in_owner = true
+layout_mode = 2
+size_flags_horizontal = 3
+focus_mode = 2
+text = "..."
+fit_content = true
+selection_enabled = true
diff --git a/godot/examples/simple/simple.gd b/godot/examples/simple/simple.gd
new file mode 100644
index 0000000..4519f5d
--- /dev/null
+++ b/godot/examples/simple/simple.gd
@@ -0,0 +1,53 @@
+extends Node
+
+const message = preload("res://examples/simple/message.tscn")
+
+@onready var messages_container = %MessagesContainer
+@onready var llama_context = %LlamaContext
+
+func _on_text_edit_submit(input: String) -> void:
+ handle_input(input)
+
+func handle_input(input: String) -> void:
+ #var messages = [{ "sender": "system", "text": "You are a pirate chatbot who always responds in pirate speak!" }]
+
+ #var messages = [{ "sender": "system", "text": "You are a helpful chatbot assistant!" }]
+ var messages = []
+ messages.append_array(messages_container.get_children().filter(func(msg: Message): return msg.include_in_prompt).map(
+ func(msg: Message) -> Dictionary:
+ return { "text": msg.text, "sender": msg.sender }
+ ))
+ messages.append({"text": input, "sender": "user"})
+ var prompt = ChatFormatter.apply("llama3", messages)
+ print("prompt: ", prompt)
+
+ var completion_id = llama_context.request_completion(prompt)
+
+ var user_message: Message = message.instantiate()
+ messages_container.add_child(user_message)
+ user_message.set_text(input)
+ user_message.sender = "user"
+ user_message.completion_id = completion_id
+
+ var ai_message: Message = message.instantiate()
+ messages_container.add_child(ai_message)
+ ai_message.sender = "assistant"
+ ai_message.completion_id = completion_id
+ ai_message.pending = true
+ ai_message.grab_focus()
+
+
+
+func _on_llama_context_completion_generated(chunk: Dictionary) -> void:
+ var completion_id = chunk.id
+ for msg: Message in messages_container.get_children():
+ if msg.completion_id != completion_id or msg.sender != "assistant":
+ continue
+ if chunk.has("error"):
+ msg.errored = true
+ elif chunk.has("text"):
+ if msg.pending:
+ msg.pending = false
+ msg.set_text(chunk["text"])
+ else:
+ msg.append_text(chunk["text"])
diff --git a/godot/examples/simple/simple.tscn b/godot/examples/simple/simple.tscn
new file mode 100644
index 0000000..50b193d
--- /dev/null
+++ b/godot/examples/simple/simple.tscn
@@ -0,0 +1,79 @@
+[gd_scene load_steps=6 format=3 uid="uid://c55kb4qvg6geq"]
+
+[ext_resource type="Texture2D" uid="uid://dplw232htshgc" path="res://addons/godot-llama-cpp/assets/godot-llama-cpp-1024x1024.svg" id="1_gjsev"]
+[ext_resource type="Script" path="res://examples/simple/simple.gd" id="1_sruc3"]
+[ext_resource type="PackedScene" uid="uid://t862t0v8ht2q" path="res://examples/simple/message.tscn" id="2_7iip7"]
+[ext_resource type="Script" path="res://examples/simple/TextEdit.gd" id="2_7usqw"]
+[ext_resource type="LlamaModel" path="res://models/meta-llama-3-8b-instruct.Q5_K_M.gguf" id="5_qov1l"]
+
+[node name="Node" type="Node"]
+script = ExtResource("1_sruc3")
+
+[node name="Panel" type="Panel" parent="."]
+anchors_preset = 15
+anchor_right = 1.0
+anchor_bottom = 1.0
+grow_horizontal = 2
+grow_vertical = 2
+
+[node name="MarginContainer" type="MarginContainer" parent="Panel"]
+layout_mode = 1
+anchors_preset = 15
+anchor_right = 1.0
+anchor_bottom = 1.0
+grow_horizontal = 2
+grow_vertical = 2
+theme_override_constants/margin_left = 10
+theme_override_constants/margin_top = 10
+theme_override_constants/margin_right = 10
+theme_override_constants/margin_bottom = 10
+
+[node name="VBoxContainer" type="VBoxContainer" parent="Panel/MarginContainer"]
+layout_mode = 2
+
+[node name="ScrollContainer" type="ScrollContainer" parent="Panel/MarginContainer/VBoxContainer"]
+layout_mode = 2
+size_flags_vertical = 3
+follow_focus = true
+
+[node name="MessagesContainer" type="VBoxContainer" parent="Panel/MarginContainer/VBoxContainer/ScrollContainer"]
+unique_name_in_owner = true
+layout_mode = 2
+size_flags_horizontal = 3
+size_flags_vertical = 3
+theme_override_constants/separation = 30
+
+[node name="RichTextLabel2" parent="Panel/MarginContainer/VBoxContainer/ScrollContainer/MessagesContainer" instance=ExtResource("2_7iip7")]
+layout_mode = 2
+include_in_prompt = false
+
+[node name="Text" parent="Panel/MarginContainer/VBoxContainer/ScrollContainer/MessagesContainer/RichTextLabel2" index="1"]
+text = "How can I help you?"
+
+[node name="HBoxContainer" type="HBoxContainer" parent="Panel/MarginContainer/VBoxContainer"]
+layout_mode = 2
+
+[node name="TextEdit" type="TextEdit" parent="Panel/MarginContainer/VBoxContainer/HBoxContainer"]
+custom_minimum_size = Vector2(2.08165e-12, 100)
+layout_mode = 2
+size_flags_horizontal = 3
+placeholder_text = "Ask me anything..."
+wrap_mode = 1
+script = ExtResource("2_7usqw")
+
+[node name="Button" type="Button" parent="Panel/MarginContainer/VBoxContainer/HBoxContainer"]
+custom_minimum_size = Vector2(100, 2.08165e-12)
+layout_mode = 2
+icon = ExtResource("1_gjsev")
+expand_icon = true
+
+[node name="LlamaContext" type="LlamaContext" parent="."]
+model = ExtResource("5_qov1l")
+temperature = 0.9
+unique_name_in_owner = true
+
+[connection signal="submit" from="Panel/MarginContainer/VBoxContainer/HBoxContainer/TextEdit" to="." method="_on_text_edit_submit"]
+[connection signal="pressed" from="Panel/MarginContainer/VBoxContainer/HBoxContainer/Button" to="Panel/MarginContainer/VBoxContainer/HBoxContainer/TextEdit" method="_on_button_pressed"]
+[connection signal="completion_generated" from="LlamaContext" to="." method="_on_llama_context_completion_generated"]
+
+[editable path="Panel/MarginContainer/VBoxContainer/ScrollContainer/MessagesContainer/RichTextLabel2"]
diff --git a/godot/icon.svg b/godot/icon.svg
new file mode 100644
index 0000000..3fe4f4a
--- /dev/null
+++ b/godot/icon.svg
@@ -0,0 +1 @@
+
diff --git a/godot/icon.svg.import b/godot/icon.svg.import
new file mode 100644
index 0000000..5a5ae61
--- /dev/null
+++ b/godot/icon.svg.import
@@ -0,0 +1,37 @@
+[remap]
+
+importer="texture"
+type="CompressedTexture2D"
+uid="uid://beeg0oqle7bnk"
+path="res://.godot/imported/icon.svg-218a8f2b3041327d8a5756f3a245f83b.ctex"
+metadata={
+"vram_texture": false
+}
+
+[deps]
+
+source_file="res://icon.svg"
+dest_files=["res://.godot/imported/icon.svg-218a8f2b3041327d8a5756f3a245f83b.ctex"]
+
+[params]
+
+compress/mode=0
+compress/high_quality=false
+compress/lossy_quality=0.7
+compress/hdr_compression=1
+compress/normal_map=0
+compress/channel_pack=0
+mipmaps/generate=false
+mipmaps/limit=-1
+roughness/mode=0
+roughness/src_normal=""
+process/fix_alpha_border=true
+process/premult_alpha=false
+process/normal_map_invert_y=false
+process/hdr_as_srgb=false
+process/hdr_clamp_exposure=false
+process/size_limit=0
+detect_3d/compress_to=1
+svg/scale=1.0
+editor/scale_with_editor_scale=false
+editor/convert_colors_with_editor_theme=false
diff --git a/godot/main.gd b/godot/main.gd
deleted file mode 100644
index 6c258e3..0000000
--- a/godot/main.gd
+++ /dev/null
@@ -1,27 +0,0 @@
-extends Node
-
-@onready var input: TextEdit = %Input
-@onready var submit_button: Button = %SubmitButton
-@onready var output: Label = %Output
-
-func _on_button_pressed():
- handle_submit()
-
-func handle_submit():
- print(input.text)
- Llama.request_completion(input.text)
-
- input.clear()
- input.editable = false
- submit_button.disabled = true
- output.text = "..."
-
- var completion = await Llama.completion_generated
- output.text = ""
- while !completion[1]:
- print(completion[0])
- output.text += completion[0]
- completion = await Llama.completion_generated
-
- input.editable = true
- submit_button.disabled = false
diff --git a/godot/main.tscn b/godot/main.tscn
deleted file mode 100644
index 7bda86d..0000000
--- a/godot/main.tscn
+++ /dev/null
@@ -1,103 +0,0 @@
-[gd_scene load_steps=4 format=3 uid="uid://7oo8yj56scb1"]
-
-[ext_resource type="Texture2D" uid="uid://dplw232htshgc" path="res://addons/godot-llama-cpp/assets/godot-llama-cpp-1024x1024.svg" id="1_ojdoj"]
-[ext_resource type="Script" path="res://main.gd" id="1_vvrqe"]
-
-[sub_resource type="StyleBoxFlat" id="StyleBoxFlat_3e37a"]
-corner_radius_top_left = 5
-corner_radius_top_right = 5
-corner_radius_bottom_right = 5
-corner_radius_bottom_left = 5
-
-[node name="Main" type="Node"]
-script = ExtResource("1_vvrqe")
-
-[node name="Background" type="ColorRect" parent="."]
-anchors_preset = 15
-anchor_right = 1.0
-anchor_bottom = 1.0
-grow_horizontal = 2
-grow_vertical = 2
-color = Color(0.980392, 0.952941, 0.929412, 1)
-
-[node name="CenterContainer" type="CenterContainer" parent="."]
-anchors_preset = 8
-anchor_left = 0.5
-anchor_top = 0.5
-anchor_right = 0.5
-anchor_bottom = 0.5
-offset_left = -400.0
-offset_top = -479.0
-offset_right = 400.0
-offset_bottom = 479.0
-grow_horizontal = 2
-grow_vertical = 2
-
-[node name="VBoxContainer" type="VBoxContainer" parent="CenterContainer"]
-custom_minimum_size = Vector2(500, 0)
-layout_mode = 2
-theme_override_constants/separation = 10
-alignment = 1
-
-[node name="Name" type="Label" parent="CenterContainer/VBoxContainer"]
-layout_mode = 2
-theme_override_colors/font_color = Color(0.101961, 0.0823529, 0.0627451, 1)
-theme_override_font_sizes/font_size = 32
-text = "godot-llama-cpp"
-horizontal_alignment = 1
-
-[node name="MarginContainer" type="MarginContainer" parent="CenterContainer/VBoxContainer"]
-layout_mode = 2
-theme_override_constants/margin_left = 100
-theme_override_constants/margin_right = 100
-
-[node name="TextureRect" type="TextureRect" parent="CenterContainer/VBoxContainer/MarginContainer"]
-layout_mode = 2
-texture = ExtResource("1_ojdoj")
-expand_mode = 4
-
-[node name="ScrollContainer" type="ScrollContainer" parent="CenterContainer/VBoxContainer"]
-custom_minimum_size = Vector2(2.08165e-12, 150)
-layout_mode = 2
-horizontal_scroll_mode = 0
-
-[node name="Panel" type="PanelContainer" parent="CenterContainer/VBoxContainer/ScrollContainer"]
-custom_minimum_size = Vector2(2.08165e-12, 2.08165e-12)
-layout_mode = 2
-size_flags_horizontal = 3
-size_flags_vertical = 3
-theme_override_styles/panel = SubResource("StyleBoxFlat_3e37a")
-
-[node name="MarginContainer" type="MarginContainer" parent="CenterContainer/VBoxContainer/ScrollContainer/Panel"]
-layout_mode = 2
-theme_override_constants/margin_left = 20
-theme_override_constants/margin_right = 20
-
-[node name="Output" type="Label" parent="CenterContainer/VBoxContainer/ScrollContainer/Panel/MarginContainer"]
-unique_name_in_owner = true
-custom_minimum_size = Vector2(200, 2.08165e-12)
-layout_mode = 2
-theme_override_colors/font_color = Color(0.101961, 0.0823529, 0.0627451, 1)
-text = "Ask me anything!"
-autowrap_mode = 3
-
-[node name="Form" type="HBoxContainer" parent="CenterContainer/VBoxContainer"]
-custom_minimum_size = Vector2(500, 60)
-layout_mode = 2
-size_flags_horizontal = 4
-alignment = 1
-
-[node name="Input" type="TextEdit" parent="CenterContainer/VBoxContainer/Form"]
-unique_name_in_owner = true
-layout_mode = 2
-size_flags_horizontal = 3
-size_flags_stretch_ratio = 3.0
-placeholder_text = "Why do cows moo?"
-
-[node name="SubmitButton" type="Button" parent="CenterContainer/VBoxContainer/Form"]
-unique_name_in_owner = true
-layout_mode = 2
-size_flags_horizontal = 3
-text = "Submit"
-
-[connection signal="pressed" from="CenterContainer/VBoxContainer/Form/SubmitButton" to="." method="_on_button_pressed"]
diff --git a/godot/project.godot b/godot/project.godot
index e5ff82f..4730d03 100644
--- a/godot/project.godot
+++ b/godot/project.godot
@@ -11,35 +11,14 @@ config_version=5
[application]
config/name="godot-llama-cpp"
-run/main_scene="res://main.tscn"
+run/main_scene="res://examples/simple/simple.tscn"
config/features=PackedStringArray("4.2", "Forward Plus")
-config/icon="res://addons/godot-llama-cpp/assets/godot-llama-cpp-1024x1024.svg"
-
-[autoload]
-
-__LlamaBackend="*res://addons/godot-llama-cpp/autoloads/llama-backend.gd"
-Llama="*res://autoloads/llama.tscn"
-
-[display]
-
-window/size/viewport_width=1280
-window/size/viewport_height=720
+config/icon="res://icon.svg"
[editor_plugins]
enabled=PackedStringArray("res://addons/godot-llama-cpp/plugin.cfg")
-[input]
-
-submit_form={
-"deadzone": 0.5,
-"events": [Object(InputEventKey,"resource_local_to_scene":false,"resource_name":"","device":-1,"window_id":0,"alt_pressed":false,"shift_pressed":false,"ctrl_pressed":false,"meta_pressed":false,"pressed":false,"keycode":0,"physical_keycode":4194309,"key_label":0,"unicode":0,"echo":false,"script":null)
-]
-}
-
-[rendering]
+[gui]
-anti_aliasing/quality/msaa_2d=3
-anti_aliasing/quality/msaa_3d=3
-anti_aliasing/quality/screen_space_aa=1
-anti_aliasing/quality/use_taa=true
+theme/default_theme_scale=2.0
diff --git a/godot_cpp b/godot_cpp
index 51c752c..b28098e 160000
--- a/godot_cpp
+++ b/godot_cpp
@@ -1 +1 @@
-Subproject commit 51c752c46b44769d3b6c661526c364a18ea64781
+Subproject commit b28098e76b84e8831b8ac68d490f4bca44678b2a
diff --git a/llama.cpp b/llama.cpp
index 4755afd..9afdffe 160000
--- a/llama.cpp
+++ b/llama.cpp
@@ -1 +1 @@
-Subproject commit 4755afd1cbd40d93c017e5b98c39796f52345314
+Subproject commit 9afdffe70ebf3166d429b4434783bb0b7f97bdeb
diff --git a/src/llama_backend.cpp b/src/llama_backend.cpp
deleted file mode 100644
index 0e2c0a5..0000000
--- a/src/llama_backend.cpp
+++ /dev/null
@@ -1,19 +0,0 @@
-#include "llama.h"
-#include "llama_backend.h"
-#include
-
-using namespace godot;
-
-void LlamaBackend::init() {
- llama_backend_init();
- llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_DISABLED);
-}
-
-void LlamaBackend::deinit() {
- llama_backend_free();
-}
-
-void LlamaBackend::_bind_methods() {
- ClassDB::bind_method(D_METHOD("init"), &LlamaBackend::init);
- ClassDB::bind_method(D_METHOD("deinit"), &LlamaBackend::deinit);
-}
\ No newline at end of file
diff --git a/src/llama_backend.h b/src/llama_backend.h
deleted file mode 100644
index 8b0628f..0000000
--- a/src/llama_backend.h
+++ /dev/null
@@ -1,19 +0,0 @@
-#ifndef LLAMA_BACKEND_H
-#define LLAMA_BACKEND_H
-
-#include
-
-namespace godot {
-class LlamaBackend : public RefCounted {
- GDCLASS(LlamaBackend, RefCounted)
-
-protected:
- static void _bind_methods();
-
-public:
- void init();
- void deinit();
-};
-} //namespace godot
-
-#endif
\ No newline at end of file
diff --git a/src/llama_context.cpp b/src/llama_context.cpp
index 78ad255..0fcfc2f 100644
--- a/src/llama_context.cpp
+++ b/src/llama_context.cpp
@@ -2,10 +2,12 @@
#include "common.h"
#include "llama.h"
#include "llama_model.h"
+#include
#include
#include
#include
#include
+#include
#include
using namespace godot;
@@ -15,31 +17,41 @@ void LlamaContext::_bind_methods() {
ClassDB::bind_method(D_METHOD("get_model"), &LlamaContext::get_model);
ClassDB::add_property("LlamaContext", PropertyInfo(Variant::OBJECT, "model", PROPERTY_HINT_RESOURCE_TYPE, "LlamaModel"), "set_model", "get_model");
- ClassDB::bind_method(D_METHOD("get_seed"), &LlamaContext::get_seed);
- ClassDB::bind_method(D_METHOD("set_seed", "seed"), &LlamaContext::set_seed);
- ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "seed"), "set_seed", "get_seed");
+ ClassDB::bind_method(D_METHOD("get_seed"), &LlamaContext::get_seed);
+ ClassDB::bind_method(D_METHOD("set_seed", "seed"), &LlamaContext::set_seed);
+ ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "seed"), "set_seed", "get_seed");
- ClassDB::bind_method(D_METHOD("get_n_ctx"), &LlamaContext::get_n_ctx);
- ClassDB::bind_method(D_METHOD("set_n_ctx", "n_ctx"), &LlamaContext::set_n_ctx);
- ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_ctx"), "set_n_ctx", "get_n_ctx");
+ ClassDB::bind_method(D_METHOD("get_temperature"), &LlamaContext::get_temperature);
+ ClassDB::bind_method(D_METHOD("set_temperature", "temperature"), &LlamaContext::set_temperature);
+ ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "temperature"), "set_temperature", "get_temperature");
- ClassDB::bind_method(D_METHOD("get_n_threads"), &LlamaContext::get_n_threads);
- ClassDB::bind_method(D_METHOD("set_n_threads", "n_threads"), &LlamaContext::set_n_threads);
- ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_threads"), "set_n_threads", "get_n_threads");
+ ClassDB::bind_method(D_METHOD("get_top_p"), &LlamaContext::get_top_p);
+ ClassDB::bind_method(D_METHOD("set_top_p", "top_p"), &LlamaContext::set_top_p);
+ ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "top_p"), "set_top_p", "get_top_p");
- ClassDB::bind_method(D_METHOD("get_n_threads_batch"), &LlamaContext::get_n_threads_batch);
- ClassDB::bind_method(D_METHOD("set_n_threads_batch", "n_threads_batch"), &LlamaContext::set_n_threads_batch);
- ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_threads_batch"), "set_n_threads_batch", "get_n_threads_batch");
+ ClassDB::bind_method(D_METHOD("get_frequency_penalty"), &LlamaContext::get_frequency_penalty);
+ ClassDB::bind_method(D_METHOD("set_frequency_penalty", "frequency_penalty"), &LlamaContext::set_frequency_penalty);
+ ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "frequency_penalty"), "set_frequency_penalty", "get_frequency_penalty");
+
+ ClassDB::bind_method(D_METHOD("get_presence_penalty"), &LlamaContext::get_presence_penalty);
+ ClassDB::bind_method(D_METHOD("set_presence_penalty", "presence_penalty"), &LlamaContext::set_presence_penalty);
+ ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "presence_penalty"), "set_presence_penalty", "get_presence_penalty");
+
+ ClassDB::bind_method(D_METHOD("get_n_ctx"), &LlamaContext::get_n_ctx);
+ ClassDB::bind_method(D_METHOD("set_n_ctx", "n_ctx"), &LlamaContext::set_n_ctx);
+ ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_ctx"), "set_n_ctx", "get_n_ctx");
+
+ ClassDB::bind_method(D_METHOD("get_n_len"), &LlamaContext::get_n_len);
+ ClassDB::bind_method(D_METHOD("set_n_len", "n_len"), &LlamaContext::set_n_len);
+ ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_len"), "set_n_len", "get_n_len");
ClassDB::bind_method(D_METHOD("request_completion", "prompt"), &LlamaContext::request_completion);
- ClassDB::bind_method(D_METHOD("_fulfill_completion", "prompt"), &LlamaContext::_fulfill_completion);
+ ClassDB::bind_method(D_METHOD("__thread_loop"), &LlamaContext::__thread_loop);
- ADD_SIGNAL(MethodInfo("completion_generated", PropertyInfo(Variant::STRING, "completion"), PropertyInfo(Variant::BOOL, "is_final")));
+ ADD_SIGNAL(MethodInfo("completion_generated", PropertyInfo(Variant::DICTIONARY, "chunk")));
}
LlamaContext::LlamaContext() {
- batch = llama_batch_init(4096, 0, 1);
-
ctx_params = llama_context_default_params();
ctx_params.seed = -1;
ctx_params.n_ctx = 4096;
@@ -60,109 +72,186 @@ void LlamaContext::_ready() {
return;
}
+ mutex.instantiate();
+ semaphore.instantiate();
+ thread.instantiate();
+
+ llama_backend_init();
+ llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_DISABLED);
+
ctx = llama_new_context_with_model(model->model, ctx_params);
if (ctx == NULL) {
UtilityFunctions::printerr(vformat("%s: Failed to initialize llama context, null ctx", __func__));
return;
}
+
+ sampling_ctx = llama_sampling_init(sampling_params);
+
UtilityFunctions::print(vformat("%s: Context initialized", __func__));
-}
-PackedStringArray LlamaContext::_get_configuration_warnings() const {
- PackedStringArray warnings;
- if (model == NULL) {
- warnings.push_back("Model resource property not defined");
- }
- return warnings;
+ thread->start(callable_mp(this, &LlamaContext::__thread_loop));
}
-Variant LlamaContext::request_completion(const String &prompt) {
- UtilityFunctions::print(vformat("%s: Requesting completion for prompt: %s", __func__, prompt));
- if (task_id) {
- WorkerThreadPool::get_singleton()->wait_for_task_completion(task_id);
- }
- task_id = WorkerThreadPool::get_singleton()->add_task(Callable(this, "_fulfill_completion").bind(prompt));
- return OK;
-}
+void LlamaContext::__thread_loop() {
+ while (true) {
+ semaphore->wait();
-void LlamaContext::_fulfill_completion(const String &prompt) {
- UtilityFunctions::print(vformat("%s: Fulfilling completion for prompt: %s", __func__, prompt));
- std::vector tokens_list;
- tokens_list = ::llama_tokenize(ctx, std::string(prompt.utf8().get_data()), true);
+ mutex->lock();
+ if (exit_thread) {
+ mutex->unlock();
+ break;
+ }
+ if (completion_requests.size() == 0) {
+ mutex->unlock();
+ continue;
+ }
+ completion_request req = completion_requests.get(0);
+ completion_requests.remove_at(0);
+ mutex->unlock();
- const int n_len = 128;
- const int n_ctx = llama_n_ctx(ctx);
- const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
- if (n_kv_req > n_ctx) {
- UtilityFunctions::printerr(vformat("%s: n_kv_req > n_ctx, the required KV cache size is not big enough\neither reduce n_len or increase n_ctx", __func__));
- return;
- }
+ UtilityFunctions::print(vformat("%s: Running completion for prompt id: %d", __func__, req.id));
- for (size_t i = 0; i < tokens_list.size(); i++) {
- llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
- }
+ std::vector request_tokens;
+ request_tokens = ::llama_tokenize(ctx, req.prompt.utf8().get_data(), true, true);
- batch.logits[batch.n_tokens - 1] = true;
+ size_t shared_prefix_idx = 0;
+ auto diff = std::mismatch(context_tokens.begin(), context_tokens.end(), request_tokens.begin(), request_tokens.end());
+ if (diff.first != context_tokens.end()) {
+ shared_prefix_idx = std::distance(context_tokens.begin(), diff.first);
+ } else {
+ shared_prefix_idx = std::min(context_tokens.size(), request_tokens.size());
+ }
- llama_kv_cache_clear(ctx);
+ bool rm_success = llama_kv_cache_seq_rm(ctx, -1, shared_prefix_idx, -1);
+ if (!rm_success) {
+ UtilityFunctions::printerr(vformat("%s: Failed to remove tokens from kv cache", __func__));
+ Dictionary response;
+ response["id"] = req.id;
+ response["error"] = "Failed to remove tokens from kv cache";
+ call_thread_safe("emit_signal", "completion_generated", response);
+ continue;
+ }
+ context_tokens.erase(context_tokens.begin() + shared_prefix_idx, context_tokens.end());
+ request_tokens.erase(request_tokens.begin(), request_tokens.begin() + shared_prefix_idx);
- int decode_res = llama_decode(ctx, batch);
- if (decode_res != 0) {
- UtilityFunctions::printerr(vformat("%s: Failed to decode prompt with error code: %d", __func__, decode_res));
- return;
- }
+ uint batch_size = std::min(ctx_params.n_batch, (uint)request_tokens.size());
+
+ llama_batch batch = llama_batch_init(batch_size, 0, 1);
+
+ // chunk request_tokens into sequences of size batch_size
+ std::vector> sequences;
+ for (size_t i = 0; i < request_tokens.size(); i += batch_size) {
+ sequences.push_back(std::vector(request_tokens.begin() + i, request_tokens.begin() + std::min(i + batch_size, request_tokens.size())));
+ }
+
+ printf("Request tokens: \n");
+ for (auto sequence : sequences) {
+ for (auto token : sequence) {
+ printf("%s", llama_token_to_piece(ctx, token).c_str());
+ }
+ }
+ printf("\n");
- int n_cur = batch.n_tokens;
- int n_decode = 0;
- llama_model *llama_model = model->model;
+ int curr_token_pos = context_tokens.size();
+ bool decode_failed = false;
- while (n_cur <= n_len) {
- // sample the next token
- {
- auto n_vocab = llama_n_vocab(llama_model);
- auto *logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
+ for (size_t i = 0; i < sequences.size(); i++) {
+ llama_batch_clear(batch);
- std::vector candidates;
- candidates.reserve(n_vocab);
+ std::vector sequence = sequences[i];
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
+ for (size_t j = 0; j < sequence.size(); j++) {
+ llama_batch_add(batch, sequence[j], j + curr_token_pos, { 0 }, false);
+ curr_token_pos++;
}
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+ if (i == sequences.size() - 1) {
+ batch.logits[batch.n_tokens - 1] = true;
+ }
- // sample the most likely token
- const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
+ if (llama_decode(ctx, batch) != 0) {
+ decode_failed = true;
+ break;
+ }
+ }
- // is it an end of stream?
- if (new_token_id == llama_token_eos(llama_model) || n_cur == n_len) {
- call_thread_safe("emit_signal", "completion_generated", "\n", true);
+ if (decode_failed) {
+ Dictionary response;
+ response["id"] = req.id;
+ response["error"] = "llama_decode() failed";
+ call_thread_safe("emit_signal", "completion_generated", response);
+ continue;
+ }
+
+ context_tokens.insert(context_tokens.end(), request_tokens.begin(), request_tokens.end());
+
+ while (true) {
+ if (exit_thread) {
+ return;
+ }
+ llama_token new_token_id = llama_sampling_sample(sampling_ctx, ctx, NULL, batch.n_tokens - 1);
+ llama_sampling_accept(sampling_ctx, ctx, new_token_id, false);
+ Dictionary response;
+ response["id"] = req.id;
+
+ context_tokens.push_back(new_token_id);
+
+ if (llama_token_is_eog(model->model, new_token_id) || curr_token_pos == n_len) {
+ response["done"] = true;
+ call_thread_safe("emit_signal", "completion_generated", response);
break;
}
- call_thread_safe("emit_signal", "completion_generated", vformat("%s", llama_token_to_piece(ctx, new_token_id).c_str()), false);
+ response["text"] = llama_token_to_piece(ctx, new_token_id).c_str();
+ response["done"] = false;
+ call_thread_safe("emit_signal", "completion_generated", response);
- // prepare the next batch
llama_batch_clear(batch);
- // push this new token for next evaluation
- llama_batch_add(batch, new_token_id, n_cur, { 0 }, true);
+ llama_batch_add(batch, new_token_id, curr_token_pos, { 0 }, true);
- n_decode += 1;
- }
+ curr_token_pos++;
- n_cur += 1;
+ if (llama_decode(ctx, batch) != 0) {
+ decode_failed = true;
+ break;
+ }
+ }
- // evaluate the current batch with the transformer model
- int decode_res = llama_decode(ctx, batch);
- if (decode_res != 0) {
- UtilityFunctions::printerr(vformat("%s: Failed to decode batch with error code: %d", __func__, decode_res));
- break;
+ if (decode_failed) {
+ Dictionary response;
+ response["id"] = req.id;
+ response["error"] = "llama_decode() failed";
+ call_thread_safe("emit_signal", "completion_generated", response);
+ continue;
}
}
}
+PackedStringArray LlamaContext::_get_configuration_warnings() const {
+ PackedStringArray warnings;
+ if (model == NULL) {
+ warnings.push_back("Model resource property not defined");
+ }
+ return warnings;
+}
+
+int LlamaContext::request_completion(const String &prompt) {
+ int id = request_id++;
+
+ UtilityFunctions::print(vformat("%s: Requesting completion for prompt id: %d", __func__, id));
+
+ mutex->lock();
+ completion_request req = { id, prompt };
+ completion_requests.append(req);
+ mutex->unlock();
+
+ semaphore->post();
+
+ return id;
+}
+
void LlamaContext::set_model(const Ref p_model) {
model = p_model;
}
@@ -184,28 +273,58 @@ void LlamaContext::set_n_ctx(int n_ctx) {
ctx_params.n_ctx = n_ctx;
}
-int LlamaContext::get_n_threads() {
- return ctx_params.n_threads;
+int LlamaContext::get_n_len() {
+ return n_len;
}
-void LlamaContext::set_n_threads(int n_threads) {
- ctx_params.n_threads = n_threads;
+void LlamaContext::set_n_len(int n_len) {
+ this->n_len = n_len;
}
-int LlamaContext::get_n_threads_batch() {
- return ctx_params.n_threads_batch;
+float LlamaContext::get_temperature() {
+ return sampling_params.temp;
}
-void LlamaContext::set_n_threads_batch(int n_threads_batch) {
- ctx_params.n_threads_batch = n_threads_batch;
+void LlamaContext::set_temperature(float temperature) {
+ sampling_params.temp = temperature;
}
-LlamaContext::~LlamaContext() {
- if (ctx) {
- llama_free(ctx);
+float LlamaContext::get_top_p() {
+ return sampling_params.top_p;
+}
+void LlamaContext::set_top_p(float top_p) {
+ sampling_params.top_p = top_p;
+}
+
+float LlamaContext::get_frequency_penalty() {
+ return sampling_params.penalty_freq;
+}
+void LlamaContext::set_frequency_penalty(float frequency_penalty) {
+ sampling_params.penalty_freq = frequency_penalty;
+}
+
+float LlamaContext::get_presence_penalty() {
+ return sampling_params.penalty_present;
+}
+void LlamaContext::set_presence_penalty(float presence_penalty) {
+ sampling_params.penalty_present = presence_penalty;
+}
+
+void LlamaContext::_exit_tree() {
+ if (Engine::get_singleton()->is_editor_hint()) {
+ return;
}
- llama_batch_free(batch);
+ mutex->lock();
+ exit_thread = true;
+ mutex->unlock();
+
+ semaphore->post();
- if (task_id) {
- WorkerThreadPool::get_singleton()->wait_for_task_completion(task_id);
+ thread->wait_to_finish();
+
+ if (ctx) {
+ llama_free(ctx);
}
+
+ llama_sampling_free(sampling_ctx);
+ llama_backend_free();
}
\ No newline at end of file
diff --git a/src/llama_context.h b/src/llama_context.h
index 5db3789..07bfd3e 100644
--- a/src/llama_context.h
+++ b/src/llama_context.h
@@ -2,19 +2,38 @@
#define LLAMA_CONTEXT_H
#include "llama.h"
+#include "common.h"
#include "llama_model.h"
+#include
#include
-
+#include
+#include
+#include
namespace godot {
+
+struct completion_request {
+ int id;
+ String prompt;
+};
+
class LlamaContext : public Node {
GDCLASS(LlamaContext, Node)
private:
Ref model;
llama_context *ctx = nullptr;
+ llama_sampling_context *sampling_ctx = nullptr;
llama_context_params ctx_params;
- llama_batch batch;
- int task_id;
+ llama_sampling_params sampling_params;
+ int n_len = 1024;
+ int request_id = 0;
+ Vector completion_requests;
+
+ Ref thread;
+ Ref semaphore;
+ Ref mutex;
+ std::vector context_tokens;
+ bool exit_thread = false;
protected:
static void _bind_methods();
@@ -23,22 +42,28 @@ class LlamaContext : public Node {
void set_model(const Ref model);
Ref get_model();
- Variant request_completion(const String &prompt);
- void _fulfill_completion(const String &prompt);
+ int request_completion(const String &prompt);
+ void __thread_loop();
- int get_seed();
- void set_seed(int seed);
- int get_n_ctx();
- void set_n_ctx(int n_ctx);
- int get_n_threads();
- void set_n_threads(int n_threads);
- int get_n_threads_batch();
- void set_n_threads_batch(int n_threads_batch);
+ int get_seed();
+ void set_seed(int seed);
+ int get_n_ctx();
+ void set_n_ctx(int n_ctx);
+ int get_n_len();
+ void set_n_len(int n_len);
+ float get_temperature();
+ void set_temperature(float temperature);
+ float get_top_p();
+ void set_top_p(float top_p);
+ float get_frequency_penalty();
+ void set_frequency_penalty(float frequency_penalty);
+ float get_presence_penalty();
+ void set_presence_penalty(float presence_penalty);
- virtual PackedStringArray _get_configuration_warnings() const override;
+ virtual PackedStringArray _get_configuration_warnings() const override;
virtual void _ready() override;
- LlamaContext();
- ~LlamaContext();
+ virtual void _exit_tree() override;
+ LlamaContext();
};
} //namespace godot
diff --git a/src/llama_model.cpp b/src/llama_model.cpp
index 92cb9db..96eea4f 100644
--- a/src/llama_model.cpp
+++ b/src/llama_model.cpp
@@ -1,5 +1,6 @@
#include "llama_model.h"
#include "llama.h"
+#include
#include
#include
@@ -22,14 +23,16 @@ void LlamaModel::load_model(const String &path) {
llama_free_model(model);
}
- model = llama_load_model_from_file(path.utf8().get_data(), model_params);
+ String absPath = ProjectSettings::get_singleton()->globalize_path(path);
+
+ model = llama_load_model_from_file(absPath.utf8().get_data(), model_params);
if (model == NULL) {
- UtilityFunctions::printerr(vformat("%s: Unable to load model from %s", __func__, path));
+ UtilityFunctions::printerr(vformat("%s: Unable to load model from %s", __func__, absPath));
return;
}
- UtilityFunctions::print(vformat("%s: Model loaded from %s", __func__, path));
+ UtilityFunctions::print(vformat("%s: Model loaded from %s", __func__, absPath));
}
int LlamaModel::get_n_gpu_layers() {
diff --git a/src/llama_model_loader.cpp b/src/llama_model_loader.cpp
index 77b8a94..940f1c3 100644
--- a/src/llama_model_loader.cpp
+++ b/src/llama_model_loader.cpp
@@ -2,7 +2,6 @@
#include "llama_model.h"
#include
#include
-#include
#include
using namespace godot;
@@ -24,9 +23,7 @@ Variant godot::LlamaModelLoader::_load(const String &path, const String &origina
return { model };
}
- String absPath = ProjectSettings::get_singleton()->globalize_path(path);
-
- model->load_model(absPath);
+ model->load_model(path);
return { model };
}
diff --git a/src/register_types.cpp b/src/register_types.cpp
index 0cf4cf0..17ec90b 100644
--- a/src/register_types.cpp
+++ b/src/register_types.cpp
@@ -7,7 +7,6 @@
#include "llama_model.h"
#include "llama_model_loader.h"
#include "llama_context.h"
-#include "llama_backend.h"
using namespace godot;
@@ -24,7 +23,6 @@ void initialize_types(ModuleInitializationLevel p_level)
ClassDB::register_class();
ClassDB::register_class();
- ClassDB::register_class();
}
void uninitialize_types(ModuleInitializationLevel p_level) {
diff --git a/tools/expand_metal.zig b/tools/expand_metal.zig
new file mode 100644
index 0000000..1e4e9f4
--- /dev/null
+++ b/tools/expand_metal.zig
@@ -0,0 +1,79 @@
+const std = @import("std");
+
+const usage =
+ \\Usage: ./embed_metal [options]
+ \\
+ \\Options:
+ \\ --metal-file ggml-metal.metal
+ \\ --common-file ggml-common.h
+ \\ --output-file ggml-metal-embed.metal
+ \\
+;
+
+pub fn main() !void {
+ var arena_state = std.heap.ArenaAllocator.init(std.heap.page_allocator);
+ defer arena_state.deinit();
+ const arena = arena_state.allocator();
+
+ const args = try std.process.argsAlloc(arena);
+
+ var opt_metal_file_path: ?[]const u8 = null;
+ var opt_common_file_path: ?[]const u8 = null;
+ var opt_output_file_path: ?[]const u8 = null;
+
+ {
+ var i: usize = 1;
+ while (i < args.len) : (i += 1) {
+ const arg = args[i];
+ if (std.mem.eql(u8, "-h", arg) or std.mem.eql(u8, "--help", arg)) {
+ try std.io.getStdOut().writeAll(usage);
+ return std.process.cleanExit();
+ } else if (std.mem.eql(u8, "--metal-file", arg)) {
+ i += 1;
+ if (i > args.len) std.debug.panic("expected arg after '{s}'", .{arg});
+ if (opt_metal_file_path != null) std.debug.panic("duplicated {s} argument", .{arg});
+ opt_metal_file_path = args[i];
+ } else if (std.mem.eql(u8, "--common-file", arg)) {
+ i += 1;
+ if (i > args.len) std.debug.panic("expected arg after '{s}'", .{arg});
+ if (opt_common_file_path != null) std.debug.panic("duplicated {s} argument", .{arg});
+ opt_common_file_path = args[i];
+ } else if (std.mem.eql(u8, "--output-file", arg)) {
+ i += 1;
+ if (i > args.len) std.debug.panic("expected arg after '{s}'", .{arg});
+ if (opt_output_file_path != null) std.debug.panic("duplicated {s} argument", .{arg});
+ opt_output_file_path = args[i];
+ } else {
+ std.debug.panic("unrecognized arg: '{s}'", .{arg});
+ }
+ }
+ }
+
+ const metal_file_path = opt_metal_file_path orelse std.debug.panic("missing --input-file", .{});
+ const common_file_path = opt_common_file_path orelse std.debug.panic("missing --output-file", .{});
+ const output_file_path = opt_output_file_path orelse std.debug.panic("missing --lang", .{});
+
+ const cwd = std.fs.cwd();
+
+ var metal_file = try cwd.openFile(metal_file_path, .{});
+ defer metal_file.close();
+
+ var common_file = try cwd.openFile(common_file_path, .{});
+ defer common_file.close();
+
+ const metal_size = (try metal_file.stat()).size;
+ const metal_contents = try arena.alloc(u8, metal_size);
+ defer arena.free(metal_contents);
+ _ = try metal_file.readAll(metal_contents);
+
+ const common_size = (try common_file.stat()).size;
+ const common_contents = try arena.alloc(u8, common_size);
+ defer arena.free(common_contents);
+ _ = try common_file.readAll(common_contents);
+
+ const output = try std.mem.replaceOwned(u8, arena, metal_contents, "#include \"ggml-common.h\"", common_contents);
+ defer arena.free(output);
+
+ const output_file = try cwd.createFile(output_file_path, .{});
+ try output_file.writeAll(output);
+}