-
Notifications
You must be signed in to change notification settings - Fork 813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEAT] Add dilate and erode operations to Tensor #1136
Open
BritishWerewolf
wants to merge
3
commits into
huggingface:main
Choose a base branch
from
BritishWerewolf:add-dilate-erode
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,8 +39,13 @@ const DataTypeMap = Object.freeze({ | |
}); | ||
|
||
/** | ||
* @typedef {keyof typeof DataTypeMap} DataType | ||
* @typedef {import('./maths.js').AnyTypedArray | any[]} DataArray | ||
* @typedef {keyof typeof DataTypeMap} DataType A Tensor data type, for example `uint8`. | ||
* @typedef {import('./maths.js').AnyTypedArray | any[]} DataArray A typed array or an array of values. | ||
* @typedef {{x: number, y: number}} Point An object representing coordinates. | ||
* @typedef {'RECT' | 'CROSS' | 'ELLIPSE'} Shape A shape for morphological operations. | ||
* @typedef {{width: number, height: number}} Size An object representing the size of an object. | ||
* | ||
* @typedef {number | [number, number] | Size} KernelSize A kernel size for morphological operations. | ||
*/ | ||
|
||
|
||
|
@@ -789,6 +794,189 @@ export class Tensor { | |
return new Tensor('int64', [BigInt(index)], []); | ||
} | ||
|
||
/** | ||
* Mutates the data through a dilation morphological operation. | ||
* | ||
* @param {KernelSize} kernelSize The width and height of the kernel. | ||
* @param {Shape} [shape='RECT'] The shape of the kernel. | ||
* @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel. | ||
* @returns {Promise<Tensor>} Returns `this`. | ||
*/ | ||
async dilate_(kernelSize = 3, shape = 'RECT', anchor = { x: -1, y: -1 }) { | ||
const this_data = this.data; | ||
const data = await this.morphologicalOperation('DILATE', this_data, kernelSize, shape, anchor); | ||
for (let i = 0; i < this_data.length; ++i) { | ||
this.data[i] = data[i]; | ||
} | ||
return this; | ||
} | ||
|
||
/** | ||
* Returns a new Tensor where the data is mutated through a dilation | ||
* morphological operation. | ||
* | ||
* @param {KernelSize} kernelSize The width and height of the kernel. | ||
* @param {Shape} [shape='RECT'] The shape of the kernel. | ||
* @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel. | ||
* @returns {Promise<Tensor>} The new Tensor. | ||
*/ | ||
async dilate(kernelSize = 3, shape = 'RECT', anchor = { x: -1, y: -1 }) { | ||
return this.clone().dilate_(kernelSize, shape, anchor); | ||
} | ||
|
||
/** | ||
* * Mutates the data through a erosion morphological operation. | ||
* | ||
* @param {KernelSize} kernelSize The width and height of the kernel. | ||
* @param {Shape} [shape='RECT'] The shape of the kernel. | ||
* @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel. | ||
* @returns {Promise<Tensor>} Returns `this`. | ||
*/ | ||
async erode_(kernelSize = 3, shape = 'RECT', anchor = { x: -1, y: -1 }) { | ||
const this_data = this.data; | ||
const data = await this.morphologicalOperation('ERODE', this_data, kernelSize, shape, anchor); | ||
for (let i = 0; i < this_data.length; ++i) { | ||
this.data[i] = data[i]; | ||
} | ||
return this; | ||
} | ||
|
||
/** | ||
* Returns a new Tensor where the data is mutated through a erosion | ||
* morphological operation. | ||
* | ||
* @param {KernelSize} kernelSize The width and height of the kernel. | ||
* @param {Shape} [shape='RECT'] The shape of the kernel. | ||
* @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel. | ||
* @returns {Promise<Tensor>} The new Tensor. | ||
*/ | ||
async erode(kernelSize = 3, shape = 'RECT', anchor = { x: -1, y: -1 }) { | ||
return this.clone().erode_(kernelSize, shape, anchor); | ||
} | ||
|
||
/** | ||
* Applies a morphological operation to this tensor. | ||
* | ||
* @param {'DILATE' | 'ERODE'} operation The operation to apply. | ||
* @param {DataArray} data The input tensor data. | ||
* @param {KernelSize} kernelSize The width and height of the kernel. | ||
* @param {Shape} [shape='RECT'] The shape of the kernel. | ||
* @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel. | ||
* @returns {Promise<DataArray>} The cloned, modified output tensor. | ||
*/ | ||
async morphologicalOperation(operation, data, kernelSize, shape = 'RECT', anchor = { x: -1, y: -1 }) { | ||
kernelSize = validateKernel(kernelSize); | ||
// We don't need to perform the operation if the kernel is empty. | ||
if (kernelSize.width * kernelSize.height === 1) { | ||
return; | ||
} | ||
|
||
anchor = normalizeAnchor(anchor, kernelSize); | ||
let kernel = getStructuringElement(shape, kernelSize, anchor); | ||
|
||
const [batches, rows, cols] = this.dims; | ||
const paddingSize = { width: Math.floor(kernelSize.width / 2), height: Math.floor(kernelSize.height / 2) }; | ||
const outputData = new Float32Array(this.data.length); | ||
const operationFunction = (operationType => { | ||
switch (operationType) { | ||
case 'DILATE': | ||
return Math.max; | ||
case 'ERODE': | ||
return Math.min; | ||
default: | ||
throw new Error(`Unknown operation: ${operationType}`); | ||
} | ||
})(operation); | ||
|
||
const processChunk = async chunk => { | ||
for (const { batchIndex, rowIndex, colIndex } of chunk) { | ||
const kernelValues = []; | ||
|
||
// Collect values in the kernel window. | ||
for (let kernelRowOffset = -paddingSize.height; kernelRowOffset <= paddingSize.height; kernelRowOffset++) { | ||
for (let kernelColOffset = -paddingSize.width; kernelColOffset <= paddingSize.width; kernelColOffset++) { | ||
const neighborRowIndex = rowIndex + kernelRowOffset; | ||
const neighborColIndex = colIndex + kernelColOffset; | ||
if (neighborRowIndex >= 0 && neighborRowIndex < rows && neighborColIndex >= 0 && neighborColIndex < cols) { | ||
const neighborIndex = (batchIndex * rows * cols) + neighborRowIndex * cols + neighborColIndex; | ||
// Only include values where the kernel has a value | ||
// of 1. | ||
// Rather than multiply against this value, we use | ||
// the if check to reduce the size of the array. | ||
const kernelValue = kernel[kernelRowOffset + paddingSize.height][kernelColOffset + paddingSize.width]; | ||
if (kernelValue === 1) { | ||
kernelValues.push(data[neighborIndex] * kernelValue); | ||
} | ||
} | ||
} | ||
} | ||
|
||
// Apply operation function to the values. | ||
const outputIndex = batchIndex * rows * cols + rowIndex * cols + colIndex; | ||
outputData[outputIndex] = operationFunction(...kernelValues); | ||
} | ||
}; | ||
|
||
// Divide work into chunks for parallel processing. | ||
const chunks = []; | ||
const chunkSize = Math.ceil((batches * rows * cols) / (navigator.hardwareConcurrency || 4)); | ||
let currentChunk = []; | ||
|
||
for (let rowIndex = 0; rowIndex < rows; rowIndex++) { | ||
for (let colIndex = 0; colIndex < cols; colIndex++) { | ||
for (let batchIndex = 0; batchIndex < batches; batchIndex++) { | ||
currentChunk.push({ batchIndex, rowIndex, colIndex }); | ||
// Store the chunk now that it is the right size. | ||
if (currentChunk.length >= chunkSize) { | ||
chunks.push([...currentChunk]); | ||
currentChunk = []; | ||
} | ||
} | ||
} | ||
} | ||
// Get any elements that may not fit neatly in the defined chunk size. | ||
if (currentChunk.length > 0) { | ||
chunks.push(currentChunk); | ||
} | ||
|
||
// Process all chunks in parallel. | ||
await Promise.all(chunks.map(chunk => processChunk(chunk))); | ||
|
||
return outputData; | ||
} | ||
|
||
/** | ||
* Performs a morphological operation on the input image. | ||
* | ||
* @param {'ERODE' | 'DILATE' | 'OPEN' | 'CLOSE'} operation | ||
* @param {KernelSize} kernelSize The width and height of the kernel. | ||
* @param {Shape} [shape='RECT'] The shape of the kernel. | ||
* @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel. | ||
* @returns {Promise<Tensor>} The cloned, modified output tensor. | ||
*/ | ||
async morph(operation, kernelSize, shape = 'RECT', anchor = { x: -1, y: -1 }) { | ||
switch (operation) { | ||
case 'ERODE': | ||
return this.erode(kernelSize, shape, anchor); | ||
|
||
case 'DILATE': | ||
return this.dilate(kernelSize, shape, anchor); | ||
|
||
case 'OPEN': | ||
return (await this | ||
.erode_(kernelSize, shape, anchor)) | ||
.dilate_(kernelSize, shape, anchor); | ||
|
||
case 'CLOSE': | ||
return (await this | ||
.dilate_(kernelSize, shape, anchor)) | ||
.erode_(kernelSize, shape, anchor); | ||
|
||
default: | ||
throw new Error("Unknown morphological operation"); | ||
} | ||
} | ||
|
||
/** | ||
* Performs Tensor dtype conversion. | ||
* @param {DataType} type The desired data type. | ||
|
@@ -1542,3 +1730,141 @@ export function quantize_embeddings(tensor, precision) { | |
|
||
return new Tensor(dtype, outputData, [tensor.dims[0], tensor.dims[1] / 8]); | ||
} | ||
|
||
/** | ||
* Ensure that an anchor lies within the kernel size. | ||
* Passing in a `-1` will center the anchor. | ||
* | ||
* @param {Point} anchor The input anchor point. | ||
* @param {Size} kernelSize The width and height of the kernel. | ||
* @returns {Point} The normalized anchor point. | ||
*/ | ||
function normalizeAnchor(anchor, kernelSize) { | ||
// Centralize the x coordinate. | ||
if (anchor.x === -1) { | ||
anchor.x = Math.floor(kernelSize.width / 2); | ||
} | ||
// Centralize the y coordinate. | ||
if (anchor.y === -1) { | ||
anchor.y = Math.floor(kernelSize.height / 2); | ||
} | ||
// Check if the anchor is within the kernel size. | ||
if (anchor.x < 0 || anchor.x >= kernelSize.width || | ||
anchor.y < 0 || anchor.y >= kernelSize.height | ||
) { | ||
throw new Error("Anchor is out of bounds for the given kernel size."); | ||
} | ||
return anchor; | ||
} | ||
|
||
/** | ||
* Creates a Size object that represents a kernel. | ||
* Performs some validation on the kernel size. | ||
* | ||
* @param {KernelSize} kernelSize The size of the kernel. | ||
* @returns {Size} An object representing the kernel width and height. | ||
* @throws {Error} If the kernel size is invalid. | ||
* @throws {Error} If kernel size is even. | ||
*/ | ||
function validateKernel(kernelSize) { | ||
let kernel; | ||
if (typeof kernelSize === 'object' && 'width' in kernelSize && 'height' in kernelSize) { | ||
// This is a Size object, so no conversion required. | ||
kernel = kernelSize; | ||
} else if (typeof kernelSize === 'number' && Number.isInteger(kernelSize)) { | ||
// A single whole number is assumed as the width and height. | ||
kernel = { width: kernelSize, height: kernelSize }; | ||
} else if (Array.isArray(kernelSize) && kernelSize.length === 2 && kernelSize.every(Number.isInteger)) { | ||
// An array of two values is assumed as width then height. | ||
kernel = { width: kernelSize[0], height: kernelSize[1] }; | ||
} else { | ||
throw new Error("Invalid kernel size."); | ||
} | ||
|
||
if (kernel.width % 2 === 0 || kernel.height % 2 === 0) { | ||
throw new Error("Kernel size must be odd"); | ||
} | ||
|
||
return kernel; | ||
} | ||
Comment on lines
+1769
to
+1789
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The purpose of this function is to help users with passing in kernel values. To get a 3x3 kernel to the
Most likely we want a symmetrical kernel, and it's easier to just pass in a single number. |
||
|
||
/** | ||
* Creates a structuring element for morphological operations. | ||
* | ||
* This function is a JavaScript translation of the [OpenCV C++ function of the same name](https://github.com/egonSchiele/OpenCV/blob/master/modules/imgproc/src/morph.cpp#L981). | ||
* | ||
* @param {Shape} shape The shape of the kernel. | ||
* @param {Size} kernelSize The width and height of the kernel. | ||
* @param {Point} [anchor={x: -1, y: -1}] The central position of the kernel. | ||
* @returns {Array<Array<number>>} The structuring element as a 2D array. | ||
* @throws {Error} If the shape, or kernel size, is invalid is invalid. | ||
*/ | ||
function getStructuringElement(shape, kernelSize, anchor = { x: -1, y: -1 }) { | ||
if (!['RECT', 'CROSS', 'ELLIPSE'].includes(shape)) { | ||
throw new Error("Invalid shape. Must be 'RECT', 'CROSS', or 'ELLIPSE'."); | ||
} | ||
|
||
// Get a kernel object that represents the kernel width and height. | ||
let kernel = validateKernel(kernelSize); | ||
|
||
// Normalize anchor to default to the center if not specified. | ||
anchor = normalizeAnchor(anchor, kernel); | ||
|
||
// If the kernel size is 1x1, treat as a rectangle. | ||
if (kernel.width === 1 && kernel.height === 1) { | ||
shape = 'RECT'; | ||
} | ||
|
||
let rowRadius = 0; // Radius along the height. | ||
let colRadius = 0; // Radius along the width. | ||
let inverseRowRadiusSquared = 0; // Inverse squared radius for ellipses. | ||
|
||
if (shape === 'ELLIPSE') { | ||
// Calculate radii and inverse squared radius for the ellipse equation. | ||
rowRadius = Math.floor(kernel.height / 2); | ||
colRadius = Math.floor(kernel.width / 2); | ||
inverseRowRadiusSquared = rowRadius > 0 ? 1 / (rowRadius * rowRadius) : 0; | ||
} | ||
|
||
// Create a 2D array to represent the kernel. | ||
const kernelArray = Array.from({ length: kernel.height }, () => Array(kernel.width).fill(0)); | ||
|
||
for (let row = 0; row < kernel.height; row++) { | ||
let startColumn = 0; | ||
let endColumn = 0; | ||
|
||
if (shape === 'RECT' || (shape === 'CROSS' && row === anchor.y)) { | ||
// Full width for rectangle or horizontal line for cross shape. | ||
endColumn = kernel.width; | ||
} else if (shape === 'CROSS') { | ||
// Single column for cross shape. | ||
// A cross will be a single row and column, so only add 1. | ||
startColumn = anchor.x; | ||
endColumn = startColumn + 1; | ||
} else if (shape === 'ELLIPSE') { | ||
// Calculate elliptical bounds for this row. | ||
|
||
// Distance from the anchor row. | ||
const verticalOffset = row - anchor.y; | ||
|
||
if (Math.abs(verticalOffset) <= rowRadius) { | ||
// Solve for horizontal bounds using the ellipse equation: x^2/a^2 + y^2/b^2 = 1 | ||
const horizontalRadius = Math.floor( | ||
colRadius * Math.sqrt(Math.max(0, (rowRadius * rowRadius) - (verticalOffset * verticalOffset)) * inverseRowRadiusSquared) | ||
); | ||
|
||
// Left and right bound of the ellipse. | ||
// Add 1 to endColumn because it's not inclusive in the for loop. | ||
startColumn = Math.max(anchor.x - horizontalRadius, 0); | ||
endColumn = Math.min(anchor.x + horizontalRadius + 1, kernel.width); | ||
} | ||
} | ||
|
||
// Fill the kernel row with 1s within the range (startColumn, endColumn). | ||
for (let col = startColumn; col < endColumn; col++) { | ||
kernelArray[row][col] = 1; | ||
} | ||
} | ||
|
||
return kernelArray; | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Originally I was not using the if statement, and utilising the fact that a kernel contains either
0
or1
, however that meant thatkernelValues
array would be bigger than it needs to be.Also,
erode
wouldn't work properly because0
is always the minimum value.