Skip to content

Commit 13eb0ad

Browse files
committed
add metal build
1 parent 6960f3b commit 13eb0ad

File tree

2 files changed

+124
-10
lines changed

2 files changed

+124
-10
lines changed

build.zig

+45-10
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,44 @@ fn build_lib_llama_cpp(params: BuildParams) !*std.Build.Step.Compile {
8484
\\char const *LLAMA_COMMIT = "{s}";
8585
\\char const *LLAMA_COMPILER = "Zig {s}";
8686
\\char const *LLAMA_BUILD_TARGET = "{s}";
87-
\\
8887
, .{ 0, commit_hash.stdout[0 .. commit_hash.stdout.len - 1], zig_version, zig_triple }) });
8988

89+
const lib_llama_cpp = b.addStaticLibrary(.{ .name = "llama.cpp", .target = target, .optimize = optimize });
90+
9091
var objs = std.ArrayList(*std.Build.Step.Compile).init(b.allocator);
91-
var objBuilder = ObjBuilder.init(.{ .b = b, .target = target, .optimize = optimize, .include_paths = &.{
92-
"llama.cpp",
93-
"llama.cpp/common",
94-
} });
92+
var objBuilder = ObjBuilder.init(.{
93+
.b = b,
94+
.target = target,
95+
.optimize = optimize,
96+
.include_paths = &.{ "llama.cpp", "llama.cpp/common" },
97+
});
98+
99+
switch (target.result.os.tag) {
100+
.macos => {
101+
try objBuilder.flags.append("-DGGML_USE_METAL");
102+
try objs.append(objBuilder.build(.{ .name = "ggml_metal", .sources = &.{"llama.cpp/ggml-metal.m"} }));
103+
104+
lib_llama_cpp.linkFramework("Foundation");
105+
lib_llama_cpp.linkFramework("Metal");
106+
lib_llama_cpp.linkFramework("MetalKit");
107+
108+
const expand_metal = b.addExecutable(.{
109+
.name = "expand_metal",
110+
.target = target,
111+
.root_source_file = .{ .path = "tools/expand_metal.zig" },
112+
});
113+
var run_expand_metal = b.addRunArtifact(expand_metal);
114+
run_expand_metal.addArg("--metal-file");
115+
run_expand_metal.addFileArg(.{ .path = "llama.cpp/ggml-metal.metal" });
116+
run_expand_metal.addArg("--common-file");
117+
run_expand_metal.addFileArg(.{ .path = "llama.cpp/ggml-common.h" });
118+
run_expand_metal.addArg("--output-file");
119+
const metal_expanded = run_expand_metal.addOutputFileArg("ggml-metal.metal");
120+
const install_metal = b.addInstallFileWithDir(metal_expanded, .lib, "ggml-metal.metal");
121+
lib_llama_cpp.step.dependOn(&install_metal.step);
122+
},
123+
else => {},
124+
}
95125

