Skip to content

Commit

Permalink
-1 is not a valid index
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed May 22, 2024
1 parent b23bfdb commit 342b092
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lib/nn/layers/transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ export class MultiHeadAttention extends Module {
values = values.reshape(B, S, numHeads, -1).transpose(0, 2, 1, 3);

// Dimensions are [batch x numHeads x sequence x hiddenDim].
const scale = mx.array(Math.sqrt(1 / queries.shape[-1]), queries.dtype);
const scale = Math.sqrt(1 / queries.shape[queries.shape.length - 1]);
let scores = mx.matmul(mx.multiply(queries, scale), keys);
if (mask)
scores = mx.add(scores, mask.astype(scores.dtype));
Expand Down Expand Up @@ -417,4 +417,4 @@ export class Transformer extends Module {
const memory = this.encoder.forward(src, srcMask);
return this.decoder.forward(tgt, memory, tgtMask, memoryMask);
}
}
}

0 comments on commit 342b092

Please sign in to comment.