Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add dilate and erode operations to Tensor #1136

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
330 changes: 328 additions & 2 deletions src/utils/tensor.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,13 @@ const DataTypeMap = Object.freeze({
});

/**
* @typedef {keyof typeof DataTypeMap} DataType
* @typedef {import('./maths.js').AnyTypedArray | any[]} DataArray
* @typedef {keyof typeof DataTypeMap} DataType A Tensor data type, for example `uint8`.
* @typedef {import('./maths.js').AnyTypedArray | any[]} DataArray A typed array or an array of values.
* @typedef {{x: number, y: number}} Point An object representing coordinates.
* @typedef {'RECT' | 'CROSS' | 'ELLIPSE'} Shape A shape for morphological operations.
* @typedef {{width: number, height: number}} Size An object representing the size of an object.
*
* @typedef {number | [number, number] | Size} KernelSize A kernel size for morphological operations.
*/


Expand Down Expand Up @@ -789,6 +794,189 @@ export class Tensor {
return new Tensor('int64', [BigInt(index)], []);
}

/**
* Mutates the data through a dilation morphological operation.
*
* @param {KernelSize} kernelSize The width and height of the kernel.
* @param {Shape} [shape='RECT'] The shape of the kernel.
* @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel.
* @returns {Promise<Tensor>} Returns `this`.
*/
async dilate_(kernelSize = 3, shape = 'RECT', anchor = { x: -1, y: -1 }) {
const this_data = this.data;
const data = await this.morphologicalOperation('DILATE', this_data, kernelSize, shape, anchor);
for (let i = 0; i < this_data.length; ++i) {
this.data[i] = data[i];
}
return this;
}

/**
* Returns a new Tensor where the data is mutated through a dilation
* morphological operation.
*
* @param {KernelSize} kernelSize The width and height of the kernel.
* @param {Shape} [shape='RECT'] The shape of the kernel.
* @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel.
* @returns {Promise<Tensor>} The new Tensor.
*/
async dilate(kernelSize = 3, shape = 'RECT', anchor = { x: -1, y: -1 }) {
return this.clone().dilate_(kernelSize, shape, anchor);
}

/**
* * Mutates the data through a erosion morphological operation.
*
* @param {KernelSize} kernelSize The width and height of the kernel.
* @param {Shape} [shape='RECT'] The shape of the kernel.
* @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel.
* @returns {Promise<Tensor>} Returns `this`.
*/
async erode_(kernelSize = 3, shape = 'RECT', anchor = { x: -1, y: -1 }) {
const this_data = this.data;
const data = await this.morphologicalOperation('ERODE', this_data, kernelSize, shape, anchor);
for (let i = 0; i < this_data.length; ++i) {
this.data[i] = data[i];
}
return this;
}

/**
* Returns a new Tensor where the data is mutated through a erosion
* morphological operation.
*
* @param {KernelSize} kernelSize The width and height of the kernel.
* @param {Shape} [shape='RECT'] The shape of the kernel.
* @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel.
* @returns {Promise<Tensor>} The new Tensor.
*/
async erode(kernelSize = 3, shape = 'RECT', anchor = { x: -1, y: -1 }) {
return this.clone().erode_(kernelSize, shape, anchor);
}

