diff --git a/build.zig b/build.zig index 711165b..aadce33 100644 --- a/build.zig +++ b/build.zig @@ -7,11 +7,41 @@ const extension_name = "godot-llama-cpp"; pub fn build(b: *std.Build) !void { const target = b.standardTargetOptions(.{}); const optimize = b.standardOptimizeOption(.{}); - const zig_triple = try target.result.linuxTriple(b.allocator); + const triple = try target.result.linuxTriple(b.allocator); - var objs = std.ArrayList(*std.Build.Step.Compile).init(b.allocator); + const lib_godot_cpp = try build_lib_godot_cpp(.{ .b = b, .target = target, .optimize = optimize }); + const lib_llama_cpp = try build_lib_llama_cpp(.{ .b = b, .target = target, .optimize = optimize }); + + const plugin = b.addSharedLibrary(.{ + .name = b.fmt("{s}-{s}-{s}", .{ extension_name, triple, @tagName(optimize) }), + .target = target, + .optimize = optimize, + }); + plugin.addCSourceFiles(.{ .files = try findFilesRecursive(b, "src/", &cfiles_exts) }); + plugin.addIncludePath(.{ .path = "src/" }); + plugin.addIncludePath(.{ .path = "godot_cpp/gdextension/" }); + plugin.addIncludePath(.{ .path = "godot_cpp/include/" }); + plugin.addIncludePath(.{ .path = "godot_cpp/gen/include" }); + plugin.addIncludePath(.{ .path = "llama.cpp" }); + plugin.addIncludePath(.{ .path = "llama.cpp/common" }); + plugin.linkLibrary(lib_llama_cpp); + plugin.linkLibrary(lib_godot_cpp); + + b.lib_dir = "./godot/addons/godot-llama-cpp/lib"; + b.installArtifact(plugin); +} + +const BuildParams = struct { + b: *std.Build, + target: std.Build.ResolvedTarget, + optimize: std.builtin.OptimizeMode, +}; + +fn build_lib_godot_cpp(params: BuildParams) !*std.Build.Step.Compile { + const b = params.b; + const target = params.target; + const optimize = params.optimize; - // godot-cpp const lib_godot = b.addStaticLibrary(.{ .name = "godot-cpp", .target = target, @@ -26,9 +56,7 @@ pub fn build(b: *std.Build) !void { .cwd_dir = b.build_root.handle, }); }, - else => { - return; - }, + else => {}, } }; lib_godot.linkLibCpp(); @@ -40,7 +68,15 @@ pub fn build(b: *std.Build) !void { lib_godot.addCSourceFiles(.{ .files = lib_godot_gen_sources, .flags = &.{ "-std=c++17", "-fno-exceptions" } }); lib_godot.addCSourceFiles(.{ .files = lib_godot_sources, .flags = &.{ "-std=c++17", "-fno-exceptions" } }); - // llama.cpp + return lib_godot; +} + +fn build_lib_llama_cpp(params: BuildParams) !*std.Build.Step.Compile { + const b = params.b; + const target = params.target; + const optimize = params.optimize; + const zig_triple = try target.result.zigTriple(b.allocator); + const commit_hash = try std.ChildProcess.run(.{ .allocator = b.allocator, .argv = &.{ "git", "rev-parse", "HEAD" }, .cwd = b.pathFromRoot("llama.cpp") }); const zig_version = builtin.zig_version_string; try b.build_root.handle.writeFile2(.{ .sub_path = "llama.cpp/common/build-info.cpp", .data = b.fmt( @@ -48,259 +84,102 @@ pub fn build(b: *std.Build) !void { \\char const *LLAMA_COMMIT = "{s}"; \\char const *LLAMA_COMPILER = "Zig {s}"; \\char const *LLAMA_BUILD_TARGET = "{s}"; - \\ , .{ 0, commit_hash.stdout[0 .. commit_hash.stdout.len - 1], zig_version, zig_triple }) }); - var flags = std.ArrayList([]const u8).init(b.allocator); - if (target.result.abi != .msvc) try flags.append("-D_GNU_SOURCE"); - if (target.result.os.tag == .macos) try flags.appendSlice(&.{ - "-D_DARWIN_C_SOURCE", - "-DGGML_USE_METAL", - "-DGGML_USE_ACCELERATE", - "-DACCELERATE_USE_LAPACK", - "-DACCELERATE_LAPACK_ILP64", - }) else try flags.append("-DGGML_USE_VULKAN"); - try flags.append("-D_XOPEN_SOURCE=600"); + const lib_llama_cpp = b.addStaticLibrary(.{ .name = "llama.cpp", .target = target, .optimize = optimize }); - var cflags = std.ArrayList([]const u8).init(b.allocator); - try cflags.append("-std=c11"); - try cflags.appendSlice(flags.items); - var cxxflags = std.ArrayList([]const u8).init(b.allocator); - try cxxflags.append("-std=c++11"); - try cxxflags.appendSlice(flags.items); - - const include_paths = [_][]const u8{ "llama.cpp", "llama.cpp/common" }; - const llama = buildObj(.{ - .b = b, - .name = "llama", - .target = target, - .optimize = optimize, - .sources = &.{"llama.cpp/llama.cpp"}, - .include_paths = &include_paths, - .link_lib_cpp = true, - .link_lib_c = false, - .flags = cxxflags.items, - }); - const ggml = buildObj(.{ - .b = b, - .name = "ggml", - .target = target, - .optimize = optimize, - .sources = &.{"llama.cpp/ggml.c"}, - .include_paths = &include_paths, - .link_lib_c = true, - .link_lib_cpp = false, - .flags = cflags.items, - }); - const common = buildObj(.{ - .b = b, - .name = "common", - .target = target, - .optimize = optimize, - .sources = &.{"llama.cpp/common/common.cpp"}, - .include_paths = &include_paths, - .link_lib_cpp = true, - .link_lib_c = false, - .flags = cxxflags.items, - }); - const console = buildObj(.{ - .b = b, - .name = "console", - .target = target, - .optimize = optimize, - .sources = &.{"llama.cpp/common/console.cpp"}, - .include_paths = &include_paths, - .link_lib_cpp = true, - .link_lib_c = false, - .flags = cxxflags.items, - }); - const sampling = buildObj(.{ - .b = b, - .name = "sampling", - .target = target, - .optimize = optimize, - .sources = &.{"llama.cpp/common/sampling.cpp"}, - .include_paths = &include_paths, - .link_lib_cpp = true, - .link_lib_c = false, - .flags = cxxflags.items, - }); - const grammar_parser = buildObj(.{ - .b = b, - .name = "grammar_parser", - .target = target, - .optimize = optimize, - .sources = &.{"llama.cpp/common/grammar-parser.cpp"}, - .include_paths = &include_paths, - .link_lib_cpp = true, - .link_lib_c = false, - .flags = cxxflags.items, - }); - const build_info = buildObj(.{ - .b = b, - .name = "build_info", - .target = target, - .optimize = optimize, - .sources = &.{"llama.cpp/common/build-info.cpp"}, - .include_paths = &include_paths, - .link_lib_cpp = true, - .link_lib_c = false, - .flags = cxxflags.items, - }); - const ggml_alloc = buildObj(.{ - .b = b, - .name = "ggml_alloc", - .target = target, - .optimize = optimize, - .sources = &.{"llama.cpp/ggml-alloc.c"}, - .include_paths = &include_paths, - .link_lib_c = true, - .link_lib_cpp = false, - .flags = cflags.items, - }); - const ggml_backend = buildObj(.{ - .b = b, - .name = "ggml_backend", - .target = target, - .optimize = optimize, - .sources = &.{"llama.cpp/ggml-backend.c"}, - .include_paths = &include_paths, - .link_lib_c = true, - .link_lib_cpp = false, - .flags = cflags.items, - }); - const ggml_quants = buildObj(.{ - .b = b, - .name = "ggml_quants", - .target = target, - .optimize = optimize, - .sources = &.{"llama.cpp/ggml-quants.c"}, - .include_paths = &include_paths, - .link_lib_c = true, - .link_lib_cpp = false, - .flags = cflags.items, - }); - const unicode = buildObj(.{ + var objs = std.ArrayList(*std.Build.Step.Compile).init(b.allocator); + var objBuilder = ObjBuilder.init(.{ .b = b, - .name = "unicode", .target = target, .optimize = optimize, - .sources = &.{"llama.cpp/unicode.cpp"}, - .include_paths = &include_paths, - .link_lib_c = false, - .link_lib_cpp = true, - .flags = cxxflags.items, + .include_paths = &.{ "llama.cpp", "llama.cpp/common" }, }); - try objs.appendSlice(&.{ llama, ggml, common, console, sampling, grammar_parser, build_info, ggml_alloc, ggml_backend, ggml_quants, unicode }); - - if (target.result.os.tag == .macos) { - const ggml_metal = buildObj(.{ - .b = b, - .name = "ggml_metal", - .target = target, - .optimize = optimize, - .sources = &.{"llama.cpp/ggml-metal.m"}, - .include_paths = &include_paths, - .link_lib_c = true, - .link_lib_cpp = false, - .flags = cflags.items, - }); - const airCommand = b.addSystemCommand(&.{ "xcrun", "-sdk", "macosx", "metal", "-O3", "-c" }); - airCommand.addFileArg(.{ .path = "llama.cpp/ggml-metal.metal" }); - airCommand.addArg("-o"); - const air = airCommand.addOutputFileArg("ggml-metal.air"); - const libCommand = b.addSystemCommand(&.{ "xcrun", "-sdk", "macosx", "metallib" }); - libCommand.addFileArg(air); - libCommand.addArg("-o"); - const lib = libCommand.addOutputFileArg("default.metallib"); - const libInstall = b.addInstallLibFile(lib, "default.metallib"); - b.getInstallStep().dependOn(&libInstall.step); - try objs.append(ggml_metal); - } else { - const ggml_vulkan = buildObj(.{ - .b = b, - .name = "ggml_vulkan", - .target = target, - .optimize = optimize, - .sources = &.{"llama.cpp/ggml-vulkan.cpp"}, - .include_paths = &include_paths, - .link_lib_cpp = true, - .link_lib_c = false, - .flags = cxxflags.items, - }); - try objs.append(ggml_vulkan); + switch (target.result.os.tag) { + .macos => { + try objBuilder.flags.append("-DGGML_USE_METAL"); + try objs.append(objBuilder.build(.{ .name = "ggml_metal", .sources = &.{"llama.cpp/ggml-metal.m"} })); + + lib_llama_cpp.linkFramework("Foundation"); + lib_llama_cpp.linkFramework("Metal"); + lib_llama_cpp.linkFramework("MetalKit"); + + const expand_metal = b.addExecutable(.{ + .name = "expand_metal", + .target = target, + .root_source_file = .{ .path = "tools/expand_metal.zig" }, + }); + var run_expand_metal = b.addRunArtifact(expand_metal); + run_expand_metal.addArg("--metal-file"); + run_expand_metal.addFileArg(.{ .path = "llama.cpp/ggml-metal.metal" }); + run_expand_metal.addArg("--common-file"); + run_expand_metal.addFileArg(.{ .path = "llama.cpp/ggml-common.h" }); + run_expand_metal.addArg("--output-file"); + const metal_expanded = run_expand_metal.addOutputFileArg("ggml-metal.metal"); + const install_metal = b.addInstallFileWithDir(metal_expanded, .lib, "ggml-metal.metal"); + lib_llama_cpp.step.dependOn(&install_metal.step); + }, + else => {}, } - const extension = b.addSharedLibrary(.{ .name = b.fmt("{s}-{s}-{s}", .{ extension_name, zig_triple, @tagName(optimize) }), .target = target, .optimize = optimize }); - const sources = try findFilesRecursive(b, "src", &cfiles_exts); - extension.addCSourceFiles(.{ .files = sources, .flags = &.{ "-std=c++17", "-fno-exceptions" } }); - extension.addIncludePath(.{ .path = "src" }); - extension.addIncludePath(.{ .path = "godot_cpp/include/" }); - extension.addIncludePath(.{ .path = "godot_cpp/gdextension/" }); - extension.addIncludePath(.{ .path = "godot_cpp/gen/include/" }); - extension.addIncludePath(.{ .path = "llama.cpp" }); - extension.addIncludePath(.{ .path = "llama.cpp/common" }); + try objs.appendSlice(&.{ + objBuilder.build(.{ .name = "ggml", .sources = &.{"llama.cpp/ggml.c"} }), + objBuilder.build(.{ .name = "sgemm", .sources = &.{"llama.cpp/sgemm.cpp"} }), + objBuilder.build(.{ .name = "ggml_alloc", .sources = &.{"llama.cpp/ggml-alloc.c"} }), + objBuilder.build(.{ .name = "ggml_backend", .sources = &.{"llama.cpp/ggml-backend.c"} }), + objBuilder.build(.{ .name = "ggml_quants", .sources = &.{"llama.cpp/ggml-quants.c"} }), + objBuilder.build(.{ .name = "llama", .sources = &.{"llama.cpp/llama.cpp"} }), + objBuilder.build(.{ .name = "unicode", .sources = &.{"llama.cpp/unicode.cpp"} }), + objBuilder.build(.{ .name = "unicode_data", .sources = &.{"llama.cpp/unicode-data.cpp"} }), + objBuilder.build(.{ .name = "common", .sources = &.{"llama.cpp/common/common.cpp"} }), + objBuilder.build(.{ .name = "console", .sources = &.{"llama.cpp/common/console.cpp"} }), + objBuilder.build(.{ .name = "sampling", .sources = &.{"llama.cpp/common/sampling.cpp"} }), + objBuilder.build(.{ .name = "grammar_parser", .sources = &.{"llama.cpp/common/grammar-parser.cpp"} }), + objBuilder.build(.{ .name = "json_schema_to_grammar", .sources = &.{"llama.cpp/common/json-schema-to-grammar.cpp"} }), + objBuilder.build(.{ .name = "build_info", .sources = &.{"llama.cpp/common/build-info.cpp"} }), + }); + for (objs.items) |obj| { - extension.addObject(obj); - } - extension.linkLibC(); - extension.linkLibCpp(); - if (target.result.os.tag == .macos) { - extension.linkFramework("Metal"); - extension.linkFramework("MetalKit"); - extension.linkFramework("Foundation"); - extension.linkFramework("Accelerate"); - // b.installFile("llama.cpp/ggml-metal.metal", b.pathJoin(&.{ std.fs.path.basename(b.lib_dir), "ggml-metal.metal" })); - // b.installFile("llama.cpp/ggml-common.h", b.pathJoin(&.{ std.fs.path.basename(b.lib_dir), "ggml-common.h" })); - } else { - if (target.result.os.tag == .windows) { - const vk_path = b.graph.env_map.get("VK_SDK_PATH") orelse @panic("VK_SDK_PATH not set"); - extension.addLibraryPath(.{ .path = b.pathJoin(&.{ vk_path, "Lib" }) }); - extension.linkSystemLibrary("vulkan-1"); - } else { - extension.linkSystemLibrary("vulkan"); - } + lib_llama_cpp.addObject(obj); } - extension.linkLibrary(lib_godot); - b.installArtifact(extension); + return lib_llama_cpp; } -const BuildObjectParams = struct { +const ObjBuilder = struct { b: *std.Build, - name: []const u8, target: std.Build.ResolvedTarget, optimize: std.builtin.OptimizeMode, - sources: []const []const u8, include_paths: []const []const u8, - link_lib_c: bool, - link_lib_cpp: bool, - flags: []const []const u8, -}; - -fn buildObj(params: BuildObjectParams) *std.Build.Step.Compile { - const obj = params.b.addObject(.{ - .name = params.name, - .target = params.target, - .optimize = params.optimize, - }); - if (params.target.result.os.tag == .windows) { - const vk_path = params.b.graph.env_map.get("VK_SDK_PATH") orelse @panic("VK_SDK_PATH not set"); - obj.addIncludePath(.{ .path = params.b.pathJoin(&.{ vk_path, "include" }) }); + flags: std.ArrayList([]const u8), + + fn init(params: struct { + b: *std.Build, + target: std.Build.ResolvedTarget, + optimize: std.builtin.OptimizeMode, + include_paths: []const []const u8, + }) ObjBuilder { + return ObjBuilder{ + .b = params.b, + .target = params.target, + .optimize = params.optimize, + .include_paths = params.include_paths, + .flags = std.ArrayList([]const u8).init(params.b.allocator), + }; } - for (params.include_paths) |path| { - obj.addIncludePath(.{ .path = path }); - } - if (params.link_lib_c) { + + fn build(self: *ObjBuilder, params: struct { name: []const u8, sources: []const []const u8 }) *std.Build.Step.Compile { + const obj = self.b.addObject(.{ .name = params.name, .target = self.target, .optimize = self.optimize }); + obj.addCSourceFiles(.{ .files = params.sources, .flags = self.flags.items }); + for (self.include_paths) |path| { + obj.addIncludePath(.{ .path = path }); + } obj.linkLibC(); - } - if (params.link_lib_cpp) { obj.linkLibCpp(); + return obj; } - obj.addCSourceFiles(.{ .files = params.sources, .flags = params.flags }); - return obj; -} +}; fn findFilesRecursive(b: *std.Build, dir_name: []const u8, exts: []const []const u8) ![][]const u8 { var sources = std.ArrayList([]const u8).init(b.allocator); diff --git a/godot/.gitattributes b/godot/.gitattributes new file mode 100644 index 0000000..8ad74f7 --- /dev/null +++ b/godot/.gitattributes @@ -0,0 +1,2 @@ +# Normalize EOL for all files that Git considers text files. +* text=auto eol=lf diff --git a/godot/.gitignore b/godot/.gitignore new file mode 100644 index 0000000..4709183 --- /dev/null +++ b/godot/.gitignore @@ -0,0 +1,2 @@ +# Godot 4+ specific ignores +.godot/ diff --git a/godot/addons/godot-llama-cpp/autoloads/llama-backend.gd b/godot/addons/godot-llama-cpp/autoloads/llama-backend.gd deleted file mode 100644 index 83bb52e..0000000 --- a/godot/addons/godot-llama-cpp/autoloads/llama-backend.gd +++ /dev/null @@ -1,10 +0,0 @@ -# This script will be autoloaded by the editor plugin -extends Node - -var backend: LlamaBackend = LlamaBackend.new() - -func _enter_tree() -> void: - backend.init() - -func _exit_tree() -> void: - backend.deinit() diff --git a/godot/addons/godot-llama-cpp/chat/chat_formatter.gd b/godot/addons/godot-llama-cpp/chat/chat_formatter.gd new file mode 100644 index 0000000..84d8369 --- /dev/null +++ b/godot/addons/godot-llama-cpp/chat/chat_formatter.gd @@ -0,0 +1,56 @@ +class_name ChatFormatter + +static func apply(format: String, messages: Array) -> String: + match format: + "llama3": + return format_llama3(messages) + "phi3": + return format_phi3(messages) + "mistral": + return format_mistral(messages) + _: + printerr("Unknown chat format: ", format) + return "" + +static func format_llama3(messages: Array) -> String: + var res = "" + + for i in range(messages.size()): + match messages[i]: + {"text": var text, "sender": var sender}: + res += """<|start_header_id|>%s<|end_header_id|> + +%s<|eot_id|> +""" % [sender, text] + _: + printerr("Invalid message at index ", i) + + res += "<|start_header_id|>assistant<|end_header_id|>\n\n" + return res + +static func format_phi3(messages: Array) -> String: + var res = "" + + for i in range(messages.size()): + match messages[i]: + {"text": var text, "sender": var sender}: + res +="<|%s|>\n%s<|end|>\n" % [sender, text] + _: + printerr("Invalid message at index ", i) + res += "<|assistant|>\n" + return res + +static func format_mistral(messages: Array) -> String: + var res = "" + + for i in range(messages.size()): + match messages[i]: + {"text": var text, "sender": var sender}: + if sender == "user": + res += "[INST] %s [/INST]" % text + else: + res += "%s" + _: + printerr("Invalid message at index ", i) + + return res diff --git a/godot/addons/godot-llama-cpp/plugin.cfg b/godot/addons/godot-llama-cpp/plugin.cfg index 36e0b68..6335055 100644 --- a/godot/addons/godot-llama-cpp/plugin.cfg +++ b/godot/addons/godot-llama-cpp/plugin.cfg @@ -4,4 +4,4 @@ name="godot-llama-cpp" description="Run large language models in Godot. Powered by llama.cpp." author="hazelnutcloud" version="0.0.1" -script="godot-llama-cpp.gd" +script="plugin.gd" diff --git a/godot/addons/godot-llama-cpp/godot-llama-cpp.gd b/godot/addons/godot-llama-cpp/plugin.gd similarity index 50% rename from godot/addons/godot-llama-cpp/godot-llama-cpp.gd rename to godot/addons/godot-llama-cpp/plugin.gd index ca63dbc..5cee76e 100644 --- a/godot/addons/godot-llama-cpp/godot-llama-cpp.gd +++ b/godot/addons/godot-llama-cpp/plugin.gd @@ -4,9 +4,9 @@ extends EditorPlugin func _enter_tree(): # Initialization of the plugin goes here. - add_autoload_singleton("__LlamaBackend", "res://addons/godot-llama-cpp/autoloads/llama-backend.gd") + pass func _exit_tree(): # Clean-up of the plugin goes here. - remove_autoload_singleton("__LlamaBackend") + pass diff --git a/godot/addons/godot-llama-cpp/godot-llama-cpp.gdextension b/godot/addons/godot-llama-cpp/plugin.gdextension similarity index 95% rename from godot/addons/godot-llama-cpp/godot-llama-cpp.gdextension rename to godot/addons/godot-llama-cpp/plugin.gdextension index 4171072..9f2561c 100644 --- a/godot/addons/godot-llama-cpp/godot-llama-cpp.gdextension +++ b/godot/addons/godot-llama-cpp/plugin.gdextension @@ -5,8 +5,8 @@ compatibility_minimum = "4.2" [libraries] -macos.debug = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp-aarch64-macos-none-Debug.dylib" -macos.release = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp-aarch64-macos-none-ReleaseFast.dylib" +macos.debug = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp-aarch64-macos-none-ReleaseSafe.dylib" +macos.release = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp-aarch64-macos-none-ReleaseSafe.dylib" windows.debug.x86_32 = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp.windows.template_debug.x86_32.dll" windows.release.x86_32 = "res://addons/godot-llama-cpp/lib/libgodot-llama-cpp.windows.template_release.x86_32.dll" windows.debug.x86_64 = "res://addons/godot-llama-cpp/lib/godot-llama-cpp-x86_64-windows-gnu-Debug.dll" diff --git a/godot/autoloads/llama.tscn b/godot/autoloads/llama.tscn deleted file mode 100644 index 9115818..0000000 --- a/godot/autoloads/llama.tscn +++ /dev/null @@ -1,6 +0,0 @@ -[gd_scene load_steps=2 format=3 uid="uid://bxobxniygk7jm"] - -[ext_resource type="LlamaModel" path="res://models/OGNO-7B-Q4_K_M.gguf" id="1_vd8h8"] - -[node name="LlamaContext" type="LlamaContext"] -model = ExtResource("1_vd8h8") diff --git a/godot/examples/simple/TextEdit.gd b/godot/examples/simple/TextEdit.gd new file mode 100644 index 0000000..b8d4bc0 --- /dev/null +++ b/godot/examples/simple/TextEdit.gd @@ -0,0 +1,20 @@ +extends TextEdit + +signal submit(input: String) + +func _gui_input(event: InputEvent) -> void: + if event is InputEventKey: + var keycode = event.get_keycode_with_modifiers() + if keycode == KEY_ENTER and event.is_pressed(): + handle_submit() + accept_event() + if keycode == KEY_ENTER | KEY_MASK_SHIFT and event.is_pressed(): + insert_text_at_caret("\n") + accept_event() + +func _on_button_pressed() -> void: + handle_submit() + +func handle_submit() -> void: + submit.emit(text) + text = "" diff --git a/godot/examples/simple/form.gd b/godot/examples/simple/form.gd new file mode 100644 index 0000000..092aa73 --- /dev/null +++ b/godot/examples/simple/form.gd @@ -0,0 +1,6 @@ +extends HBoxContainer + +@onready var text_edit = %TextEdit + +func _on_button_pressed() -> void: + text_edit.handle_submit() diff --git a/godot/examples/simple/message.gd b/godot/examples/simple/message.gd new file mode 100644 index 0000000..63b4f31 --- /dev/null +++ b/godot/examples/simple/message.gd @@ -0,0 +1,23 @@ +class_name Message +extends Node + +@onready var text_container = %Text +@onready var icon = %Panel +@export_enum("user", "assistant") var sender: String +@export var include_in_prompt: bool = true +var text: + get: + return text_container.text + set(value): + text_container.text = value + +var completion_id: int = -1 +var pending: bool = false +var errored: bool = false + +func set_text(new_text: String): + text_container.text = new_text + +func append_text(new_text: String): + text_container.text += new_text + diff --git a/godot/examples/simple/message.tscn b/godot/examples/simple/message.tscn new file mode 100644 index 0000000..36b68e7 --- /dev/null +++ b/godot/examples/simple/message.tscn @@ -0,0 +1,37 @@ +[gd_scene load_steps=5 format=3 uid="uid://t862t0v8ht2q"] + +[ext_resource type="Script" path="res://examples/simple/message.gd" id="1_pko33"] +[ext_resource type="Texture2D" uid="uid://dplw232htshgc" path="res://addons/godot-llama-cpp/assets/godot-llama-cpp-1024x1024.svg" id="2_dvc7y"] + +[sub_resource type="StyleBoxTexture" id="StyleBoxTexture_t8bgj"] +texture = ExtResource("2_dvc7y") + +[sub_resource type="Theme" id="Theme_bw3pb"] +Panel/styles/panel = SubResource("StyleBoxTexture_t8bgj") + +[node name="RichTextLabel" type="HBoxContainer"] +anchors_preset = 15 +anchor_right = 1.0 +anchor_bottom = 1.0 +grow_horizontal = 2 +grow_vertical = 2 +size_flags_horizontal = 3 +theme_override_constants/separation = 20 +script = ExtResource("1_pko33") +sender = "assistant" + +[node name="Panel" type="Panel" parent="."] +unique_name_in_owner = true +custom_minimum_size = Vector2(80, 80) +layout_mode = 2 +size_flags_vertical = 0 +theme = SubResource("Theme_bw3pb") + +[node name="Text" type="RichTextLabel" parent="."] +unique_name_in_owner = true +layout_mode = 2 +size_flags_horizontal = 3 +focus_mode = 2 +text = "..." +fit_content = true +selection_enabled = true diff --git a/godot/examples/simple/simple.gd b/godot/examples/simple/simple.gd new file mode 100644 index 0000000..4519f5d --- /dev/null +++ b/godot/examples/simple/simple.gd @@ -0,0 +1,53 @@ +extends Node + +const message = preload("res://examples/simple/message.tscn") + +@onready var messages_container = %MessagesContainer +@onready var llama_context = %LlamaContext + +func _on_text_edit_submit(input: String) -> void: + handle_input(input) + +func handle_input(input: String) -> void: + #var messages = [{ "sender": "system", "text": "You are a pirate chatbot who always responds in pirate speak!" }] + + #var messages = [{ "sender": "system", "text": "You are a helpful chatbot assistant!" }] + var messages = [] + messages.append_array(messages_container.get_children().filter(func(msg: Message): return msg.include_in_prompt).map( + func(msg: Message) -> Dictionary: + return { "text": msg.text, "sender": msg.sender } + )) + messages.append({"text": input, "sender": "user"}) + var prompt = ChatFormatter.apply("llama3", messages) + print("prompt: ", prompt) + + var completion_id = llama_context.request_completion(prompt) + + var user_message: Message = message.instantiate() + messages_container.add_child(user_message) + user_message.set_text(input) + user_message.sender = "user" + user_message.completion_id = completion_id + + var ai_message: Message = message.instantiate() + messages_container.add_child(ai_message) + ai_message.sender = "assistant" + ai_message.completion_id = completion_id + ai_message.pending = true + ai_message.grab_focus() + + + +func _on_llama_context_completion_generated(chunk: Dictionary) -> void: + var completion_id = chunk.id + for msg: Message in messages_container.get_children(): + if msg.completion_id != completion_id or msg.sender != "assistant": + continue + if chunk.has("error"): + msg.errored = true + elif chunk.has("text"): + if msg.pending: + msg.pending = false + msg.set_text(chunk["text"]) + else: + msg.append_text(chunk["text"]) diff --git a/godot/examples/simple/simple.tscn b/godot/examples/simple/simple.tscn new file mode 100644 index 0000000..50b193d --- /dev/null +++ b/godot/examples/simple/simple.tscn @@ -0,0 +1,79 @@ +[gd_scene load_steps=6 format=3 uid="uid://c55kb4qvg6geq"] + +[ext_resource type="Texture2D" uid="uid://dplw232htshgc" path="res://addons/godot-llama-cpp/assets/godot-llama-cpp-1024x1024.svg" id="1_gjsev"] +[ext_resource type="Script" path="res://examples/simple/simple.gd" id="1_sruc3"] +[ext_resource type="PackedScene" uid="uid://t862t0v8ht2q" path="res://examples/simple/message.tscn" id="2_7iip7"] +[ext_resource type="Script" path="res://examples/simple/TextEdit.gd" id="2_7usqw"] +[ext_resource type="LlamaModel" path="res://models/meta-llama-3-8b-instruct.Q5_K_M.gguf" id="5_qov1l"] + +[node name="Node" type="Node"] +script = ExtResource("1_sruc3") + +[node name="Panel" type="Panel" parent="."] +anchors_preset = 15 +anchor_right = 1.0 +anchor_bottom = 1.0 +grow_horizontal = 2 +grow_vertical = 2 + +[node name="MarginContainer" type="MarginContainer" parent="Panel"] +layout_mode = 1 +anchors_preset = 15 +anchor_right = 1.0 +anchor_bottom = 1.0 +grow_horizontal = 2 +grow_vertical = 2 +theme_override_constants/margin_left = 10 +theme_override_constants/margin_top = 10 +theme_override_constants/margin_right = 10 +theme_override_constants/margin_bottom = 10 + +[node name="VBoxContainer" type="VBoxContainer" parent="Panel/MarginContainer"] +layout_mode = 2 + +[node name="ScrollContainer" type="ScrollContainer" parent="Panel/MarginContainer/VBoxContainer"] +layout_mode = 2 +size_flags_vertical = 3 +follow_focus = true + +[node name="MessagesContainer" type="VBoxContainer" parent="Panel/MarginContainer/VBoxContainer/ScrollContainer"] +unique_name_in_owner = true +layout_mode = 2 +size_flags_horizontal = 3 +size_flags_vertical = 3 +theme_override_constants/separation = 30 + +[node name="RichTextLabel2" parent="Panel/MarginContainer/VBoxContainer/ScrollContainer/MessagesContainer" instance=ExtResource("2_7iip7")] +layout_mode = 2 +include_in_prompt = false + +[node name="Text" parent="Panel/MarginContainer/VBoxContainer/ScrollContainer/MessagesContainer/RichTextLabel2" index="1"] +text = "How can I help you?" + +[node name="HBoxContainer" type="HBoxContainer" parent="Panel/MarginContainer/VBoxContainer"] +layout_mode = 2 + +[node name="TextEdit" type="TextEdit" parent="Panel/MarginContainer/VBoxContainer/HBoxContainer"] +custom_minimum_size = Vector2(2.08165e-12, 100) +layout_mode = 2 +size_flags_horizontal = 3 +placeholder_text = "Ask me anything..." +wrap_mode = 1 +script = ExtResource("2_7usqw") + +[node name="Button" type="Button" parent="Panel/MarginContainer/VBoxContainer/HBoxContainer"] +custom_minimum_size = Vector2(100, 2.08165e-12) +layout_mode = 2 +icon = ExtResource("1_gjsev") +expand_icon = true + +[node name="LlamaContext" type="LlamaContext" parent="."] +model = ExtResource("5_qov1l") +temperature = 0.9 +unique_name_in_owner = true + +[connection signal="submit" from="Panel/MarginContainer/VBoxContainer/HBoxContainer/TextEdit" to="." method="_on_text_edit_submit"] +[connection signal="pressed" from="Panel/MarginContainer/VBoxContainer/HBoxContainer/Button" to="Panel/MarginContainer/VBoxContainer/HBoxContainer/TextEdit" method="_on_button_pressed"] +[connection signal="completion_generated" from="LlamaContext" to="." method="_on_llama_context_completion_generated"] + +[editable path="Panel/MarginContainer/VBoxContainer/ScrollContainer/MessagesContainer/RichTextLabel2"] diff --git a/godot/icon.svg b/godot/icon.svg new file mode 100644 index 0000000..3fe4f4a --- /dev/null +++ b/godot/icon.svg @@ -0,0 +1 @@ + diff --git a/godot/icon.svg.import b/godot/icon.svg.import new file mode 100644 index 0000000..5a5ae61 --- /dev/null +++ b/godot/icon.svg.import @@ -0,0 +1,37 @@ +[remap] + +importer="texture" +type="CompressedTexture2D" +uid="uid://beeg0oqle7bnk" +path="res://.godot/imported/icon.svg-218a8f2b3041327d8a5756f3a245f83b.ctex" +metadata={ +"vram_texture": false +} + +[deps] + +source_file="res://icon.svg" +dest_files=["res://.godot/imported/icon.svg-218a8f2b3041327d8a5756f3a245f83b.ctex"] + +[params] + +compress/mode=0 +compress/high_quality=false +compress/lossy_quality=0.7 +compress/hdr_compression=1 +compress/normal_map=0 +compress/channel_pack=0 +mipmaps/generate=false +mipmaps/limit=-1 +roughness/mode=0 +roughness/src_normal="" +process/fix_alpha_border=true +process/premult_alpha=false +process/normal_map_invert_y=false +process/hdr_as_srgb=false +process/hdr_clamp_exposure=false +process/size_limit=0 +detect_3d/compress_to=1 +svg/scale=1.0 +editor/scale_with_editor_scale=false +editor/convert_colors_with_editor_theme=false diff --git a/godot/main.gd b/godot/main.gd deleted file mode 100644 index 6c258e3..0000000 --- a/godot/main.gd +++ /dev/null @@ -1,27 +0,0 @@ -extends Node - -@onready var input: TextEdit = %Input -@onready var submit_button: Button = %SubmitButton -@onready var output: Label = %Output - -func _on_button_pressed(): - handle_submit() - -func handle_submit(): - print(input.text) - Llama.request_completion(input.text) - - input.clear() - input.editable = false - submit_button.disabled = true - output.text = "..." - - var completion = await Llama.completion_generated - output.text = "" - while !completion[1]: - print(completion[0]) - output.text += completion[0] - completion = await Llama.completion_generated - - input.editable = true - submit_button.disabled = false diff --git a/godot/main.tscn b/godot/main.tscn deleted file mode 100644 index 7bda86d..0000000 --- a/godot/main.tscn +++ /dev/null @@ -1,103 +0,0 @@ -[gd_scene load_steps=4 format=3 uid="uid://7oo8yj56scb1"] - -[ext_resource type="Texture2D" uid="uid://dplw232htshgc" path="res://addons/godot-llama-cpp/assets/godot-llama-cpp-1024x1024.svg" id="1_ojdoj"] -[ext_resource type="Script" path="res://main.gd" id="1_vvrqe"] - -[sub_resource type="StyleBoxFlat" id="StyleBoxFlat_3e37a"] -corner_radius_top_left = 5 -corner_radius_top_right = 5 -corner_radius_bottom_right = 5 -corner_radius_bottom_left = 5 - -[node name="Main" type="Node"] -script = ExtResource("1_vvrqe") - -[node name="Background" type="ColorRect" parent="."] -anchors_preset = 15 -anchor_right = 1.0 -anchor_bottom = 1.0 -grow_horizontal = 2 -grow_vertical = 2 -color = Color(0.980392, 0.952941, 0.929412, 1) - -[node name="CenterContainer" type="CenterContainer" parent="."] -anchors_preset = 8 -anchor_left = 0.5 -anchor_top = 0.5 -anchor_right = 0.5 -anchor_bottom = 0.5 -offset_left = -400.0 -offset_top = -479.0 -offset_right = 400.0 -offset_bottom = 479.0 -grow_horizontal = 2 -grow_vertical = 2 - -[node name="VBoxContainer" type="VBoxContainer" parent="CenterContainer"] -custom_minimum_size = Vector2(500, 0) -layout_mode = 2 -theme_override_constants/separation = 10 -alignment = 1 - -[node name="Name" type="Label" parent="CenterContainer/VBoxContainer"] -layout_mode = 2 -theme_override_colors/font_color = Color(0.101961, 0.0823529, 0.0627451, 1) -theme_override_font_sizes/font_size = 32 -text = "godot-llama-cpp" -horizontal_alignment = 1 - -[node name="MarginContainer" type="MarginContainer" parent="CenterContainer/VBoxContainer"] -layout_mode = 2 -theme_override_constants/margin_left = 100 -theme_override_constants/margin_right = 100 - -[node name="TextureRect" type="TextureRect" parent="CenterContainer/VBoxContainer/MarginContainer"] -layout_mode = 2 -texture = ExtResource("1_ojdoj") -expand_mode = 4 - -[node name="ScrollContainer" type="ScrollContainer" parent="CenterContainer/VBoxContainer"] -custom_minimum_size = Vector2(2.08165e-12, 150) -layout_mode = 2 -horizontal_scroll_mode = 0 - -[node name="Panel" type="PanelContainer" parent="CenterContainer/VBoxContainer/ScrollContainer"] -custom_minimum_size = Vector2(2.08165e-12, 2.08165e-12) -layout_mode = 2 -size_flags_horizontal = 3 -size_flags_vertical = 3 -theme_override_styles/panel = SubResource("StyleBoxFlat_3e37a") - -[node name="MarginContainer" type="MarginContainer" parent="CenterContainer/VBoxContainer/ScrollContainer/Panel"] -layout_mode = 2 -theme_override_constants/margin_left = 20 -theme_override_constants/margin_right = 20 - -[node name="Output" type="Label" parent="CenterContainer/VBoxContainer/ScrollContainer/Panel/MarginContainer"] -unique_name_in_owner = true -custom_minimum_size = Vector2(200, 2.08165e-12) -layout_mode = 2 -theme_override_colors/font_color = Color(0.101961, 0.0823529, 0.0627451, 1) -text = "Ask me anything!" -autowrap_mode = 3 - -[node name="Form" type="HBoxContainer" parent="CenterContainer/VBoxContainer"] -custom_minimum_size = Vector2(500, 60) -layout_mode = 2 -size_flags_horizontal = 4 -alignment = 1 - -[node name="Input" type="TextEdit" parent="CenterContainer/VBoxContainer/Form"] -unique_name_in_owner = true -layout_mode = 2 -size_flags_horizontal = 3 -size_flags_stretch_ratio = 3.0 -placeholder_text = "Why do cows moo?" - -[node name="SubmitButton" type="Button" parent="CenterContainer/VBoxContainer/Form"] -unique_name_in_owner = true -layout_mode = 2 -size_flags_horizontal = 3 -text = "Submit" - -[connection signal="pressed" from="CenterContainer/VBoxContainer/Form/SubmitButton" to="." method="_on_button_pressed"] diff --git a/godot/project.godot b/godot/project.godot index e5ff82f..4730d03 100644 --- a/godot/project.godot +++ b/godot/project.godot @@ -11,35 +11,14 @@ config_version=5 [application] config/name="godot-llama-cpp" -run/main_scene="res://main.tscn" +run/main_scene="res://examples/simple/simple.tscn" config/features=PackedStringArray("4.2", "Forward Plus") -config/icon="res://addons/godot-llama-cpp/assets/godot-llama-cpp-1024x1024.svg" - -[autoload] - -__LlamaBackend="*res://addons/godot-llama-cpp/autoloads/llama-backend.gd" -Llama="*res://autoloads/llama.tscn" - -[display] - -window/size/viewport_width=1280 -window/size/viewport_height=720 +config/icon="res://icon.svg" [editor_plugins] enabled=PackedStringArray("res://addons/godot-llama-cpp/plugin.cfg") -[input] - -submit_form={ -"deadzone": 0.5, -"events": [Object(InputEventKey,"resource_local_to_scene":false,"resource_name":"","device":-1,"window_id":0,"alt_pressed":false,"shift_pressed":false,"ctrl_pressed":false,"meta_pressed":false,"pressed":false,"keycode":0,"physical_keycode":4194309,"key_label":0,"unicode":0,"echo":false,"script":null) -] -} - -[rendering] +[gui] -anti_aliasing/quality/msaa_2d=3 -anti_aliasing/quality/msaa_3d=3 -anti_aliasing/quality/screen_space_aa=1 -anti_aliasing/quality/use_taa=true +theme/default_theme_scale=2.0 diff --git a/godot_cpp b/godot_cpp index 51c752c..b28098e 160000 --- a/godot_cpp +++ b/godot_cpp @@ -1 +1 @@ -Subproject commit 51c752c46b44769d3b6c661526c364a18ea64781 +Subproject commit b28098e76b84e8831b8ac68d490f4bca44678b2a diff --git a/llama.cpp b/llama.cpp index 4755afd..9afdffe 160000 --- a/llama.cpp +++ b/llama.cpp @@ -1 +1 @@ -Subproject commit 4755afd1cbd40d93c017e5b98c39796f52345314 +Subproject commit 9afdffe70ebf3166d429b4434783bb0b7f97bdeb diff --git a/src/llama_backend.cpp b/src/llama_backend.cpp deleted file mode 100644 index 0e2c0a5..0000000 --- a/src/llama_backend.cpp +++ /dev/null @@ -1,19 +0,0 @@ -#include "llama.h" -#include "llama_backend.h" -#include - -using namespace godot; - -void LlamaBackend::init() { - llama_backend_init(); - llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_DISABLED); -} - -void LlamaBackend::deinit() { - llama_backend_free(); -} - -void LlamaBackend::_bind_methods() { - ClassDB::bind_method(D_METHOD("init"), &LlamaBackend::init); - ClassDB::bind_method(D_METHOD("deinit"), &LlamaBackend::deinit); -} \ No newline at end of file diff --git a/src/llama_backend.h b/src/llama_backend.h deleted file mode 100644 index 8b0628f..0000000 --- a/src/llama_backend.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef LLAMA_BACKEND_H -#define LLAMA_BACKEND_H - -#include - -namespace godot { -class LlamaBackend : public RefCounted { - GDCLASS(LlamaBackend, RefCounted) - -protected: - static void _bind_methods(); - -public: - void init(); - void deinit(); -}; -} //namespace godot - -#endif \ No newline at end of file diff --git a/src/llama_context.cpp b/src/llama_context.cpp index 78ad255..0fcfc2f 100644 --- a/src/llama_context.cpp +++ b/src/llama_context.cpp @@ -2,10 +2,12 @@ #include "common.h" #include "llama.h" #include "llama_model.h" +#include #include #include #include #include +#include #include using namespace godot; @@ -15,31 +17,41 @@ void LlamaContext::_bind_methods() { ClassDB::bind_method(D_METHOD("get_model"), &LlamaContext::get_model); ClassDB::add_property("LlamaContext", PropertyInfo(Variant::OBJECT, "model", PROPERTY_HINT_RESOURCE_TYPE, "LlamaModel"), "set_model", "get_model"); - ClassDB::bind_method(D_METHOD("get_seed"), &LlamaContext::get_seed); - ClassDB::bind_method(D_METHOD("set_seed", "seed"), &LlamaContext::set_seed); - ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "seed"), "set_seed", "get_seed"); + ClassDB::bind_method(D_METHOD("get_seed"), &LlamaContext::get_seed); + ClassDB::bind_method(D_METHOD("set_seed", "seed"), &LlamaContext::set_seed); + ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "seed"), "set_seed", "get_seed"); - ClassDB::bind_method(D_METHOD("get_n_ctx"), &LlamaContext::get_n_ctx); - ClassDB::bind_method(D_METHOD("set_n_ctx", "n_ctx"), &LlamaContext::set_n_ctx); - ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_ctx"), "set_n_ctx", "get_n_ctx"); + ClassDB::bind_method(D_METHOD("get_temperature"), &LlamaContext::get_temperature); + ClassDB::bind_method(D_METHOD("set_temperature", "temperature"), &LlamaContext::set_temperature); + ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "temperature"), "set_temperature", "get_temperature"); - ClassDB::bind_method(D_METHOD("get_n_threads"), &LlamaContext::get_n_threads); - ClassDB::bind_method(D_METHOD("set_n_threads", "n_threads"), &LlamaContext::set_n_threads); - ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_threads"), "set_n_threads", "get_n_threads"); + ClassDB::bind_method(D_METHOD("get_top_p"), &LlamaContext::get_top_p); + ClassDB::bind_method(D_METHOD("set_top_p", "top_p"), &LlamaContext::set_top_p); + ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "top_p"), "set_top_p", "get_top_p"); - ClassDB::bind_method(D_METHOD("get_n_threads_batch"), &LlamaContext::get_n_threads_batch); - ClassDB::bind_method(D_METHOD("set_n_threads_batch", "n_threads_batch"), &LlamaContext::set_n_threads_batch); - ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_threads_batch"), "set_n_threads_batch", "get_n_threads_batch"); + ClassDB::bind_method(D_METHOD("get_frequency_penalty"), &LlamaContext::get_frequency_penalty); + ClassDB::bind_method(D_METHOD("set_frequency_penalty", "frequency_penalty"), &LlamaContext::set_frequency_penalty); + ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "frequency_penalty"), "set_frequency_penalty", "get_frequency_penalty"); + + ClassDB::bind_method(D_METHOD("get_presence_penalty"), &LlamaContext::get_presence_penalty); + ClassDB::bind_method(D_METHOD("set_presence_penalty", "presence_penalty"), &LlamaContext::set_presence_penalty); + ClassDB::add_property("LlamaContext", PropertyInfo(Variant::FLOAT, "presence_penalty"), "set_presence_penalty", "get_presence_penalty"); + + ClassDB::bind_method(D_METHOD("get_n_ctx"), &LlamaContext::get_n_ctx); + ClassDB::bind_method(D_METHOD("set_n_ctx", "n_ctx"), &LlamaContext::set_n_ctx); + ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_ctx"), "set_n_ctx", "get_n_ctx"); + + ClassDB::bind_method(D_METHOD("get_n_len"), &LlamaContext::get_n_len); + ClassDB::bind_method(D_METHOD("set_n_len", "n_len"), &LlamaContext::set_n_len); + ClassDB::add_property("LlamaContext", PropertyInfo(Variant::INT, "n_len"), "set_n_len", "get_n_len"); ClassDB::bind_method(D_METHOD("request_completion", "prompt"), &LlamaContext::request_completion); - ClassDB::bind_method(D_METHOD("_fulfill_completion", "prompt"), &LlamaContext::_fulfill_completion); + ClassDB::bind_method(D_METHOD("__thread_loop"), &LlamaContext::__thread_loop); - ADD_SIGNAL(MethodInfo("completion_generated", PropertyInfo(Variant::STRING, "completion"), PropertyInfo(Variant::BOOL, "is_final"))); + ADD_SIGNAL(MethodInfo("completion_generated", PropertyInfo(Variant::DICTIONARY, "chunk"))); } LlamaContext::LlamaContext() { - batch = llama_batch_init(4096, 0, 1); - ctx_params = llama_context_default_params(); ctx_params.seed = -1; ctx_params.n_ctx = 4096; @@ -60,109 +72,186 @@ void LlamaContext::_ready() { return; } + mutex.instantiate(); + semaphore.instantiate(); + thread.instantiate(); + + llama_backend_init(); + llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_DISABLED); + ctx = llama_new_context_with_model(model->model, ctx_params); if (ctx == NULL) { UtilityFunctions::printerr(vformat("%s: Failed to initialize llama context, null ctx", __func__)); return; } + + sampling_ctx = llama_sampling_init(sampling_params); + UtilityFunctions::print(vformat("%s: Context initialized", __func__)); -} -PackedStringArray LlamaContext::_get_configuration_warnings() const { - PackedStringArray warnings; - if (model == NULL) { - warnings.push_back("Model resource property not defined"); - } - return warnings; + thread->start(callable_mp(this, &LlamaContext::__thread_loop)); } -Variant LlamaContext::request_completion(const String &prompt) { - UtilityFunctions::print(vformat("%s: Requesting completion for prompt: %s", __func__, prompt)); - if (task_id) { - WorkerThreadPool::get_singleton()->wait_for_task_completion(task_id); - } - task_id = WorkerThreadPool::get_singleton()->add_task(Callable(this, "_fulfill_completion").bind(prompt)); - return OK; -} +void LlamaContext::__thread_loop() { + while (true) { + semaphore->wait(); -void LlamaContext::_fulfill_completion(const String &prompt) { - UtilityFunctions::print(vformat("%s: Fulfilling completion for prompt: %s", __func__, prompt)); - std::vector tokens_list; - tokens_list = ::llama_tokenize(ctx, std::string(prompt.utf8().get_data()), true); + mutex->lock(); + if (exit_thread) { + mutex->unlock(); + break; + } + if (completion_requests.size() == 0) { + mutex->unlock(); + continue; + } + completion_request req = completion_requests.get(0); + completion_requests.remove_at(0); + mutex->unlock(); - const int n_len = 128; - const int n_ctx = llama_n_ctx(ctx); - const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size()); - if (n_kv_req > n_ctx) { - UtilityFunctions::printerr(vformat("%s: n_kv_req > n_ctx, the required KV cache size is not big enough\neither reduce n_len or increase n_ctx", __func__)); - return; - } + UtilityFunctions::print(vformat("%s: Running completion for prompt id: %d", __func__, req.id)); - for (size_t i = 0; i < tokens_list.size(); i++) { - llama_batch_add(batch, tokens_list[i], i, { 0 }, false); - } + std::vector request_tokens; + request_tokens = ::llama_tokenize(ctx, req.prompt.utf8().get_data(), true, true); - batch.logits[batch.n_tokens - 1] = true; + size_t shared_prefix_idx = 0; + auto diff = std::mismatch(context_tokens.begin(), context_tokens.end(), request_tokens.begin(), request_tokens.end()); + if (diff.first != context_tokens.end()) { + shared_prefix_idx = std::distance(context_tokens.begin(), diff.first); + } else { + shared_prefix_idx = std::min(context_tokens.size(), request_tokens.size()); + } - llama_kv_cache_clear(ctx); + bool rm_success = llama_kv_cache_seq_rm(ctx, -1, shared_prefix_idx, -1); + if (!rm_success) { + UtilityFunctions::printerr(vformat("%s: Failed to remove tokens from kv cache", __func__)); + Dictionary response; + response["id"] = req.id; + response["error"] = "Failed to remove tokens from kv cache"; + call_thread_safe("emit_signal", "completion_generated", response); + continue; + } + context_tokens.erase(context_tokens.begin() + shared_prefix_idx, context_tokens.end()); + request_tokens.erase(request_tokens.begin(), request_tokens.begin() + shared_prefix_idx); - int decode_res = llama_decode(ctx, batch); - if (decode_res != 0) { - UtilityFunctions::printerr(vformat("%s: Failed to decode prompt with error code: %d", __func__, decode_res)); - return; - } + uint batch_size = std::min(ctx_params.n_batch, (uint)request_tokens.size()); + + llama_batch batch = llama_batch_init(batch_size, 0, 1); + + // chunk request_tokens into sequences of size batch_size + std::vector> sequences; + for (size_t i = 0; i < request_tokens.size(); i += batch_size) { + sequences.push_back(std::vector(request_tokens.begin() + i, request_tokens.begin() + std::min(i + batch_size, request_tokens.size()))); + } + + printf("Request tokens: \n"); + for (auto sequence : sequences) { + for (auto token : sequence) { + printf("%s", llama_token_to_piece(ctx, token).c_str()); + } + } + printf("\n"); - int n_cur = batch.n_tokens; - int n_decode = 0; - llama_model *llama_model = model->model; + int curr_token_pos = context_tokens.size(); + bool decode_failed = false; - while (n_cur <= n_len) { - // sample the next token - { - auto n_vocab = llama_n_vocab(llama_model); - auto *logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + for (size_t i = 0; i < sequences.size(); i++) { + llama_batch_clear(batch); - std::vector candidates; - candidates.reserve(n_vocab); + std::vector sequence = sequences[i]; - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + for (size_t j = 0; j < sequence.size(); j++) { + llama_batch_add(batch, sequence[j], j + curr_token_pos, { 0 }, false); + curr_token_pos++; } - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + if (i == sequences.size() - 1) { + batch.logits[batch.n_tokens - 1] = true; + } - // sample the most likely token - const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + if (llama_decode(ctx, batch) != 0) { + decode_failed = true; + break; + } + } - // is it an end of stream? - if (new_token_id == llama_token_eos(llama_model) || n_cur == n_len) { - call_thread_safe("emit_signal", "completion_generated", "\n", true); + if (decode_failed) { + Dictionary response; + response["id"] = req.id; + response["error"] = "llama_decode() failed"; + call_thread_safe("emit_signal", "completion_generated", response); + continue; + } + + context_tokens.insert(context_tokens.end(), request_tokens.begin(), request_tokens.end()); + + while (true) { + if (exit_thread) { + return; + } + llama_token new_token_id = llama_sampling_sample(sampling_ctx, ctx, NULL, batch.n_tokens - 1); + llama_sampling_accept(sampling_ctx, ctx, new_token_id, false); + Dictionary response; + response["id"] = req.id; + + context_tokens.push_back(new_token_id); + + if (llama_token_is_eog(model->model, new_token_id) || curr_token_pos == n_len) { + response["done"] = true; + call_thread_safe("emit_signal", "completion_generated", response); break; } - call_thread_safe("emit_signal", "completion_generated", vformat("%s", llama_token_to_piece(ctx, new_token_id).c_str()), false); + response["text"] = llama_token_to_piece(ctx, new_token_id).c_str(); + response["done"] = false; + call_thread_safe("emit_signal", "completion_generated", response); - // prepare the next batch llama_batch_clear(batch); - // push this new token for next evaluation - llama_batch_add(batch, new_token_id, n_cur, { 0 }, true); + llama_batch_add(batch, new_token_id, curr_token_pos, { 0 }, true); - n_decode += 1; - } + curr_token_pos++; - n_cur += 1; + if (llama_decode(ctx, batch) != 0) { + decode_failed = true; + break; + } + } - // evaluate the current batch with the transformer model - int decode_res = llama_decode(ctx, batch); - if (decode_res != 0) { - UtilityFunctions::printerr(vformat("%s: Failed to decode batch with error code: %d", __func__, decode_res)); - break; + if (decode_failed) { + Dictionary response; + response["id"] = req.id; + response["error"] = "llama_decode() failed"; + call_thread_safe("emit_signal", "completion_generated", response); + continue; } } } +PackedStringArray LlamaContext::_get_configuration_warnings() const { + PackedStringArray warnings; + if (model == NULL) { + warnings.push_back("Model resource property not defined"); + } + return warnings; +} + +int LlamaContext::request_completion(const String &prompt) { + int id = request_id++; + + UtilityFunctions::print(vformat("%s: Requesting completion for prompt id: %d", __func__, id)); + + mutex->lock(); + completion_request req = { id, prompt }; + completion_requests.append(req); + mutex->unlock(); + + semaphore->post(); + + return id; +} + void LlamaContext::set_model(const Ref p_model) { model = p_model; } @@ -184,28 +273,58 @@ void LlamaContext::set_n_ctx(int n_ctx) { ctx_params.n_ctx = n_ctx; } -int LlamaContext::get_n_threads() { - return ctx_params.n_threads; +int LlamaContext::get_n_len() { + return n_len; } -void LlamaContext::set_n_threads(int n_threads) { - ctx_params.n_threads = n_threads; +void LlamaContext::set_n_len(int n_len) { + this->n_len = n_len; } -int LlamaContext::get_n_threads_batch() { - return ctx_params.n_threads_batch; +float LlamaContext::get_temperature() { + return sampling_params.temp; } -void LlamaContext::set_n_threads_batch(int n_threads_batch) { - ctx_params.n_threads_batch = n_threads_batch; +void LlamaContext::set_temperature(float temperature) { + sampling_params.temp = temperature; } -LlamaContext::~LlamaContext() { - if (ctx) { - llama_free(ctx); +float LlamaContext::get_top_p() { + return sampling_params.top_p; +} +void LlamaContext::set_top_p(float top_p) { + sampling_params.top_p = top_p; +} + +float LlamaContext::get_frequency_penalty() { + return sampling_params.penalty_freq; +} +void LlamaContext::set_frequency_penalty(float frequency_penalty) { + sampling_params.penalty_freq = frequency_penalty; +} + +float LlamaContext::get_presence_penalty() { + return sampling_params.penalty_present; +} +void LlamaContext::set_presence_penalty(float presence_penalty) { + sampling_params.penalty_present = presence_penalty; +} + +void LlamaContext::_exit_tree() { + if (Engine::get_singleton()->is_editor_hint()) { + return; } - llama_batch_free(batch); + mutex->lock(); + exit_thread = true; + mutex->unlock(); + + semaphore->post(); - if (task_id) { - WorkerThreadPool::get_singleton()->wait_for_task_completion(task_id); + thread->wait_to_finish(); + + if (ctx) { + llama_free(ctx); } + + llama_sampling_free(sampling_ctx); + llama_backend_free(); } \ No newline at end of file diff --git a/src/llama_context.h b/src/llama_context.h index 5db3789..07bfd3e 100644 --- a/src/llama_context.h +++ b/src/llama_context.h @@ -2,19 +2,38 @@ #define LLAMA_CONTEXT_H #include "llama.h" +#include "common.h" #include "llama_model.h" +#include #include - +#include +#include +#include namespace godot { + +struct completion_request { + int id; + String prompt; +}; + class LlamaContext : public Node { GDCLASS(LlamaContext, Node) private: Ref model; llama_context *ctx = nullptr; + llama_sampling_context *sampling_ctx = nullptr; llama_context_params ctx_params; - llama_batch batch; - int task_id; + llama_sampling_params sampling_params; + int n_len = 1024; + int request_id = 0; + Vector completion_requests; + + Ref thread; + Ref semaphore; + Ref mutex; + std::vector context_tokens; + bool exit_thread = false; protected: static void _bind_methods(); @@ -23,22 +42,28 @@ class LlamaContext : public Node { void set_model(const Ref model); Ref get_model(); - Variant request_completion(const String &prompt); - void _fulfill_completion(const String &prompt); + int request_completion(const String &prompt); + void __thread_loop(); - int get_seed(); - void set_seed(int seed); - int get_n_ctx(); - void set_n_ctx(int n_ctx); - int get_n_threads(); - void set_n_threads(int n_threads); - int get_n_threads_batch(); - void set_n_threads_batch(int n_threads_batch); + int get_seed(); + void set_seed(int seed); + int get_n_ctx(); + void set_n_ctx(int n_ctx); + int get_n_len(); + void set_n_len(int n_len); + float get_temperature(); + void set_temperature(float temperature); + float get_top_p(); + void set_top_p(float top_p); + float get_frequency_penalty(); + void set_frequency_penalty(float frequency_penalty); + float get_presence_penalty(); + void set_presence_penalty(float presence_penalty); - virtual PackedStringArray _get_configuration_warnings() const override; + virtual PackedStringArray _get_configuration_warnings() const override; virtual void _ready() override; - LlamaContext(); - ~LlamaContext(); + virtual void _exit_tree() override; + LlamaContext(); }; } //namespace godot diff --git a/src/llama_model.cpp b/src/llama_model.cpp index 92cb9db..96eea4f 100644 --- a/src/llama_model.cpp +++ b/src/llama_model.cpp @@ -1,5 +1,6 @@ #include "llama_model.h" #include "llama.h" +#include #include #include @@ -22,14 +23,16 @@ void LlamaModel::load_model(const String &path) { llama_free_model(model); } - model = llama_load_model_from_file(path.utf8().get_data(), model_params); + String absPath = ProjectSettings::get_singleton()->globalize_path(path); + + model = llama_load_model_from_file(absPath.utf8().get_data(), model_params); if (model == NULL) { - UtilityFunctions::printerr(vformat("%s: Unable to load model from %s", __func__, path)); + UtilityFunctions::printerr(vformat("%s: Unable to load model from %s", __func__, absPath)); return; } - UtilityFunctions::print(vformat("%s: Model loaded from %s", __func__, path)); + UtilityFunctions::print(vformat("%s: Model loaded from %s", __func__, absPath)); } int LlamaModel::get_n_gpu_layers() { diff --git a/src/llama_model_loader.cpp b/src/llama_model_loader.cpp index 77b8a94..940f1c3 100644 --- a/src/llama_model_loader.cpp +++ b/src/llama_model_loader.cpp @@ -2,7 +2,6 @@ #include "llama_model.h" #include #include -#include #include using namespace godot; @@ -24,9 +23,7 @@ Variant godot::LlamaModelLoader::_load(const String &path, const String &origina return { model }; } - String absPath = ProjectSettings::get_singleton()->globalize_path(path); - - model->load_model(absPath); + model->load_model(path); return { model }; } diff --git a/src/register_types.cpp b/src/register_types.cpp index 0cf4cf0..17ec90b 100644 --- a/src/register_types.cpp +++ b/src/register_types.cpp @@ -7,7 +7,6 @@ #include "llama_model.h" #include "llama_model_loader.h" #include "llama_context.h" -#include "llama_backend.h" using namespace godot; @@ -24,7 +23,6 @@ void initialize_types(ModuleInitializationLevel p_level) ClassDB::register_class(); ClassDB::register_class(); - ClassDB::register_class(); } void uninitialize_types(ModuleInitializationLevel p_level) { diff --git a/tools/expand_metal.zig b/tools/expand_metal.zig new file mode 100644 index 0000000..1e4e9f4 --- /dev/null +++ b/tools/expand_metal.zig @@ -0,0 +1,79 @@ +const std = @import("std"); + +const usage = + \\Usage: ./embed_metal [options] + \\ + \\Options: + \\ --metal-file ggml-metal.metal + \\ --common-file ggml-common.h + \\ --output-file ggml-metal-embed.metal + \\ +; + +pub fn main() !void { + var arena_state = std.heap.ArenaAllocator.init(std.heap.page_allocator); + defer arena_state.deinit(); + const arena = arena_state.allocator(); + + const args = try std.process.argsAlloc(arena); + + var opt_metal_file_path: ?[]const u8 = null; + var opt_common_file_path: ?[]const u8 = null; + var opt_output_file_path: ?[]const u8 = null; + + { + var i: usize = 1; + while (i < args.len) : (i += 1) { + const arg = args[i]; + if (std.mem.eql(u8, "-h", arg) or std.mem.eql(u8, "--help", arg)) { + try std.io.getStdOut().writeAll(usage); + return std.process.cleanExit(); + } else if (std.mem.eql(u8, "--metal-file", arg)) { + i += 1; + if (i > args.len) std.debug.panic("expected arg after '{s}'", .{arg}); + if (opt_metal_file_path != null) std.debug.panic("duplicated {s} argument", .{arg}); + opt_metal_file_path = args[i]; + } else if (std.mem.eql(u8, "--common-file", arg)) { + i += 1; + if (i > args.len) std.debug.panic("expected arg after '{s}'", .{arg}); + if (opt_common_file_path != null) std.debug.panic("duplicated {s} argument", .{arg}); + opt_common_file_path = args[i]; + } else if (std.mem.eql(u8, "--output-file", arg)) { + i += 1; + if (i > args.len) std.debug.panic("expected arg after '{s}'", .{arg}); + if (opt_output_file_path != null) std.debug.panic("duplicated {s} argument", .{arg}); + opt_output_file_path = args[i]; + } else { + std.debug.panic("unrecognized arg: '{s}'", .{arg}); + } + } + } + + const metal_file_path = opt_metal_file_path orelse std.debug.panic("missing --input-file", .{}); + const common_file_path = opt_common_file_path orelse std.debug.panic("missing --output-file", .{}); + const output_file_path = opt_output_file_path orelse std.debug.panic("missing --lang", .{}); + + const cwd = std.fs.cwd(); + + var metal_file = try cwd.openFile(metal_file_path, .{}); + defer metal_file.close(); + + var common_file = try cwd.openFile(common_file_path, .{}); + defer common_file.close(); + + const metal_size = (try metal_file.stat()).size; + const metal_contents = try arena.alloc(u8, metal_size); + defer arena.free(metal_contents); + _ = try metal_file.readAll(metal_contents); + + const common_size = (try common_file.stat()).size; + const common_contents = try arena.alloc(u8, common_size); + defer arena.free(common_contents); + _ = try common_file.readAll(common_contents); + + const output = try std.mem.replaceOwned(u8, arena, metal_contents, "#include \"ggml-common.h\"", common_contents); + defer arena.free(output); + + const output_file = try cwd.createFile(output_file_path, .{}); + try output_file.writeAll(output); +}