title | author | published-on | date |
---|---|---|---|
Gather Multiaxis Operator |
Dwayne Robinson |
2025-02-19 |
2025-02-19 |
"One operator to gather them all, to bring them together, and in the darkness bind them."
partial interface MLGraphBuilder
{
MLOperand gatherMultiaxis(MLOperand input, MLOperand indices, sequence<unsigned long> axes);
};
ML libraries have a confusing mix of various gather/scatter operators, and it always takes me a few minutes to recall the brain-bending differences between every gather*
variant out there, even just in ONNX (Gather, GatherElements, GatherND) let alone other ML libraries. Many are underdocumented too (e.g. TOSA gather and StableHLO gather). Is there a more fundamental expression of a gathering operation that is more generic while also being simpler to document and implement?
It always bothered me after implementing DML_OPERATOR_GATHER
, DML_OPERATOR_GATHER_ELEMENTS
, and DML_OPERATOR_GATHER_ND
(for the corresponding ONNX operators) that there wasn't a more elegant DML operator to encompass them all at the API level, because at the GPU implementation level, every operator used the same shader after normalizing the tensor ranks/strides to be rank-compatible and broadcastable (which made the implementation much simpler and reusable). So after some massaging...
- (1) set input and indices tensor ranks consistently, padding with 1's where needed
- (2) pass
axes
explicitly (like howreduce*
andresample
take axes) instead of letting them be partially inferred from shapes - (3) use existing broadcasting definitions like those from elementwise operators
...you then have one operator that can implement each, and you don't need to re-remember the divergences between each of them, nor need hacks like an extra batch_dims
parameter.
Gather operators can be grouped so:
Category | Library Names | Notes |
---|---|---|
Single axis element gather | ONNX GatherElements PyTorch gather PyTorch take_along_dim numpy.take_along_axis CoreML gather_along_axis |
All tensors have the same rank. All dimensions in input and indices have the same size except the active axis.input.shape = [leading dimensions..., input axis dimension, trailing dimensions...] indices.shape = [leading dimensions..., output axis dimension, trailing dimensions...] output.shape = [leading dimensions..., output axis dimension, trailing dimensions...] output.shape[axis] == indices.shape[axis] input.rank == indices.rank == output.rank axis = from 0 to (input.rank - 1) indices[∀ coordinates] < input.shape[axis] |
Single axis element gather (1D input absolute indices) | PyTorch take | Same as above, but the input tensor is flattened to 1D first, meaning all indices are linearly unique to each input element.input.shape = [indexable dimension] // flattened to 1D indices.shape = [index dimensions...] output.shape = [index dimensions...] input.rank == 1 (after flattening to 1D) output.rank == indices.rank output.shape == indices.shape axis = implicitly 0 indices[∀ coordinates] < input.shape[axis] |
Single axis block gather | ONNX Gather numpy.take TensorFlow gather CoreML gather |
Dimensions are selected from a given input axis and any trailing dimensions copy entire blocks to the output (as if those dimensions in indices were broadcast to the input.shape). The indices can insert additional dimensions into the output.input.shape = [leading dimensions..., input axis dimension, trailing dimensions...] indices.shape = [ index dimensions... ] output.shape = [leading dimensions..., index dimensions..., trailing dimensions...] output.shape = input.shape[0..axis] ~ indices.shape ~ input.shape[axis+1..input.rank] axis = from 0 to (input.rank - 1) indices[∀ coordinates] < input.shape[axis] |
Multiple contiguous axes block gather | ONNX GatherND ONNX gather_nd CoreML gather_nd |
Axes are indirectly implied by correspondence of input and indices shapes, the batch dimension count, and the size of the last dimension in indices (the lookup coordinate size). Axes start at dimension 0 in the input or after the batch dimension count if nonzero, and the number of indexable input dimensions depends on the coordinate size.input.shape = [batch dimensions..., indexable dimensions..., trailing dimensions...] indices.shape = [batch dimensions..., index dimensions..., coordinate size ] output.shape = [batch dimensions..., index dimensions..., trailing dimensions...] batch dimension count < min(input.rank, indices.rank) axes = [batch dimension count, ..., batch dimension count + coordinate size - 1] indices[batch and index coordinates..., ∀ i] < input.shape[batch count + ∀ i] |
Multiaxis gather | None known, but emulatable via reshape + transpose + gatherND | Multiple noncontiguous axes are supported to gather from the input.input.shape = [mix of indexable and broadcastable dimensions...] indices.shape = [mix of index and broadcastable dimensions] output.shape = [mix of indexed and broadcasted dimensions] input.rank == indices.rank == output.rank broadcastShape = broadcast(input.shape, indices.shape) broadcastShape[axes[∀i]] = indices.shape[axes[∀i]] output.shape = broadcastShape axes = [unique axis, ...] where ∀ axis < input.rank |
Indeterminate from documentation 🤷♂️ | TOSA linalg gather TOSA tensor gather StableHLO gather |
TOSA's gather is probably equivalent to one of the above, but the docs lack insight. StableHLO's gather looks quite complex, like some hybrid slice/gather chimera 😯 - it's out of scope. |
They have the following properties:
Category | GatherElements | Gather(blocks) | GatherND(blocks) | Gather Multiaxis |
---|---|---|---|---|
Multiple axes | ❌ | ❌ | ✅ | ✅ |
Non-contiguous axes (like N and C in NHWC layout) | ❌ | ❌ | ❌ | ✅ |
Custom coordinate ordering (like [x,y] or [y,x]) | ❌ | ❌ | ❌ | ✅ |
Supports broadcasting input to indices before axes | ❌ | ❌ | ✅ | ✅ |
Supports broadcasting indices to input before axes¹ | ❌ | ✅ | ❌ | ✅ |
Supports broadcasting indices to input after axes | ❌ | ✅ | ✅ | ✅ |
Supports trailing broadcasting (after axes) | ❌ | ✅ | ✅ | ✅ |
Trivial implementation² | ✅ | ❌ | ❌ | ✅ |
- ¹ Unsure if it's supposed to support broadcasting or not, but ORT appears to.
- ² Trivial implementations reduce the chances of bugs.
partial interface MLGraphBuilder
{
MLOperand gatherMultiaxis(MLOperand input, MLOperand indices, sequence<unsigned long> axes);
};
const output = graphBuilder.gatherMultiaxis(input, indices, axes);
- Consistent rank:
input.rank
==indices.rank
==output.rank
. - Broadcastability: The
input
shape andindices
logical shape may differ for any dimensions inaxes
, but they must be bidirectionally broadcastable for any other dimensions. - Logical indices shape: The
indices
have an actual shape and a logical shape, as the last dimension ofindices
must be a multiple of theaxes
length since the coordinates are folded into that dimension. For 1D cases (axes.length == 1
), the logical and actual shapes are identical anyway, and this is directly equivalent to {ONNX GatherElements
,take_along_axis
,gather_along_axis
}. For 2D/3D/ND cases (axes.length > 1
), the logical shape has the last dimension divided byaxes
length. e.g Anindices
logical shape[2,4]
with 3D coordinates would have an actual shape[2,4*3]
. - Output shape: The output tensor shape takes dimensions of
indices
logical shape that are inaxes
, and any other dimensions are taken frominput
broadcasted withindices
. - Axes: Axes do not need to be contiguous and strictly in order, unlike ONNX Gather and ONNX GatherND. For example, a 2D coordinate [1,2] could be mapped to y=1 and x=2 dimensions respectively, or mapped to the (more commonly found in graphics) the x=1 and y=2 dimensions. GatherND is unable to accomodate this without either transposes on the input or reversal of the indices.
Given the following...
Parameter | Value |
---|---|
input shape | [4,2,1,2] |
indices shape | [1,3,2,2] |
axes | [1] (1D coordinates) |
output shape | [4,3,2,2] |
...the input
shape [4,_,1,2]
is broadcastable with indices
shape [1,_,2,2]
to form [4,_,2,2]
(note the ignorable _
's in place of dimensions in axes
, since those are replaced later anyway). Then axis 1 is taken from indices
shape [_,3,_,_]
to form the final output
shape [4,3,2,2]
.
For multiple axes (2D/3D/ND coordinates where axes.length > 1
), shape computation is more complex since each coordinate consumes multiple index elements. There are a few possible approaches:
- indices.rank = input.rank + 1: Append another trailing dimension onto the indices. So an
indices
logical shape[2,4]
with 3D coordinates yield an actual shape[2,4,3]
. - indices.rank = input.rank and fold last dimension: Fold the coordinate size into the last dimension. So an
indices
logical shape[2,4]
with 3D coordinates yield an actual shape[2,4*3]
(the last dimension is always an exact multiple ofaxes.length
).
Consideration | input.rank + 1 | input.rank and fold last dimension |
---|---|---|
Enables you to change the lengths of non-axes dimensions. |
✅ | ✅ |
Avoids rank limitation issues. | ❌ | ✅ |
Implementation correspondence between tensor coordinates is straight-forward. | ✅ | ✅ (one extra multiply) |
No approach is ideal, but the latter enables you to change existing dimension sizes, enables full rank usage up to backend limits (in other words, a gatherElements
that already worked up to 5D in an implementation could be implemented by this operator), has a simple rank consistency validation rule (input.rank == indices.rank == output.rank
), and is fairly easy for the implementation when mapping coordinates between output/indices/input.
import {Tensor} from './tensor.js'; // https://fdwr.github.io/MachineLearningOperators/tensor.js
// Implementation of a multiaxis gather, which can satisfy ONNX Gather, GatherElements, and GatherND.
function gatherMultiaxis(/*Tensor*/ input, /*Tensor*/ indices, /*Array*/ axes)
{
const coordinateSize = Math.max(axes.length, 1);
console.assert(input.rank == indices.rank);
console.assert(axes.length <= input.rank); // There can't be more axes than input dimensions.
console.assert(indices.size == 0 || input.size != 0); // Input cannot be empty if indices are given.
console.assert(axes.every((value) => value < input.rank)); // Ensure valid axes.
console.assert(axes.every((value, index) => axes.indexOf(value) == index)); // Ensure uniqueness.
console.assert(indices.shape.at(-1) % coordinateSize == 0);
// Bail out early for scalar case to simplify later logic.
if (indices.rank == 0)
{
return new Tensor(indices.shape, input.data);
}
// Compute output shape, and create output tensor, taking the dimensions from indices
// that are in axes and bidirectionally broadcasting input with logical indices dimensions.
//
// e.g. Given input shape [2,3,1], indices shape [1,1,4], and axes [1],
// first broadcast the input shape to intermediate shape [2,3,4],
// then take axes [1] from indices shape [_,1,_] to yield output shape [2,1,4].
// e.g. Given input shape [10,9,8], indices shape [1,5,6*2], and axes = [1,2],
// first broadcast the input with logical indices shape to make [10,_,_],
// then take axes [1,2] from logical shape [_,5,6] to yield [10,5,6].
//
let logicalIndicesShape = [...indices.shape];
logicalIndicesShape[logicalIndicesShape.length - 1] /= coordinateSize;
let outputShape = broadcastShapeWith(input.shape, logicalIndicesShape);
axes.forEach((axis) => outputShape[axis] = logicalIndicesShape[axis]);
let output = new Tensor(outputShape);
// Create broadcasting masks to avoid creating large broadcasted temporaries
// that would waste memory and time.
const inputCoordinateMask = makeBroadcastingMask(input.shape);
const indicesCoordinateMask = makeBroadcastingMask(logicalIndicesShape);
for (let outputCoordinate of output.coordinates)
{
// Determine the corresponding input coordinate to read from given the
// current output coordinate and indices tensor coordinate. Achieve in-place
// broadcasting of the input and/or indices by masking the output coordinate.
let inputCoordinate = getMaskedCoordinate(outputCoordinate, inputCoordinateMask);
let indexCoordinate = getMaskedCoordinate(outputCoordinate, indicesCoordinateMask);
indexCoordinate[indexCoordinate.length - 1] *= coordinateSize;
for (const axis of axes)
{
inputCoordinate[axis] = indices.at(indexCoordinate);
indexCoordinate[indexCoordinate.length - 1]++;
}
const inputValue = input.at(inputCoordinate);
output.setAt(outputCoordinate, inputValue);
}
return output;
}
// Bidirectional broadcasting between two shapes.
// Any dimensions of size 1 in the first tensor will be broadcast to the dimension of the other tensor.
// e.g. first [1,3,1] with second [2,1,4] yields shape [2,3,4].
// If two corresponding coordinates are > 1, the first one wins (not an error). e.g. [2,1] and [3,4] yield [2,4].
function broadcastShapeWith(first, second)
{
console.assert(first.length == second.length);
return first.slice().map((value, index) => (value == 1) ? second[index] : value);
}
// Mask any dimensions of length 1 to 0, so that when coordinates are masked via getMaskedCoordinate
// that they have no contribution to the element location, allowing trivial broadcasting.
// e.g. shape [3,1,4] yields a mask [1,0,1].
function makeBroadcastingMask(/*Array*/ shape)
{
return shape.slice().map((value) => value > 1 ? 1 : 0);
}
// Apply the mask to the coordinate.
// e.g. coordinate [1,2,3] with mask [1,0,1] yields coordinate [1,0,3].
function getMaskedCoordinate(/*Array*/ coordinate, /*Array*/ mask)
{
console.assert(coordinate.length == mask.length);
return coordinate.slice().map((value, index) => value * mask[index]);
}
- ONNX GatherElements
- PyTorch gather
- PyTorch take_along_dim
- numpy.take_along_axis
- CoreML gather_along_axis
// GatherElements is directly compatible with multiaxis gather (just wrap the axis).
function gatherSingleAxisElements(input, indices, axis)
{
return gatherMultiaxis(input, indices, [axis]);
}
input of shape [4,3]:
[[ 0, 1, 2],
[10, 11, 12],
[20, 21, 22],
[30, 31, 32]]
indices of shape [2,3]:
[[3, 1, 1],
[2, 0, 3]]
axis = 0 (default)
output of shape [2,3]:
[[30, 11, 12],
[20, 1, 32]]
input of shape [4,3]:
[[ 0, 1, 2],
[10, 11, 12],
[20, 21, 22],
[30, 31, 32]]
indices of shape [4,1]:
[[2],
[1],
[0],
[2]],
axis = 1
output of shape [4,1]:
[[ 2],
[11],
[20],
[32]]
input of shape [4,2,2]:
[[[ 0, 1],
[ 10, 11]],
[[100, 101],
[110, 111]],
[[200, 201],
[210, 211]],
[[300, 301],
[310, 311]],]
indices of shape [1,2,2]:
[[[0, 2],
[1, 3]]],
axis = 0
output of shape [1,2,2]:
[[[ 0, 201],
[110, 311]]]
1D only, reinterpreting the input as 1D.
function gatherForced1D(input, indices)
{
// Reshape input to 1D, then extend rank to be broadcast compatible after axis 0.
const inputShapeFiller = new Array(Math.max(indices.rank - 1, 0)).fill(1);
const newInputShape = [input.elementCount, ...inputShapeFiller];
const axes = [axis];
return gatherMultiaxis(inputReshaped, indices, axes);
}
// Translates ONNX Gather operator into a generic multiaxis gather by reshaping tensors.
// https://onnx.ai/onnx/operators/onnx__Gather.html
// data shape | indices shape | axis | output shape | output equation | reshaped input/indices/output
// ---------- | -------------- | ---- | ------------ | ------------------------------------------- | --------------------------------------
// (P, Q) | ( ) (a scalar) | 0 | (Q) | output[q] = data[indices, q] | (P, Q) (1, 1) (1, Q)
// (P, Q, R) | ( ) (a scalar) | 1 | (P, R) | output[p, r] = data[p, indices, r] | (P, Q, R) (1, 1, 1) (P, 1, R)
// (P, Q, R) | (S) | 1 | (P, S, R) | output[p, s, r] = data[p, indices[s], r] | (P, Q, 1, R) (1, 1, S, 1) (P, 1, S, R)
// (P, Q) | (R, S) | 0 | (R, S, Q) | output[r, s, q] = data[[indices[r, s], q] | (P, 1, 1, Q) (1, R, S, 1) (1, R, S, Q)
// (P, Q) | (R, S) | 1 | (P, R, S) | output[p, r, s] = data[p, indices[r, s]] | (P, Q, 1, 1) (1, 1, R, S) (P, 1, R, S)
//
function gatherBlocks(input, indices, axis)
{
console.assert(input.rank > 0);
console.assert(axis >= 0);
console.assert(axis < input.rank);
// Split the shapes for input and output into leading/trailing parts.
//
// input shape [ leading input dimensions..., axis dimension, trailing non-indexed input dimensions...]
// <---------------leading part----------------> <------------ trailing part ----------->
// |
// alignment filler inserted here
//
// indices shape [ index dimensions... ]
// | |
// leading filler inserted here trailing filler inserted here
//
const leadingInputShape = input.shape.slice(0, axis + 1);
const trailingInputShape = input.shape.slice(axis + 1);
const leadingOutputShape = input.shape.slice(0, axis);
const trailingOutputShape = trailingInputShape;
const inputShapeFiller = new Array(indices.rank).fill(1);
const leadingIndicesFiller = new Array(axis + 1).fill(1);
const trailingIndicesFiller = new Array(input.rank - axis - 1).fill(1);
// Compute normalized shapes for input and indices, then the output shape, which starts with the input shape,
// removes the dimension of the active axis from input, and substitutes it with the indices shape.
//
// output.shape = input.shape[0..axis] ~ indices.shape ~ input.shape[axis + 1..input.rank]
//
const newInputShape = [...leadingInputShape, ...inputShapeFiller, ...trailingInputShape ];
const newIndicesShape = [...leadingIndicesFiller, ...indices.shape, ...trailingIndicesFiller];
const outputShape = [...leadingOutputShape, ...indices.shape, ...trailingOutputShape ];
const inputReshaped = input.asShape(newInputShape);
const indicesReshaped = indices.asShape(newIndicesShape);
const axes = [axis];
const output = gatherMultiaxis(inputReshaped, indicesReshaped, axes);
return output.asShape(outputShape);
}
input of shape [4,3]:
[[ 0, 1, 2],
[10, 11, 12],
[20, 21, 22],
[30, 31, 32]]
axis = 0 (default)
indices of shape [2]:
[3,1]
output of shape [2,3]:
[[30, 31, 32],
[10, 11, 12]]
intermediate processing values:
input shape = [4,3]
indices shape = [2,1]
output shape = [2,3]
axes = [0]
input of shape [4,3]:
[[ 0, 1, 2],
[10, 11, 12],
[20, 21, 22],
[30, 31, 32]]
axis = 1
indices of shape [5]:
[2,1,0,1,2]
output of shape [4,5]:
[[ 2, 1, 0, 1, 2],
[12, 11, 10, 11, 12],
[22, 21, 20, 21, 22],
[32, 31, 30, 31, 32]]
intermediate processing values:
input shape = [4,3]
indices shape = [1,5]
output shape = [4,5]
axes = [1]
input of shape [4,3]:
[[ 0, 1, 2],
[10, 11, 12],
[20, 21, 22],
[30, 31, 32]]
axis = 1
indices of shape [2,2]:
[[0, 1],
[1, 2]]
output of shape [4,2,2]:
[[[ 0, 1], [ 1, 2]],
[[10, 11], [11, 12]],
[[20, 21], [21, 22]],
[[30, 31], [31, 32]]]
intermediate processing values:
input shape = [4,1,3]
indices shape = [1,2,2]
output shape = [4,2,2]
axes = [2]
input:
[[1 2]
[3 4]]
indices:
1
axis = 0
output:
[3 4]
intermediate processing values:
input shape = [2,2]
indices shape = [2,1]
output shape = [2,2]
axes = [0]
input:
[[1 2]
[3 4]]
indices:
[1 0]
axis = 0
output:
[[3 4]
[1 2]]
intermediate processing values:
input shape = [2,2]
indices shape = [2,1]
output shape = [2,2]
axes = [0]
input:
[[1 2]
[3 4]]
indices:
[[1 0]
[0 1]]
axis = 0
output:
[[[3 4]
[1 2]]
[[1 2]
[3 4]]]
intermediate processing values:
input shape = [2,1,2]
indices shape = [2,2,1]
output shape = [2,2,2]
axes = [0]
- ONNX GatherND
- ONNX gather_nd
- CoreML gather_nd
- PyTorch gather_nd missing, see here and here
// Maps GatherND to multiaxis gather by reshaping tensors.
// https://onnx.ai/onnx/operators/onnx__GatherND.html
//
// Example 1
//
// batch_dims = 0
// data = [[0,1],[2,3]] # data_shape = [2, 2]
// indices = [[0,0],[1,1]] # indices_shape = [2, 2]
// output = [0,3] # output_shape = [2]
///
// Example 2
//
// batch_dims = 0
// data = [[0,1],[2,3]] # data_shape = [2, 2]
// indices = [[1],[0]] # indices_shape = [2, 1]
// output = [[2,3],[0,1]] # output_shape = [2, 2]
//
// Example 3
//
// batch_dims = 0
// data = [[[0,1],[2,3]],[[4,5],[6,7]]] # data_shape = [2, 2, 2]
// indices = [[0,1],[1,0]] # indices_shape = [2, 2]
// output = [[2,3],[4,5]] # output_shape = [2, 2]
//
// Example 4
//
// batch_dims = 0
// data = [[[0,1],[2,3]],[[4,5],[6,7]]] # data_shape = [2, 2, 2]
// indices = [[[0,1]],[[1,0]]] # indices_shape = [2, 1, 2]
// output = [[[2,3]],[[4,5]]] # output_shape = [2, 1, 2]
//
// Example 5
//
// batch_dims = 1
// data = [[[0,1],[2,3]],[[4,5],[6,7]]] # data_shape = [2, 2, 2]
// indices = [[1],[0]] # indices_shape = [2, 1]
// output = [[2,3],[4,5]] # output_shape = [2, 2]
//
function gatherNdBlocks(input, indices, batchDimensionCount)
{
console.assert(input.rank > 0);
console.assert(indices.rank > 0);
console.assert(batchDimensionCount >= 0);
console.assert(batchDimensionCount <= input.rank);
console.assert(batchDimensionCount <= indices.rank);
console.assert(indices.shape.at(-1) >= 0); // 0 is allowed.
console.assert(batchDimensionCount + indices.shape.at(-1) <= input.rank); // Can't touch elements that don't exist.
const coordinateSize = indices.shape.at(-1);
const axes = Array.from({length: coordinateSize}, (_, i) => batchDimensionCount + i); // Sequence. e.g. 0,1,2
// Split the input shape into leading/trailing parts.
//
// input shape [ batch dimensions..., indexable input dimensions..., non-indexed input dimensions...]
// <-----------------leading part---------------------> <------- trailing part ------->
//
// Split the indices shape into leading/trailing parts (the trailing part is chopped off and multiplied back later
// as coordinateSize).
//
// indices shape [ batch dimensions..., index dimensions..., coordinate tuple size dimension]
// <--------------leading part-------------> <------- trailing part ------->
//
const leadingInputShape = input.shape.slice(0, batchDimensionCount + coordinateSize); // Exclude trailing non-indexed dimensions.
const trailingInputShape = input.shape.slice(batchDimensionCount + coordinateSize); // Exclude batch and indexed dimensions.
const leadingIndicesShape = indices.shape.slice(0, -1); // Exclude last dimension.
// Align the shape fragments, so that the indexable input dimensions and the index dimensions are consistent.
// Insert filler as needed to make them correspondent and broadcastable to the output.
//
// input shape [ batch, indexable dimensions, <--- filler here --->, non-indexable dimensions]
// indices shape [ batch, indices dimensions, <---------------- filler here ---------------->]
//
const maxLeadingShapingLength = Math.max(leadingInputShape.length, leadingIndicesShape.length);
const newShapeLength = maxLeadingShapingLength + trailingInputShape.length;
const inputShapeFiller = new Array(maxLeadingShapingLength - leadingInputShape.length).fill(1);
const indicesShapeFiller = new Array(newShapeLength - leadingIndicesShape.length).fill(1);
// The output shape consists of any leading batch dimensions from the indices, intermediate indices dimensions
// (excluding the last dimension which is the coordinate size), and any residual input dimensions after consuming
// the number of coordinates from the leading side.
//
// output.shape = [ batch dimensions ... indices dimensions ... residual input dimensions]
const newInputShape = [...leadingInputShape, ...inputShapeFiller, ...trailingInputShape];
let newIndicesShape = [...leadingIndicesShape, ...indicesShapeFiller, /* set below */ ];
const outputShape = [...leadingIndicesShape, /* no filler */ ...trailingInputShape];
newIndicesShape[newIndicesShape.length - 1] *= coordinateSize;
const inputReshaped = input.asShape(newInputShape);
const indicesReshaped = indices.asShape(newIndicesShape);
const output = gatherMultiaxis(inputReshaped, indicesReshaped, axes);
return output.asShape(outputShape);
}
input of shape [2,2,2]:
[[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]]
indices of shape [2,2]:
[[0, 1],
[1, 0]]
output of shape [2,2]:
[[2, 3], <= row [2, 3] from input coordinates [0, 1, *]
[4, 5]] <= row [4, 5] from input coordinates [1, 0, *]
intermediate processing values:
input shape = [2,2,2]
indices shape = [1,2,1*2]
output shape = [1,2,2]
axes = [0,1]
input of shape [2,2,2]:
[[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]]
indices of shape [3,1]:
[[1],
[0],
[1]]
output of shape [3,2,2]:
[[[4, 5], <= block [[4, 5], [6, 7]] from input coordinates [1, *, *]
[6, 7]],
[[0, 1], <= block [[0, 1], [2, 3]] from input coordinates [0, *, *]
[2, 3]],
[[4, 5], <= block [[4, 5], [6, 7]] from input coordinates [1, *, *]
[6, 7]]]
intermediate processing values:
input shape = [2,2,2]
indices shape = [3,1,1]
output shape = [3,2,2]
axes = [0]
input of shape [2,2,2]:
[[[0,1],[2,3]],[[4,5],[6,7]]]
indices of shape [2,1]:
[[1],[0]]
output of shape [2,2]:
[[2,3],[4,5]]
batch_dims = 1
intermediate processing values:
input shape = [2,2,2]
indices shape = [2,1,1]
output shape = [2,1,2]
axes = [1]
input of shape [2,2,2]:
[[[0,1],[2,3]],[[4,5],[6,7]]]
indices of shape [5,3]:
[[0,0,1],
[0,1,0],
[1,0,0],
[1,1,0],
[1,1,1]]
output of shape [5]:
[1,2,4,6,7]
intermediate processing values:
input shape = [2,2,2]
indices shape = [5,1,1*3]
output shape = [5,1,1]
axes = [0,1,2]
// Demonstrates broadcasting of leading input dimensions (since 1, not 2)
batch_dims = 1
input of shape [1,3]
[[0,1,2]]
indices of shape [2,1]
[[1],[2]],
output of shape [2]
[1,2],
- TOSA linalg gather
- TOSA tensor gather (possibly GatherND?)
- StableHLO gather
The documentation does not enlighten. 🤷♂️
Test code:
# numpy==1.24.3
import numpy as np
def gather(data, indices, axis=0, mode='gather'):
print("gather", mode)
print("data:\n", data)
print("indices:\n", indices)
print("output:")
if mode == 'gather':
return np.take(data, indices, axis=axis)
elif mode == 'gather_elements':
result = np.zeros_like(indices)
for idx, val in np.ndenumerate(indices):
result[idx] = data[tuple(idx[:axis] + (val,) + idx[axis+1:])]
return result
elif mode == 'gather_nd':
return data[tuple(indices.T)]
else:
raise ValueError("Unsupported mode. Use 'gather', 'gather_elements', or 'gather_nd'.")
#endif
#enddef
# Example usage:
data = np.array([[1, 2], [3, 4]])
indices = np.array(1)
print(gather(data, indices, axis=0, mode='gather'), '\n') # Output: [3 4]
data = np.array([[1, 2], [3, 4]])
indices = np.array([1, 0])
print(gather(data, indices, axis=0, mode='gather'), '\n') # Output: [[[3 4], [1 2]], [[1 2], [3 4]]]
data = np.array([[1, 2], [3, 4]])
indices = np.array([[1, 0], [0, 1]])
print(gather(data, indices, axis=0, mode='gather'), '\n')
data = np.array([[ 0, 1, 2], [10, 11, 12], [20, 21, 22], [30, 31, 32]])
indices = np.array([2,1,0,1,2])
print(gather(data, indices, axis=1, mode='gather'), '\n')
indices = np.array([[0, 1], [1, 0]])
print(gather(data, indices, axis=1, mode='gather_elements'), '\n') # Output: [[1, 2], [4, 3]]
indices = np.array([[0, 0], [1, 1]])
print(gather(data, indices, mode='gather_nd'), '\n') # Output: [1, 4]