/**
* Applies a morphological operation to this tensor.
*
* @param {'DILATE' | 'ERODE'} operation The operation to apply.
* @param {DataArray} data The input tensor data.
* @param {KernelSize} kernelSize The width and height of the kernel.
* @param {Shape} [shape='RECT'] The shape of the kernel.
* @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel.
* @returns {Promise<DataArray>} The cloned, modified output tensor.
*/
async morphologicalOperation(operation, data, kernelSize, shape = 'RECT', anchor = { x: -1, y: -1 }) {
kernelSize = validateKernel(kernelSize);
// We don't need to perform the operation if the kernel is empty.
if (kernelSize.width * kernelSize.height === 1) {
return;
}

anchor = normalizeAnchor(anchor, kernelSize);
let kernel = getStructuringElement(shape, kernelSize, anchor);

const [batches, rows, cols] = this.dims;
const paddingSize = { width: Math.floor(kernelSize.width / 2), height: Math.floor(kernelSize.height / 2) };
const outputData = new Float32Array(this.data.length);
const operationFunction = (operationType => {
switch (operationType) {
case 'DILATE':
return Math.max;
case 'ERODE':
return Math.min;
default:
throw new Error(`Unknown operation: ${operationType}`);
}
})(operation);

const processChunk = async chunk => {
for (const { batchIndex, rowIndex, colIndex } of chunk) {
const kernelValues = [];

// Collect values in the kernel window.
for (let kernelRowOffset = -paddingSize.height; kernelRowOffset <= paddingSize.height; kernelRowOffset++) {
for (let kernelColOffset = -paddingSize.width; kernelColOffset <= paddingSize.width; kernelColOffset++) {
const neighborRowIndex = rowIndex + kernelRowOffset;
const neighborColIndex = colIndex + kernelColOffset;
if (neighborRowIndex >= 0 && neighborRowIndex < rows && neighborColIndex >= 0 && neighborColIndex < cols) {
const neighborIndex = (batchIndex * rows * cols) + neighborRowIndex * cols + neighborColIndex;
// Only include values where the kernel has a value
// of 1.
// Rather than multiply against this value, we use
// the if check to reduce the size of the array.
const kernelValue = kernel[kernelRowOffset + paddingSize.height][kernelColOffset + paddingSize.width];
if (kernelValue === 1) {
kernelValues.push(data[neighborIndex] * kernelValue);
}
Comment on lines +902 to +909
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally I was not using the if statement, and utilising the fact that a kernel contains either 0 or 1, however that meant that kernelValues array would be bigger than it needs to be.

Also, erode wouldn't work properly because 0 is always the minimum value.

}
}
}

// Apply operation function to the values.
const outputIndex = batchIndex * rows * cols + rowIndex * cols + colIndex;
outputData[outputIndex] = operationFunction(...kernelValues);
}
};

// Divide work into chunks for parallel processing.
const chunks = [];
const chunkSize = Math.ceil((batches * rows * cols) / (navigator.hardwareConcurrency || 4));
let currentChunk = [];

for (let rowIndex = 0; rowIndex < rows; rowIndex++) {
for (let colIndex = 0; colIndex < cols; colIndex++) {
for (let batchIndex = 0; batchIndex < batches; batchIndex++) {
currentChunk.push({ batchIndex, rowIndex, colIndex });
// Store the chunk now that it is the right size.
if (currentChunk.length >= chunkSize) {
chunks.push([...currentChunk]);
currentChunk = [];
}
}
}
}
// Get any elements that may not fit neatly in the defined chunk size.
if (currentChunk.length > 0) {
chunks.push(currentChunk);
}

// Process all chunks in parallel.
await Promise.all(chunks.map(chunk => processChunk(chunk)));

return outputData;
}

/**
* Performs a morphological operation on the input image.
*
* @param {'ERODE' | 'DILATE' | 'OPEN' | 'CLOSE'} operation
* @param {KernelSize} kernelSize The width and height of the kernel.
* @param {Shape} [shape='RECT'] The shape of the kernel.
* @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel.
* @returns {Promise<Tensor>} The cloned, modified output tensor.
*/
async morph(operation, kernelSize, shape = 'RECT', anchor = { x: -1, y: -1 }) {
switch (operation) {
case 'ERODE':
return this.erode(kernelSize, shape, anchor);

case 'DILATE':
return this.dilate(kernelSize, shape, anchor);

case 'OPEN':
return (await this
.erode_(kernelSize, shape, anchor))
.dilate_(kernelSize, shape, anchor);

case 'CLOSE':
return (await this
.dilate_(kernelSize, shape, anchor))
.erode_(kernelSize, shape, anchor);

default:
throw new Error("Unknown morphological operation");
}
}

/**
* Performs Tensor dtype conversion.
* @param {DataType} type The desired data type.
Expand Down Expand Up @@ -1542,3 +1730,141 @@ export function quantize_embeddings(tensor, precision) {

return new Tensor(dtype, outputData, [tensor.dims[0], tensor.dims[1] / 8]);
}

/**
* Ensure that an anchor lies within the kernel size.
* Passing in a `-1` will center the anchor.
*
* @param {Point} anchor The input anchor point.
* @param {Size} kernelSize The width and height of the kernel.
* @returns {Point} The normalized anchor point.
*/
function normalizeAnchor(anchor, kernelSize) {
// Centralize the x coordinate.
if (anchor.x === -1) {
anchor.x = Math.floor(kernelSize.width / 2);
}
// Centralize the y coordinate.
if (anchor.y === -1) {
anchor.y = Math.floor(kernelSize.height / 2);
}
// Check if the anchor is within the kernel size.
if (anchor.x < 0 || anchor.x >= kernelSize.width ||
anchor.y < 0 || anchor.y >= kernelSize.height
) {
throw new Error("Anchor is out of bounds for the given kernel size.");
}
return anchor;
}

