Skip to content

Commit

Permalink
Update to MLX v0.13.0
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed May 14, 2024
1 parent c8ab41c commit 766a7c2
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 18 deletions.
2 changes: 1 addition & 1 deletion deps/mlx
Submodule mlx updated 73 files
+2 −2 .circleci/config.yml
+49 −42 CMakeLists.txt
+25 −2 docs/src/install.rst
+1 −0 docs/src/python/nn/layers.rst
+3 −0 docs/src/python/ops.rst
+2 −2 examples/cpp/tutorial.cpp
+7 −2 mlx/CMakeLists.txt
+21 −0 mlx/backend/accelerate/primitives.cpp
+1 −0 mlx/backend/common/CMakeLists.txt
+21 −0 mlx/backend/common/binary.cpp
+347 −0 mlx/backend/common/common.cpp
+465 −1 mlx/backend/common/conv.cpp
+2 −0 mlx/backend/common/default_primitives.cpp
+0 −9 mlx/backend/common/inverse.cpp
+13 −0 mlx/backend/common/ops.h
+9 −336 mlx/backend/common/primitives.cpp
+0 −9 mlx/backend/common/svd.cpp
+2 −2 mlx/backend/metal/compiled.cpp
+74 −9 mlx/backend/metal/conv.cpp
+2 −2 mlx/backend/metal/copy.cpp
+47 −34 mlx/backend/metal/device.cpp
+17 −4 mlx/backend/metal/device.h
+1 −1 mlx/backend/metal/fft.cpp
+3 −3 mlx/backend/metal/indexing.cpp
+4 −4 mlx/backend/metal/kernels/arange.metal
+7 −0 mlx/backend/metal/kernels/binary.h
+1 −0 mlx/backend/metal/kernels/binary.metal
+42 −39 mlx/backend/metal/kernels/reduction/reduce_inst.h
+5 −3 mlx/backend/metal/kernels/steel/gemm/gemm.h
+7 −5 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal
+5 −3 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal
+17 −13 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal
+6 −0 mlx/backend/metal/kernels/unary.h
+1 −0 mlx/backend/metal/kernels/unary.metal
+12 −12 mlx/backend/metal/matmul.cpp
+16 −30 mlx/backend/metal/metal.cpp
+4 −4 mlx/backend/metal/normalization.cpp
+25 −10 mlx/backend/metal/primitives.cpp
+5 −5 mlx/backend/metal/quantized.cpp
+11 −11 mlx/backend/metal/reduce.cpp
+1 −1 mlx/backend/metal/rope.cpp
+2 −2 mlx/backend/metal/scaled_dot_product_attention.cpp
+2 −2 mlx/backend/metal/scan.cpp
+1 −1 mlx/backend/metal/softmax.cpp
+4 −4 mlx/backend/metal/sort.cpp
+9 −0 mlx/backend/no_cpu/CMakeLists.txt
+108 −0 mlx/backend/no_cpu/primitives.cpp
+2 −0 mlx/backend/no_metal/primitives.cpp
+13 −11 mlx/compile.cpp
+1 −1 mlx/io.h
+50 −25 mlx/io/CMakeLists.txt
+2 −1 mlx/io/gguf.cpp
+1 −1 mlx/io/gguf_quants.cpp
+20 −0 mlx/io/no_gguf.cpp
+37 −0 mlx/io/no_safetensors.cpp
+1 −1 mlx/io/safetensors.cpp
+55 −6 mlx/ops.cpp
+15 −0 mlx/ops.h
+99 −2 mlx/primitives.cpp
+33 −0 mlx/primitives.h
+9 −2 mlx/transforms.cpp
+1 −1 python/mlx/nn/layers/__init__.py
+63 −0 python/mlx/nn/layers/convolution.py
+1 −0 python/src/CMakeLists.txt
+9 −1 python/src/array.cpp
+239 −42 python/src/ops.cpp
+1 −1 python/src/transforms.cpp
+80 −0 python/src/utils.cpp
+17 −54 python/src/utils.h
+7 −0 python/tests/test_array.py
+199 −2 python/tests/test_conv.py
+63 −0 python/tests/test_ops.py
+1 −1 setup.py
3 changes: 1 addition & 2 deletions install.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ if (packageJson.version === '0.0.1-dev')

const fs = require('node:fs');
const path = require('node:path');
const stream = require('node:stream');
const util = require('node:util');
const zlib = require('node:zlib');
const {pipeline} = require('node:stream/promises');

const urlPrefix = 'https://github.com/frost-beta/node-mlx/releases/download';

