diff --git a/deps/mlx b/deps/mlx index 2162315..2263e4b 160000 --- a/deps/mlx +++ b/deps/mlx @@ -1 +1 @@ -Subproject commit 21623156a32910b9db7c913f91612dcde0282caf +Subproject commit 2263e4b279fe959f25615b54ae2fb300f22aa78a diff --git a/install.js b/install.js index 63e6a86..40d3307 100755 --- a/install.js +++ b/install.js @@ -8,9 +8,9 @@ if (packageJson.version === '0.0.1-dev') const fs = require('node:fs'); const path = require('node:path'); -const stream = require('node:stream'); const util = require('node:util'); const zlib = require('node:zlib'); +const {pipeline} = require('node:stream/promises'); const urlPrefix = 'https://github.com/frost-beta/node-mlx/releases/download'; @@ -39,7 +39,6 @@ async function download(url, filename) { if (!response.ok) throw new Error(`Failed to download ${url}, status: ${response.status}`); - const pipeline = util.promisify(stream.pipeline); const gunzip = zlib.createGunzip(); await pipeline(response.body, gunzip, fs.createWriteStream(filename)); } diff --git a/lib/index.d.ts b/lib/index.d.ts index 7fafb48..b59d517 100644 --- a/lib/index.d.ts +++ b/lib/index.d.ts @@ -204,6 +204,7 @@ export namespace core { function convolve(input: ScalarOrArray, weight: ScalarOrArray, mode?: string, s?: StreamOrDevice): array; function conv1d(input: ScalarOrArray, weight: ScalarOrArray, stride: number, padding: number, dilation: number, groups: number, s?: StreamOrDevice): array; function conv2d(input: ScalarOrArray, weight: ScalarOrArray, stride?: number | number[], padding?: number | number[], dilation?: number | number[], groups?: number, s?: StreamOrDevice): array; + function conv3d(input: ScalarOrArray, weight: ScalarOrArray, stride?: number | number[], padding?: number | number[], dilation?: number | number[], groups?: number, s?: StreamOrDevice): array; function convGeneral(input: ScalarOrArray, weight?: ScalarOrArray, stride?: number | number[], padding?: number | number[] | [number[], number[]], kernelDilation?: number | number[], inputDilation?: number | number[], groups?: number, flip?: boolean, s?: StreamOrDevice): array; function cos(array: ScalarOrArray, s?: StreamOrDevice): array; function cosh(array: ScalarOrArray, s?: StreamOrDevice): array; diff --git a/lib/nn/layers/convolution.ts b/lib/nn/layers/convolution.ts index a879fe9..64f45dd 100644 --- a/lib/nn/layers/convolution.ts +++ b/lib/nn/layers/convolution.ts @@ -111,7 +111,7 @@ export class Conv2d extends Module { override toStringExtra(): string { return `${this.weight.shape[3]}, ${this.weight.shape[0]}, ` + - `kernelSize=${this.weight.shape.slice(1, 3)}, stride=${this.stride}, ` + + `kernelSize=${this.weight.shape.slice(1, 2)}, stride=${this.stride}, ` + `padding=${this.padding}, dilation=${this.dilation}, ` + `bias=${!!this.bias}`; } @@ -124,3 +124,62 @@ export class Conv2d extends Module { return y; } } + +/** + * Applies a 3-dimensional convolution over the multi-channel input image. + * + * @remarks + * + * The channels are expected to be last i.e. the input shape should be `NDHWC` + * where: + * - `N` is the batch dimension + * - `D` is the input image depth + * - `H` is the input image height + * - `W` is the input image width + * - `C` is the number of input channels + * + * @param inChannels - The number of input channels. + * @param outChannels - The number of output channels. + * @param kernelSize - The size of the convolution filters. + * @param stride - The size of the stride when applying the filter. Default: 1. + * @param padding - How many positions to 0-pad the input with. Default: 0. + * @param bias - If `true` add a learnable bias to the output. Default: `true` + */ +export class Conv3d extends Module { + stride: number[]; + padding: number[]; + weight: mx.array; + bias?: mx.array; + + constructor(inChannels: number, + outChannels: number, + kernelSize: number | number[], + stride: number | number[] = [1, 1, 1], + padding: number | number[] = [0, 0, 0], + bias = true) { + super(); + this.stride = Array.isArray(stride) ? stride : [stride, stride, stride]; + this.padding = Array.isArray(padding) ? padding : [padding, padding, padding]; + + kernelSize = Array.isArray(kernelSize) ? kernelSize : [kernelSize, kernelSize]; + const scale = Math.sqrt(1 / (inChannels * kernelSize[0] * kernelSize[1] * kernelSize[2])); + this.weight = mx.random.uniform(-scale, scale, [outChannels, ...kernelSize, inChannels]); + if (bias) { + this.bias = mx.zeros([outChannels]); + } + } + + override toStringExtra(): string { + return `${this.weight.shape[3]}, ${this.weight.shape[0]}, ` + + `kernelSize=${this.weight.shape.slice(1, 3)}, stride=${this.stride}, ` + + `padding=${this.padding}, bias=${!!this.bias}`; + } + + override forward(x: mx.array): mx.array { + const y = mx.conv3d(x, this.weight, this.stride, this.padding); + if (this.bias) + return mx.add(y, this.bias); + else + return y; + } +} diff --git a/lib/nn/utils.ts b/lib/nn/utils.ts index 566f2ba..f25a15d 100644 --- a/lib/nn/utils.ts +++ b/lib/nn/utils.ts @@ -4,14 +4,13 @@ import {NestedDict} from '../utils'; /** * Transform the passed function `func` to a function that computes the - * gradients of `func` with respect to the model's trainable parameters and also its - * value. + * gradients of `func` with respect to the model's trainable parameters and also + * its value. * - * @param model The model whose trainable parameters to compute - * gradients for - * @param func The scalar function to compute gradients for - * @returns A callable that returns the value of `func` and the gradients with respect to the - * trainable parameters of `model` + * @param model The model whose trainable parameters to compute gradients for. + * @param func The scalar function to compute gradients for. + * @returns A callable that returns the value of `func` and the gradients with + * respect to the trainable parameters of `model`. */ export function valueAndGrad(model: Module, func: (...args: T) => U) { @@ -35,12 +34,11 @@ export function valueAndGrad(model: Module, * with respect to the trainable parameters of the module (and the callable's * inputs). * - * @param mod The module for whose parameters to perform gradient - * checkpointing. + * @param mod The module for whose parameters to perform gradient checkpointing. * @param func The function to checkpoint. If not provided, it defaults to the * provided module. - * @returns A function that saves the inputs and outputs during the forward - * pass and recomputes all intermediate states during the backward pass. + * @returns A function that saves the inputs and outputs during the forward pass + * and recomputes all intermediate states during the backward pass. */ export function checkpoint( mod: M, diff --git a/src/array.cc b/src/array.cc index 63fa9a5..d2f9eed 100644 --- a/src/array.cc +++ b/src/array.cc @@ -520,6 +520,7 @@ void Type::Define(napi_env env, "round", MemberFunction(&ops::Round), "diagonal", MemberFunction(&ops::Diagonal), "diag", MemberFunction(&ops::Diag), + "conj", MemberFunction(&mx::conjugate), "index", MemberFunction(&Index), "indexPut_", MemberFunction(&IndexPut), SymbolIterator(env), MemberFunction(&CreateArrayIterator)); diff --git a/src/ops.cc b/src/ops.cc index d4f6bf2..25fd46f 100644 --- a/src/ops.cc +++ b/src/ops.cc @@ -448,8 +448,38 @@ mx::array Conv2d( else if (auto p = std::get_if>(&dilation); p) dilation_pair = std::move(*p); - return conv2d(input, weight, stride_pair, padding_pair, dilation_pair, - groups.value_or(1), s); + return mx::conv2d(input, weight, stride_pair, padding_pair, dilation_pair, + groups.value_or(1), s); +} + +mx::array Conv3d( + const mx::array& input, + const mx::array& weight, + std::variant> stride, + std::variant> padding, + std::variant> dilation, + std::optional groups, + mx::StreamOrDevice s) { + std::tuple stride_tuple = {1, 1, 1}; + if (auto i = std::get_if(&stride); i) + stride_tuple = {*i, *i, *i}; + else if (auto p = std::get_if>(&stride); p) + stride_tuple = std::move(*p); + + std::tuple padding_tuple = {0, 0, 0}; + if (auto i = std::get_if(&padding); i) + padding_tuple = {*i, *i, *i}; + else if (auto p = std::get_if>(&padding); p) + padding_tuple = std::move(*p); + + std::tuple dilation_tuple = {1, 1, 1}; + if (auto i = std::get_if(&dilation); i) + dilation_tuple = {*i, *i, *i}; + else if (auto p = std::get_if>(&dilation); p) + dilation_tuple = std::move(*p); + + return mx::conv3d(input, weight, stride_tuple, padding_tuple, dilation_tuple, + groups.value_or(1), s); } mx::array ConvGeneral( @@ -611,6 +641,7 @@ void InitOps(napi_env env, napi_value exports) { "arcsin", &mx::arcsin, "arccos", &mx::arccos, "arctan", &mx::arctan, + "arctan2", &mx::arctan2, "sinh", &mx::sinh, "cosh", &mx::cosh, "tanh", &mx::tanh, @@ -684,9 +715,12 @@ void InitOps(napi_env env, napi_value exports) { "cumprod", CumOpWrapper(&mx::cumprod), "cummax", CumOpWrapper(&mx::cummax), "cummin", CumOpWrapper(&mx::cummin), + "conj", &mx::conjugate, + "conjugate", &mx::conjugate, "convolve", &ops::Convolve, "conv1d", &mx::conv1d, "conv2d", &ops::Conv2d, + "conv3d", &ops::Conv3d, "convGeneral", &ops::ConvGeneral, "where", &mx::where, "round", &ops::Round, diff --git a/src/transforms.cc b/src/transforms.cc index 09e1b64..4f193f7 100644 --- a/src/transforms.cc +++ b/src/transforms.cc @@ -125,7 +125,8 @@ ValueAndGradImpl(const char* error_tag, ki::ThrowError(js_func.Env(), error_tag, " Can't compute the gradient of argument " "index ", argnums.back(), " because the function is " - "called with only ", args.Length(), " arguments."); + "called with only ", args.Length(), + " positional arguments."); return {nullptr, nullptr}; } // Collect the arrays located at |argnums|.