/**
* Creates a Size object that represents a kernel.
* Performs some validation on the kernel size.
*
* @param {KernelSize} kernelSize The size of the kernel.
* @returns {Size} An object representing the kernel width and height.
* @throws {Error} If the kernel size is invalid.
* @throws {Error} If kernel size is even.
*/
function validateKernel(kernelSize) {
let kernel;
if (typeof kernelSize === 'object' && 'width' in kernelSize && 'height' in kernelSize) {
// This is a Size object, so no conversion required.
kernel = kernelSize;
} else if (typeof kernelSize === 'number' && Number.isInteger(kernelSize)) {
// A single whole number is assumed as the width and height.
kernel = { width: kernelSize, height: kernelSize };
} else if (Array.isArray(kernelSize) && kernelSize.length === 2 && kernelSize.every(Number.isInteger)) {
// An array of two values is assumed as width then height.
kernel = { width: kernelSize[0], height: kernelSize[1] };
} else {
throw new Error("Invalid kernel size.");
}

if (kernel.width % 2 === 0 || kernel.height % 2 === 0) {
throw new Error("Kernel size must be odd");
}

return kernel;
}
Comment on lines +1769 to +1789
Copy link
Contributor Author

@BritishWerewolf BritishWerewolf Jan 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of this function is to help users with passing in kernel values.

To get a 3x3 kernel to the erode function we can do this:

  • tensor.erode(3)
  • tensor.erode([3, 3])
  • tensor.erode({ width: 3, height: 3 })

Most likely we want a symmetrical kernel, and it's easier to just pass in a single number.


/**
* Creates a structuring element for morphological operations.
*
* This function is a JavaScript translation of the [OpenCV C++ function of the same name](https://github.com/egonSchiele/OpenCV/blob/master/modules/imgproc/src/morph.cpp#L981).
*
* @param {Shape} shape The shape of the kernel.
* @param {Size} kernelSize The width and height of the kernel.
* @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel.
* @returns {Array<Array<number>>} The structuring element as a 2D array.
* @throws {Error} If the shape, or kernel size, is invalid is invalid.
*/
function getStructuringElement(shape, kernelSize, anchor = { x: -1, y: -1 }) {
if (!['RECT', 'CROSS', 'ELLIPSE'].includes(shape)) {
throw new Error("Invalid shape. Must be 'RECT', 'CROSS', or 'ELLIPSE'.");
}

// Get a kernel object that represents the kernel width and height.
let kernel = validateKernel(kernelSize);

// Normalize anchor to default to the center if not specified.
anchor = normalizeAnchor(anchor, kernel);

// If the kernel size is 1x1, treat as a rectangle.
if (kernel.width === 1 && kernel.height === 1) {
shape = 'RECT';
}

let rowRadius = 0; // Radius along the height.
let colRadius = 0; // Radius along the width.
let inverseRowRadiusSquared = 0; // Inverse squared radius for ellipses.

if (shape === 'ELLIPSE') {
// Calculate radii and inverse squared radius for the ellipse equation.
rowRadius = Math.floor(kernel.height / 2);
colRadius = Math.floor(kernel.width / 2);
inverseRowRadiusSquared = rowRadius > 0 ? 1 / (rowRadius * rowRadius) : 0;
}

// Create a 2D array to represent the kernel.
const kernelArray = Array.from({ length: kernel.height }, () => Array(kernel.width).fill(0));

for (let row = 0; row < kernel.height; row++) {
let startColumn = 0;
let endColumn = 0;

if (shape === 'RECT' || (shape === 'CROSS' && row === anchor.y)) {
// Full width for rectangle or horizontal line for cross shape.
endColumn = kernel.width;
} else if (shape === 'CROSS') {
// Single column for cross shape.
// A cross will be a single row and column, so only add 1.
startColumn = anchor.x;
endColumn = startColumn + 1;
} else if (shape === 'ELLIPSE') {
// Calculate elliptical bounds for this row.

// Distance from the anchor row.
const verticalOffset = row - anchor.y;

if (Math.abs(verticalOffset) <= rowRadius) {
// Solve for horizontal bounds using the ellipse equation: x^2/a^2 + y^2/b^2 = 1
const horizontalRadius = Math.floor(
colRadius * Math.sqrt(Math.max(0, (rowRadius * rowRadius) - (verticalOffset * verticalOffset)) * inverseRowRadiusSquared)
);

// Left and right bound of the ellipse.
// Add 1 to endColumn because it's not inclusive in the for loop.
startColumn = Math.max(anchor.x - horizontalRadius, 0);
endColumn = Math.min(anchor.x + horizontalRadius + 1, kernel.width);
}
}

// Fill the kernel row with 1s within the range (startColumn, endColumn).
for (let col = startColumn; col < endColumn; col++) {
kernelArray[row][col] = 1;
}
}

return kernelArray;
}