Expand Down Expand Up @@ -39,7 +39,6 @@ async function download(url, filename) {
if (!response.ok)
throw new Error(`Failed to download ${url}, status: ${response.status}`);

const pipeline = util.promisify(stream.pipeline);
const gunzip = zlib.createGunzip();
await pipeline(response.body, gunzip, fs.createWriteStream(filename));
}
1 change: 1 addition & 0 deletions lib/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ export namespace core {
function convolve(input: ScalarOrArray, weight: ScalarOrArray, mode?: string, s?: StreamOrDevice): array;
function conv1d(input: ScalarOrArray, weight: ScalarOrArray, stride: number, padding: number, dilation: number, groups: number, s?: StreamOrDevice): array;
function conv2d(input: ScalarOrArray, weight: ScalarOrArray, stride?: number | number[], padding?: number | number[], dilation?: number | number[], groups?: number, s?: StreamOrDevice): array;
function conv3d(input: ScalarOrArray, weight: ScalarOrArray, stride?: number | number[], padding?: number | number[], dilation?: number | number[], groups?: number, s?: StreamOrDevice): array;
function convGeneral(input: ScalarOrArray, weight?: ScalarOrArray, stride?: number | number[], padding?: number | number[] | [number[], number[]], kernelDilation?: number | number[], inputDilation?: number | number[], groups?: number, flip?: boolean, s?: StreamOrDevice): array;
function cos(array: ScalarOrArray, s?: StreamOrDevice): array;
function cosh(array: ScalarOrArray, s?: StreamOrDevice): array;
Expand Down
61 changes: 60 additions & 1 deletion lib/nn/layers/convolution.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ export class Conv2d extends Module {

override toStringExtra(): string {
return `${this.weight.shape[3]}, ${this.weight.shape[0]}, ` +
`kernelSize=${this.weight.shape.slice(1, 3)}, stride=${this.stride}, ` +
`kernelSize=${this.weight.shape.slice(1, 2)}, stride=${this.stride}, ` +
`padding=${this.padding}, dilation=${this.dilation}, ` +
`bias=${!!this.bias}`;
}
Expand All @@ -124,3 +124,62 @@ export class Conv2d extends Module {
return y;
}
}

/**
* Applies a 3-dimensional convolution over the multi-channel input image.
*
* @remarks
*
* The channels are expected to be last i.e. the input shape should be `NDHWC`
* where:
* - `N` is the batch dimension
* - `D` is the input image depth
* - `H` is the input image height
* - `W` is the input image width
* - `C` is the number of input channels
*
* @param inChannels - The number of input channels.
* @param outChannels - The number of output channels.
* @param kernelSize - The size of the convolution filters.
* @param stride - The size of the stride when applying the filter. Default: 1.
* @param padding - How many positions to 0-pad the input with. Default: 0.
* @param bias - If `true` add a learnable bias to the output. Default: `true`
*/
export class Conv3d extends Module {
stride: number[];
padding: number[];
weight: mx.array;
bias?: mx.array;

constructor(inChannels: number,
outChannels: number,
kernelSize: number | number[],
stride: number | number[] = [1, 1, 1],
padding: number | number[] = [0, 0, 0],
bias = true) {
super();
this.stride = Array.isArray(stride) ? stride : [stride, stride, stride];
this.padding = Array.isArray(padding) ? padding : [padding, padding, padding];

kernelSize = Array.isArray(kernelSize) ? kernelSize : [kernelSize, kernelSize];
const scale = Math.sqrt(1 / (inChannels * kernelSize[0] * kernelSize[1] * kernelSize[2]));
this.weight = mx.random.uniform(-scale, scale, [outChannels, ...kernelSize, inChannels]);
if (bias) {
this.bias = mx.zeros([outChannels]);
}
}

override toStringExtra(): string {
return `${this.weight.shape[3]}, ${this.weight.shape[0]}, ` +
`kernelSize=${this.weight.shape.slice(1, 3)}, stride=${this.stride}, ` +
`padding=${this.padding}, bias=${!!this.bias}`;
}

override forward(x: mx.array): mx.array {
const y = mx.conv3d(x, this.weight, this.stride, this.padding);
if (this.bias)
return mx.add(y, this.bias);
else
return y;
}
}
20 changes: 9 additions & 11 deletions lib/nn/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@ import {NestedDict} from '../utils';

/**
* Transform the passed function `func` to a function that computes the
* gradients of `func` with respect to the model's trainable parameters and also its
* value.
* gradients of `func` with respect to the model's trainable parameters and also
* its value.
*
* @param model The model whose trainable parameters to compute
* gradients for
* @param func The scalar function to compute gradients for
* @returns A callable that returns the value of `func` and the gradients with respect to the
* trainable parameters of `model`
* @param model The model whose trainable parameters to compute gradients for.
* @param func The scalar function to compute gradients for.
* @returns A callable that returns the value of `func` and the gradients with
* respect to the trainable parameters of `model`.
*/
export function valueAndGrad<T extends any[], U>(model: Module,
func: (...args: T) => U) {
Expand All @@ -35,12 +34,11 @@ export function valueAndGrad<T extends any[], U>(model: Module,
* with respect to the trainable parameters of the module (and the callable's
* inputs).
*
* @param mod The module for whose parameters to perform gradient
* checkpointing.
* @param mod The module for whose parameters to perform gradient checkpointing.
* @param func The function to checkpoint. If not provided, it defaults to the
* provided module.
* @returns A function that saves the inputs and outputs during the forward
* pass and recomputes all intermediate states during the backward pass.
* @returns A function that saves the inputs and outputs during the forward pass
* and recomputes all intermediate states during the backward pass.
*/
export function checkpoint<M extends Module, T extends any[]>(
mod: M,
Expand Down
1 change: 1 addition & 0 deletions src/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ void Type<mx::array>::Define(napi_env env,
"round", MemberFunction(&ops::Round),
"diagonal", MemberFunction(&ops::Diagonal),
"diag", MemberFunction(&ops::Diag),
"conj", MemberFunction(&mx::conjugate),
"index", MemberFunction(&Index),
"indexPut_", MemberFunction(&IndexPut),
SymbolIterator(env), MemberFunction(&CreateArrayIterator));
Expand Down
38 changes: 36 additions & 2 deletions src/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,38 @@ mx::array Conv2d(
else if (auto p = std::get_if<std::pair<int, int>>(&dilation); p)
dilation_pair = std::move(*p);

return conv2d(input, weight, stride_pair, padding_pair, dilation_pair,
groups.value_or(1), s);
return mx::conv2d(input, weight, stride_pair, padding_pair, dilation_pair,
groups.value_or(1), s);
}

mx::array Conv3d(
const mx::array& input,
const mx::array& weight,
std::variant<std::monostate, int, std::tuple<int, int, int>> stride,
std::variant<std::monostate, int, std::tuple<int, int, int>> padding,
std::variant<std::monostate, int, std::tuple<int, int, int>> dilation,
std::optional<int> groups,
mx::StreamOrDevice s) {
std::tuple<int, int, int> stride_tuple = {1, 1, 1};
if (auto i = std::get_if<int>(&stride); i)
stride_tuple = {*i, *i, *i};
else if (auto p = std::get_if<std::tuple<int, int, int>>(&stride); p)
stride_tuple = std::move(*p);

std::tuple<int, int, int> padding_tuple = {0, 0, 0};
if (auto i = std::get_if<int>(&padding); i)
padding_tuple = {*i, *i, *i};
else if (auto p = std::get_if<std::tuple<int, int, int>>(&padding); p)
padding_tuple = std::move(*p);

std::tuple<int, int, int> dilation_tuple = {1, 1, 1};
if (auto i = std::get_if<int>(&dilation); i)
dilation_tuple = {*i, *i, *i};
else if (auto p = std::get_if<std::tuple<int, int, int>>(&dilation); p)
dilation_tuple = std::move(*p);

return mx::conv3d(input, weight, stride_tuple, padding_tuple, dilation_tuple,
groups.value_or(1), s);
}

mx::array ConvGeneral(
Expand Down Expand Up @@ -611,6 +641,7 @@ void InitOps(napi_env env, napi_value exports) {
"arcsin", &mx::arcsin,
"arccos", &mx::arccos,
"arctan", &mx::arctan,
"arctan2", &mx::arctan2,
"sinh", &mx::sinh,
"cosh", &mx::cosh,
"tanh", &mx::tanh,
Expand Down Expand Up @@ -684,9 +715,12 @@ void InitOps(napi_env env, napi_value exports) {
"cumprod", CumOpWrapper(&mx::cumprod),
"cummax", CumOpWrapper(&mx::cummax),
"cummin", CumOpWrapper(&mx::cummin),
"conj", &mx::conjugate,
"conjugate", &mx::conjugate,
"convolve", &ops::Convolve,
"conv1d", &mx::conv1d,
"conv2d", &ops::Conv2d,
"conv3d", &ops::Conv3d,
"convGeneral", &ops::ConvGeneral,
"where", &mx::where,
"round", &ops::Round,
Expand Down
3 changes: 2 additions & 1 deletion src/transforms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ ValueAndGradImpl(const char* error_tag,
ki::ThrowError(js_func.Env(),
error_tag, " Can't compute the gradient of argument "
"index ", argnums.back(), " because the function is "
"called with only ", args.Length(), " arguments.");
"called with only ", args.Length(),
" positional arguments.");
return {nullptr, nullptr};
}
// Collect the arrays located at |argnums|.
Expand Down

0 comments on commit 766a7c2

Please sign in to comment.