Skip to content

Commit

Permalink
Update to MLX 0.21.1
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Dec 7, 2024
1 parent 6ce3111 commit 29d038a
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 82 deletions.
25 changes: 17 additions & 8 deletions lib/nn/layers/quantized.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import {Embedding} from './embedding';
import {Linear} from './linear';
import {Module} from './base';

export type ClassPredicate = (p: string, m: Module) => [number, number] | boolean;

/**
* Quantize the sub-modules of a module according to a predicate.
*
Expand All @@ -19,20 +21,27 @@ import {Module} from './base';
* @param groupSize - The quantization group size. Default: `64`.
* @param bits - The number of bits per parameter. Default: `4`.
* @param classPredicate - A function which receives the `Module` path and
* `Module` itself and returns `true` if it should be quantized and `false`
* otherwise. If `null`, then all layers that define a
* `toQuantized(group_size, bits)` method are quantized. The path is converted
* to snake_case for convenience. Default: `null`.
* `Module` itself and returns `true` or an array of arguments for `toQuantized`
* if it should be quantized and `false` otherwise. If `null`, then all layers
* that define a `toQuantized(group_size, bits)` method are quantized. The path
* is converted to snake_case for convenience.
*/
export function quantize(model: Module,
groupSize = 64,
bits = 4,
classPredicate = (p: string, m: Module) => 'toQuantized' in m && typeof m.toQuantized === 'function'): void {
classPredicate: ClassPredicate = (p: string, m: Module) => 'toQuantized' in m && typeof m.toQuantized === 'function'): void {
function maybeQuantize(path: string, m: Module): Module {
if (!classPredicate(toSnakeCase(path), m))
const boolOrArgs = classPredicate(toSnakeCase(path), m);
if (!boolOrArgs)
return m;
if ('toQuantized' in m && typeof m.toQuantized === 'function')
return m.toQuantized(groupSize, bits);
if ('toQuantized' in m && typeof m.toQuantized === 'function') {
if (typeof boolOrArgs == 'boolean')
return m.toQuantized(groupSize, bits);
else if (Array.isArray(boolOrArgs))
return m.toQuantized(...boolOrArgs);
else
throw Error('"class_predicate" must return a bool or an array of arguments to pass to "toQuantized"');
}
throw Error(`Unable to quantize model of type ${typeof m}`);
}

Expand Down
165 changes: 92 additions & 73 deletions lib/optimizers/optimizers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,11 @@ export abstract class Optimizer {
* w_{t+1} &= w_t - \lambda v_{t+1}
* ```
*
* @param learningRate The learning rate `\lambda`.
* @param momentum The momentum strength `\mu`. Default: ``0``
* @param weightDecay The weight decay (L2 penalty). Default: ``0``
* @param dampening Dampening for momentum `\tau`. Default: ``0``
* @param nesterov Enables Nesterov momentum. Default: ``False``
* @param learningRate - The learning rate `\lambda`.
* @param momentum - The momentum strength `\mu`. Default: `0`.
* @param weightDecay - The weight decay (L2 penalty). Default: `0`.
* @param dampening - Dampening for momentum `\tau`. Default: `0`.
* @param nesterov - Enables Nesterov momentum. Default: `False`.
*/
export class SGD extends Optimizer {
momentum: number;
Expand Down Expand Up @@ -262,16 +262,20 @@ export class SGD extends Optimizer {
/**
* The RMSprop optimizer.
*
* Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning
* @remarks
*
* Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural
* networks for machine learning
*
* ```math
* v_{t+1} &= \alpha v_t + (1 - \alpha) g_t^2 \\
* w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
* ```
*
* @param learningRate The learning rate.
* @param alpha The smoothing constant `\alpha`. Default: ``0.99``
* @param eps The term `\epsilon` added to the denominator to improve numerical stability. Default: ``1e-8``
* @param learningRate - The learning rate.
* @param alpha - The smoothing constant `\alpha`. Default: `0.99`.
* @param eps - The term `\epsilon` added to the denominator to improve
* numerical stability. Default: `1e-8`.
*/
export class RMSprop extends Optimizer {
alpha: number;
Expand Down Expand Up @@ -338,9 +342,9 @@ export class RMSprop extends Optimizer {
* w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
* ```
*
* @param learningRate The learning rate `\lambda`.
* @param eps The term `\epsilon` added to the
* denominator to improve numerical stability. Default: ``1e-8``
* @param learningRate - The learning rate `\lambda`.
* @param eps - The term `\epsilon` added to the denominator to improve
* numerical stability. Default: `1e-8`.
*/
export class Adagrad extends Optimizer {
eps: number;
Expand Down Expand Up @@ -400,9 +404,11 @@ export class Adagrad extends Optimizer {
* w_{t+1} &= w_t - \lambda \Delta w_{t+1}
* ```
*
* @param learningRate The learning rate `\lambda`.
* @param rho The coefficient `\rho` used for computing a running average of squared gradients. Default: ``0.9``
* @param eps The term `\epsilon` added to the denominator to improve numerical stability.Default: `1e-8`
* @param learningRate - The learning rate `\lambda`.
* @param rho - The coefficient `\rho` used for computing a running average of
* squared gradients. Default: `0.9`.
* @param eps - The term `\epsilon` added to the denominator to improve
* numerical stability. Default: `1e-8`.
*/
export class AdaDelta extends Optimizer {
rho: number;
Expand Down Expand Up @@ -467,36 +473,37 @@ export class AdaDelta extends Optimizer {
*
* @remarks
*
* Our Adam implementation follows the original paper and omits the bias
* correction in the first and second moment estimates. In detail,
*
* Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic optimization.
* ICLR 2015.
* In detail,
*
* ```math
* m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
* v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
* w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}}
* ```
*
* @param learningRate The learning rate `\lambda`.
* @param betas The coefficients `(\beta_1, \beta_2)` used for computing running
* averages of the gradient and its square. Default: ``(0.9, 0.999)``
* @param eps The term `\epsilon` added to the
* denominator to improve numerical stability. Default: ``1e-8``
* @param learningRate - The learning rate `\lambda`.
* @param betas - The coefficients `(\beta_1, \beta_2)` used for computing
* running averages of the gradient and its square. Default: `[0.9, 0.999]`
* @param eps - The term `\epsilon` added to the denominator to improve
* numerical stability. Default: ``1e-8``
* @param biasCorrection - If set to `True`, bias correction is applied.
* Default: `False`.
*/
export class Adam extends Optimizer {
betas: number[];
eps: number;
biasCorrection: boolean;

constructor(learningRate: number | ((step: mx.array) => mx.array),
betas: number[] = [0.9, 0.999],
eps: number = 1e-8) {
betas: [number, number] = [0.9, 0.999],
eps = 1e-8,
biasCorrection = false) {
super();

this.maybeSchedule('learningRate', learningRate);
this.betas = betas;
this.eps = eps;
this.biasCorrection = biasCorrection;
}

/**
Expand All @@ -518,6 +525,8 @@ export class Adam extends Optimizer {
const lr = this.learningRate.astype(gradient.dtype);
const [b1, b2] = this.betas;
const eps = this.eps;
const biasCorrection = this.biasCorrection;
const step = this.step;

const m = mx.add(mx.multiply(b1, state['m']),
mx.multiply(1 - b1, gradient));
Expand All @@ -529,9 +538,22 @@ export class Adam extends Optimizer {
state['m'] = m;
state['v'] = v;

return mx.subtract(parameter,
mx.divide(mx.multiply(lr, m),
mx.add(mx.sqrt(v), eps)));
if (biasCorrection) {
const numerator = mx.multiply(mx.divide(lr,
mx.subtract(1,
mx.power(b1, step))),
m);
const denominator = mx.add(mx.divide(mx.sqrt(v),
mx.sqrt(mx.subtract(1,
mx.power(b2, step)))),
eps);
return mx.subtract(parameter,
mx.divide(numerator, denominator));
} else {
return mx.subtract(parameter,
mx.divide(mx.multiply(lr, m),
mx.add(mx.sqrt(v), eps)));
}
}
}

Expand All @@ -540,35 +562,32 @@ export class Adam extends Optimizer {
*
* @remarks
*
* Following the above convention, in contrast with [1], we do not use bias
* correction in the first and second moments for AdamW. We update the weights
* with a weight_decay (`\lambda`) value:
*
* [1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay
* regularization. ICLR 2019.
* We update the weights with a weight_decay (:math:`\lambda`) value:
*
* ```math
* m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
* v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
* w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t)
* ```
*
* @param learningRate The learning rate `\alpha`.
* @param betas The coefficients `(\beta_1, \beta_2)` used for computing running
* averages of the gradient and its square. Default: ``(0.9, 0.999)``
* @param eps The term `\epsilon` added to the
* denominator to improve numerical stability. Default: ``1e-8``
* @param weightDecay The weight decay `\lambda`.
* Default: ``0``.
* @param learningRate - The learning rate `\alpha`.
* @param betas - The coefficients `(\beta_1, \beta_2)` used for computing
* running averages of the gradient and its square. Default: `[0.9, 0.999]`.
* @param eps - The term `\epsilon` added to the denominator to improve
* numerical stability. Default: `1e-8`.
* @param weightDecay - The weight decay `\lambda`. Default: `0`.
* @param biasCorrection - If set to `True`, bias correction is applied.
* Default: `False`.
*/
export class AdamW extends Adam {
weightDecay: number;

constructor(learningRate: number | ((step: mx.array) => mx.array),
betas: number[] = [0.9, 0.999],
eps: number = 1e-8,
weightDecay: number = 0.01) {
super(learningRate, betas, eps);
betas: [number, number] = [0.9, 0.999],
eps = 1e-8,
weightDecay = 0.01,
biasCorrection = false) {
super(learningRate, betas, eps, biasCorrection);
this.weightDecay = weightDecay;
}

Expand Down Expand Up @@ -604,16 +623,16 @@ export class AdamW extends Adam {
* w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{v_{t+1} + \epsilon}
* ```
*
* @param learningRate The learning rate `\lambda`.
* @param betas The coefficients `(\beta_1, \beta_2)` used for computing running
* averages of the gradient and its square. Default: ``(0.9, 0.999)``
* @param eps The term `\epsilon` added to the
* denominator to improve numerical stability. Default: ``1e-8``
* @param learningRate - The learning rate `\lambda`.
* @param betas - The coefficients `(\beta_1, \beta_2)` used for computing
* running averages of the gradient and its square. Default: `[0.9, 0.999]`.
* @param eps - The term `\epsilon` added to the denominator to improve
* numerical stability. Default: `1e-8`.
*/
export class Adamax extends Adam {
constructor(learningRate: number | ((step: mx.array) => mx.array),
betas: number[] = [0.9, 0.999],
eps: number = 1e-8) {
betas: [number, number] = [0.9, 0.999],
eps = 1e-8) {
if (eps < 0)
throw new Error(`Epsilon value should be >=0, ${eps} was provided instead`);
super(learningRate, betas, eps);
Expand Down Expand Up @@ -677,16 +696,16 @@ export class Adamax extends Adam {
*
* @param learningRate - The learning rate `\eta`.
* @param betas - The coefficients `(\beta_1, \beta_2)` used for computing the
* gradient momentum and update direction. Default: ``(0.9, 0.99)``
* @param weightDecay - The weight decay `\lambda`. Default: ``0.0``
* gradient momentum and update direction. Default: `[0.9, 0.99]`.
* @param weightDecay - The weight decay `\lambda`. Default: `0.0`.
*/
export class Lion extends Optimizer {
betas: number[];
weightDecay: number;

constructor(learningRate: number | ((step: mx.array) => mx.array),
betas: number[] = [0.9, 0.99],
weightDecay: number = 0) {
betas: [number, number] = [0.9, 0.99],
weightDecay = 0) {
super();

this.maybeSchedule('learningRate', learningRate);
Expand Down Expand Up @@ -736,18 +755,18 @@ export class Lion extends Optimizer {
* Adafactor: Adaptive Learning Rates with Sublinear Memory Cost
* https://arxiv.org/abs/1804.04235
*
* @param learningRate The learning rate.
* @param eps The first term `\epsilon_1` added to the square of the gradients
* @param learningRate - The learning rate.
* @param eps - The first term `\epsilon_1` added to the square of the gradients
* to improve numerical stability and the second term `\epsilon_2` is used for
* parameter scaling if ``parameterScale`` is set to ``True``. Default:
* ``(1e-30, 1e-3)``.
* @param clipThreshold Clips the unscaled update at `clipThreshold`. Default:
* parameter scaling if `parameterScale` is set to `True`. Default: `[1e-30,
* 1e-3]`.
* @param clipThreshold - Clips the unscaled update at `clipThreshold`. Default:
* `1.0`.
* @param decayRate Coefficient for the running average of the squared gradient.
* Default: `-0.8`.
* @param beta1 If set to a value bigger than zero then first moment will be
* @param decayRate - Coefficient for the running average of the squared
* gradient. Default: `-0.8`.
* @param beta1 - If set to a value bigger than zero then first moment will be
* used. Default: `None`.
* @param weightDecay The weight decay `\lambda`. Default: `0.0`.
* @param weightDecay - The weight decay `\lambda`. Default: `0.0`.
* @param scaleParameter If set to `True` the learning rate will be scaled by
* `\max(\epsilon_1, \text{RMS}(w_{t-1}))`. Default: `True`.
* @param relativeStep If set to `True` the `learningRate` will be ignored and
Expand All @@ -770,10 +789,10 @@ export class Adafactor extends Optimizer {
clipThreshold: number = 1.0,
decayRate: number = -0.8,
beta1: number | null = null,
weightDecay: number = 0.0,
scaleParameter: boolean = true,
relativeStep: boolean = true,
warmupInit: boolean = false) {
weightDecay = 0.0,
scaleParameter = true,
relativeStep = true,
warmupInit = false) {
super();
if (learningRate !== null)
this.maybeSchedule('learningRate', learningRate);
Expand Down Expand Up @@ -925,8 +944,8 @@ export class Adafactor extends Optimizer {
* // {"w1": mx.array([...]), "w2": mx.array([...])}
* ```
*
* @param grads A dictionary containing the gradient arrays.
* @param maxNorm The maximum allowed global norm of the gradients.
* @param grads - A dictionary containing the gradient arrays.
* @param maxNorm - The maximum allowed global norm of the gradients.
* @returns The possibly rescaled gradients and the original gradient norm.
*/
export function clipGradNorm(grads: Nested<mx.array>,
Expand Down

0 comments on commit 29d038a

Please sign in to comment.