96126
try objs.appendSlice(&.{
97127
objBuilder.build(.{ .name = "ggml", .sources = &.{"llama.cpp/ggml.c"} }),
@@ -110,8 +140,6 @@ fn build_lib_llama_cpp(params: BuildParams) !*std.Build.Step.Compile {
110140
objBuilder.build(.{ .name = "build_info", .sources = &.{"llama.cpp/common/build-info.cpp"} }),
111141
});
112142

113-
const lib_llama_cpp = b.addStaticLibrary(.{ .name = "llama.cpp", .target = target, .optimize = optimize });
114-
115143
for (objs.items) |obj| {
116144
lib_llama_cpp.addObject(obj);
117145
}
@@ -124,19 +152,26 @@ const ObjBuilder = struct {
124152
target: std.Build.ResolvedTarget,
125153
optimize: std.builtin.OptimizeMode,
126154
include_paths: []const []const u8,
127-
128-
fn init(params: struct { b: *std.Build, target: std.Build.ResolvedTarget, optimize: std.builtin.OptimizeMode, include_paths: []const []const u8 }) ObjBuilder {
155+
flags: std.ArrayList([]const u8),
156+
157+
fn init(params: struct {
158+
b: *std.Build,
159+
target: std.Build.ResolvedTarget,
160+
optimize: std.builtin.OptimizeMode,
161+
include_paths: []const []const u8,
162+
}) ObjBuilder {
129163
return ObjBuilder{
130164
.b = params.b,
131165
.target = params.target,
132166
.optimize = params.optimize,
133167
.include_paths = params.include_paths,
168+
.flags = std.ArrayList([]const u8).init(params.b.allocator),
134169
};
135170
}
136171

137172
fn build(self: *ObjBuilder, params: struct { name: []const u8, sources: []const []const u8 }) *std.Build.Step.Compile {
138173
const obj = self.b.addObject(.{ .name = params.name, .target = self.target, .optimize = self.optimize });
139-
obj.addCSourceFiles(.{ .files = params.sources });
174+
obj.addCSourceFiles(.{ .files = params.sources, .flags = self.flags.items });
140175
for (self.include_paths) |path| {
141176
obj.addIncludePath(.{ .path = path });
142177
}

tools/expand_metal.zig

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
const std = @import("std");
2+
3+
const usage =
4+
\\Usage: ./embed_metal [options]
5+
\\
6+
\\Options:
7+
\\ --metal-file ggml-metal.metal
8+
\\ --common-file ggml-common.h
9+
\\ --output-file ggml-metal-embed.metal
10+
\\
11+
;
12+
13+
pub fn main() !void {
14+
var arena_state = std.heap.ArenaAllocator.init(std.heap.page_allocator);
15+
defer arena_state.deinit();
16+
const arena = arena_state.allocator();
17+
18+
const args = try std.process.argsAlloc(arena);
19+
20+
var opt_metal_file_path: ?[]const u8 = null;
21+
var opt_common_file_path: ?[]const u8 = null;
22+
var opt_output_file_path: ?[]const u8 = null;
23+
24+
{
25+
var i: usize = 1;
26+
while (i < args.len) : (i += 1) {
27+
const arg = args[i];
28+
if (std.mem.eql(u8, "-h", arg) or std.mem.eql(u8, "--help", arg)) {
29+
try std.io.getStdOut().writeAll(usage);
30+
return std.process.cleanExit();
31+
} else if (std.mem.eql(u8, "--metal-file", arg)) {
32+
i += 1;
33+
if (i > args.len) std.debug.panic("expected arg after '{s}'", .{arg});
34+
if (opt_metal_file_path != null) std.debug.panic("duplicated {s} argument", .{arg});
35+
opt_metal_file_path = args[i];
36+
} else if (std.mem.eql(u8, "--common-file", arg)) {
37+
i += 1;
38+
if (i > args.len) std.debug.panic("expected arg after '{s}'", .{arg});
39+
if (opt_common_file_path != null) std.debug.panic("duplicated {s} argument", .{arg});
40+
opt_common_file_path = args[i];
41+
} else if (std.mem.eql(u8, "--output-file", arg)) {
42+
i += 1;
43+
if (i > args.len) std.debug.panic("expected arg after '{s}'", .{arg});
44+
if (opt_output_file_path != null) std.debug.panic("duplicated {s} argument", .{arg});
45+
opt_output_file_path = args[i];
46+
} else {
47+
std.debug.panic("unrecognized arg: '{s}'", .{arg});
48+
}
49+
}
50+
}
51+
52+
const metal_file_path = opt_metal_file_path orelse std.debug.panic("missing --input-file", .{});
53+
const common_file_path = opt_common_file_path orelse std.debug.panic("missing --output-file", .{});
54+
const output_file_path = opt_output_file_path orelse std.debug.panic("missing --lang", .{});
55+
56+
const cwd = std.fs.cwd();
57+
58+
var metal_file = try cwd.openFile(metal_file_path, .{});
59+
defer metal_file.close();
60+
61+
var common_file = try cwd.openFile(common_file_path, .{});
62+
defer common_file.close();
63+
64+
const metal_size = (try metal_file.stat()).size;
65+
const metal_contents = try arena.alloc(u8, metal_size);
66+
defer arena.free(metal_contents);
67+
_ = try metal_file.readAll(metal_contents);
68+
69+
const common_size = (try common_file.stat()).size;
70+
const common_contents = try arena.alloc(u8, common_size);
71+
defer arena.free(common_contents);
72+
_ = try common_file.readAll(common_contents);
73+
74+
const output = try std.mem.replaceOwned(u8, arena, metal_contents, "#include \"ggml-common.h\"", common_contents);
75+
defer arena.free(output);
76+
77+
const output_file = try cwd.createFile(output_file_path, .{});
78+
try output_file.writeAll(output);
79+
}

0 commit comments

Comments
 (0)