Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement scatterND #120

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions src/lib/validate-input.js
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,69 @@ export function validateScatterElementsParams(input, indices, updates, {axis = 0
}
}

export function validateScatterNDParams(input, indices, updates) {
// Refer to https://onnx.ai/onnx/operators/onnx__ScatterND.html
// ScatterND takes three inputs data tensor of rank r >= 1, indices tensor of rank q >= 1,
// and updates tensor of rank q + r - indices.shape[-1] - 1.

const inputRank = input.rank;
if (inputRank < 1) {
throw new Error(`The input should be at least a 1-D tensor.`);
}

const indicesRank = indices.rank;
if (indicesRank < 1) {
throw new Error(`The indices should be at least a 1-D tensor.`);
}

const indicesShape = indices.shape;
const lastIndicesSize = indicesShape[indicesRank - 1];
if (lastIndicesSize < 1 || lastIndicesSize > inputRank) {
throw new Error(`The indices.shape[-1] should be in the range [1, ${inputRank}].`);
}

const inputShape = input.shape;
const indicesTotal = sizeOfShape(indicesShape);
const updatedLocationDict = {};
for (let indicesIndex = 0; indicesIndex < indicesTotal; indicesIndex += lastIndicesSize) {
const originIndicesArray = [];
const indicesArray = [];
for (let i = 0; i < lastIndicesSize; i++) {
const indicesValue = indices.getValueByIndex(indicesIndex + i);
const maxSize = inputShape[i];
if (!Number.isInteger(indicesValue) ||
indicesValue < -maxSize ||
indicesValue > maxSize - 1) {
throw new Error(`Invalid indices value - it should be an integer in the interval ` +
`[${-maxSize}, ${maxSize - 1}]`);
}
originIndicesArray.push(indicesValue);
indicesArray.push(indicesValue >= 0 ? indicesValue : inputShape[i] + indicesValue);
}
const locationString = indicesArray.toString();
if (Object.hasOwn(updatedLocationDict, locationString)) {
throw new Error(`Invalid indices, [${originIndicesArray}] and ` +
`[${updatedLocationDict[locationString]}] point to the same output location.`);
} else {
updatedLocationDict[locationString] = originIndicesArray;
}
}

const updatesRank = updates.rank;
const targetUpdatesRank = indicesRank + inputRank - lastIndicesSize -1;
if (updatesRank !== targetUpdatesRank) {
throw new Error(
`Invalid updates value - updates rank should be equal to ${targetUpdatesRank}.`);
}

const updatesShape = updates.shape;
const targetUpdatesShape =
indicesShape.slice(0, indicesRank - 1).concat(inputShape.slice(lastIndicesSize));
if (!updatesShape.every((size, index) => size === targetUpdatesShape[index])) {
throw new Error(`Invalid updates shape, it should be [${targetUpdatesShape}].`);
}
}

export function validateGatherElementsParams(input, indices, {axis = 0} = {}) {
const inputRank = input.rank;
const indicesRank = indices.rank;
Expand Down
51 changes: 51 additions & 0 deletions src/scatter_nd.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
'use strict';

import {Tensor, sizeOfShape} from './lib/tensor.js';
import {validateScatterNDParams} from './lib/validate-input.js';
import {identity} from './unary.js';

/**
* Scatter values using multidimensional indices.
* @param {Tensor} input
* @param {Tensor} indices
* @param {Tensor} updates
* @return {Tensor}
*/
export function scatterND(input, indices, updates) {
// Refer to https://onnx.ai/onnx/operators/onnx__ScatterND.html
// output = np.copy(data)
// update_indices = indices.shape[:-1]
// for idx in np.ndindex(update_indices):
// output[indices[idx]] = updates[idx]

validateScatterNDParams(input, indices, updates);

const output = identity(input);
const inputRank = input.rank;
const inputShape = input.shape;
const indicesRank = indices.rank;
const indicesShape = indices.shape;
const indicesTotal = sizeOfShape(indicesShape);
const lastIndicesSize = indicesShape[indicesRank - 1];
const tmpShape = inputShape.slice(lastIndicesSize, inputRank);
const tmp = new Tensor(tmpShape);
const tmpTotal = sizeOfShape(tmpShape);

for (let indicesIndex = 0; indicesIndex < indicesTotal; indicesIndex += lastIndicesSize) {
const indicesLocation = indices.locationFromIndex(indicesIndex);
const indicesArray = [];
for (let i = 0; i < lastIndicesSize; i++) {
const indicesValue = indices.getValueByIndex(indicesIndex + i);
indicesArray.push(indicesValue >= 0 ? indicesValue : inputShape[i] + indicesValue);
}
for (let tmpIndex = 0; tmpIndex < tmpTotal; ++tmpIndex) {
const tmpLocation = tmp.locationFromIndex(tmpIndex);
const outputLocation = indicesArray.concat(tmpLocation);
const updateValue =
updates.getValueByLocation(indicesLocation.slice(0, indicesRank - 1).concat(tmpLocation));
output.setValueByLocation(outputLocation, updateValue);
}
}

return output;
}
124 changes: 124 additions & 0 deletions test/scatter_nd_test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
'use strict';

import {scatterND} from '../src/scatter_nd.js';
import {Tensor} from '../src/lib/tensor.js';
import * as utils from './utils.js';

describe('test scatterND', function() {
function testscatterND(input, indices, updates, expected) {
const inputTensor = new Tensor(input.shape, input.data);
const indicesTensor = new Tensor(indices.shape, indices.data);
const updatesTensor = new Tensor(updates.shape, updates.data);
const outputTensor = scatterND(inputTensor, indicesTensor, updatesTensor);
utils.checkShape(outputTensor, expected.shape);
utils.checkValue(outputTensor, expected.data);
}

it('scatterND to insert individual elements in a tensor by index', function() {
// Refer to Example 1 on https://onnx.ai/onnx/operators/onnx__ScatterND.html
const input = {
shape: [8],
data: [1, 2, 3, 4, 5, 6, 7, 8],
};
const indices = {
shape: [4, 1],
data: [4, 3, 1, 7],
};
const updates = {
shape: [4],
data: [9, 10, 11, 12],
};
const expected = {
shape: [8],
data: [1, 11, 3, 10, 9, 6, 7, 12],
};
testscatterND(input, indices, updates, expected);
});

it('scatterND to insert entire slices of a higher rank tensor', function() {
// Refer to Example 2 on https://onnx.ai/onnx/operators/onnx__ScatterND.html
const input = {
shape: [4, 4, 4],
data: [
1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8,
8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8,
],
};
const indices = {
shape: [2, 1],
data: [0, 2],
};
const updates = {
shape: [2, 4, 4],
data: [
5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
],
};
const expected = {
shape: [4, 4, 4],
data: [
5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8,
],
};
testscatterND(input, indices, updates, expected);
});

it('scatterND with negative indices', function() {
const input = {
shape: [4, 4, 4],
data: [
1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8,
8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8,
],
};
const indices = {
shape: [2, 1],
data: [-4, -2],
};
const updates = {
shape: [2, 4, 4],
data: [
5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
],
};
const expected = {
shape: [4, 4, 4],
data: [
5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4,
8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8,
],
};
testscatterND(input, indices, updates, expected);
});

it('scatterND with 0D updates', function() {
const input = {
shape: [2, 2],
data: [1, 2, 3, 4],
};
const indices = {
shape: [2],
data: [1, 0],
};
const updates = {
shape: [],
data: [100],
};
const expected = {
shape: [2, 2],
data: [1, 2, 100, 4],
};
testscatterND(input, indices, updates, expected);
});
});
Loading