From cf2c47afaa8b7b3c073455b4991f5533fd5ba717 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 14 Nov 2024 17:36:02 +0900 Subject: [PATCH] Add sample utility --- README.md | 22 +++---- bindings.d.ts | 1 + lib/common.ts | 34 +++++++++++ lib/index.ts | 2 +- lib/module.ts | 2 +- lib/scalar.ts | 16 ------ lib/tensor.ts | 4 +- src/bindings.cc | 4 +- src/sample.cc | 133 +++++++++++++++++++++++++++++++++++++++++++ src/sample.h | 14 +++++ src/tensor.h | 2 +- tests/sample.spec.ts | 18 ++++++ 12 files changed, 220 insertions(+), 32 deletions(-) create mode 100644 lib/common.ts delete mode 100644 lib/scalar.ts create mode 100644 src/sample.cc create mode 100644 src/sample.h create mode 100644 tests/sample.spec.ts diff --git a/README.md b/README.md index 8d1d0f5..cb70168 100644 --- a/README.md +++ b/README.md @@ -136,15 +136,6 @@ export declare enum DType { BFloat16 } -/** - * Optional options describing the tensor. - */ -export interface TensorOptions { - shape?: number[]; - dimOrder?: number[]; - strides?: number[]; -} - type Nested = Nested[] | T; /** @@ -171,7 +162,9 @@ export declare class Tensor { * @param options.dimOrder * @param options.strides */ - constructor(input: Nested | Uint8Array, dtype?: DType, { shape, dimOrder, strides }?: TensorOptions); + constructor(input: Nested | Uint8Array, + dtype?: DType, + { shape, dimOrder, strides }?: { shape?: number[]; dimOrder?: number[]; strides?: number[]; }); /** * Return the tensor as a scalar. */ @@ -209,6 +202,15 @@ export declare class Tensor { */ get itemsize(): number; } + +/** + * Samples from the given tensor using a softmax over logits. + */ +export declare function sample(logits: Tensor, + { + temperature = 1, + topP = 1, + }?: { temperature?: number; topP?: number }): number; ``` ## Development diff --git a/bindings.d.ts b/bindings.d.ts index 34bf6ee..746d732 100644 --- a/bindings.d.ts +++ b/bindings.d.ts @@ -85,3 +85,4 @@ export const backends: Backends; export const config: 'Debug' | 'Release'; export function elementSize(dtype: number): number; +export function sample(tensor: Tensor, temperature: number, topP: number): number; diff --git a/lib/common.ts b/lib/common.ts new file mode 100644 index 0000000..b4c818d --- /dev/null +++ b/lib/common.ts @@ -0,0 +1,34 @@ +import bindings from '../bindings.js'; +import type {Tensor} from './tensor.js'; + +/** + * Data type. + */ +export enum DType { + Uint8 = bindings.ScalarType.Byte, + Int8 = bindings.ScalarType.Char, + Int16 = bindings.ScalarType.Short, + Int32 = bindings.ScalarType.Int, + Float16 = bindings.ScalarType.Half, + Float32 = bindings.ScalarType.Float, + Float64 = bindings.ScalarType.Double, + Bool = bindings.ScalarType.Bool, + BFloat16 = bindings.ScalarType.BFloat16, +} + +/** + * Samples from the given tensor using a softmax over logits. + */ +export function sample(logits: Tensor, + { + temperature = 1, + topP = 1, + }: {temperature?: number, topP?: number} = {}) { + if (logits.size == 0) + throw new Error('The logits must not be empty.'); + if (logits.ndim == 0 || + logits.ndim > 2 || + logits.ndim == 2 && logits.shape[0] != 1) + throw new Error('The shape of logits must be [N] or [1, N].'); + return bindings.sample(logits.holder, temperature, topP); +} diff --git a/lib/index.ts b/lib/index.ts index 9164b7a..a70a92e 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -1,4 +1,4 @@ export {backends, config} from '../bindings.js'; -export {DType} from './scalar.js'; +export {DType, sample} from './common.js'; export {Module} from './module.js'; export {Tensor} from './tensor.js'; diff --git a/lib/module.ts b/lib/module.ts index b608305..d2e1206 100644 --- a/lib/module.ts +++ b/lib/module.ts @@ -1,5 +1,5 @@ import bindings from '../bindings.js'; -import {DType} from './scalar.js'; +import {DType} from './common.js'; import {Tensor} from './tensor.js'; /** diff --git a/lib/scalar.ts b/lib/scalar.ts deleted file mode 100644 index f71f5eb..0000000 --- a/lib/scalar.ts +++ /dev/null @@ -1,16 +0,0 @@ -import bindings from '../bindings.js'; - -/** - * Data type. - */ -export enum DType { - Uint8 = bindings.ScalarType.Byte, - Int8 = bindings.ScalarType.Char, - Int16 = bindings.ScalarType.Short, - Int32 = bindings.ScalarType.Int, - Float16 = bindings.ScalarType.Half, - Float32 = bindings.ScalarType.Float, - Float64 = bindings.ScalarType.Double, - Bool = bindings.ScalarType.Bool, - BFloat16 = bindings.ScalarType.BFloat16, -} diff --git a/lib/tensor.ts b/lib/tensor.ts index 53b83e8..0ed0eb4 100644 --- a/lib/tensor.ts +++ b/lib/tensor.ts @@ -1,5 +1,5 @@ import bindings from '../bindings.js'; -import {DType} from './scalar.js'; +import {DType} from './common.js'; type Nested = Nested[] | T; @@ -30,7 +30,7 @@ export class Tensor { readonly shape: number[]; // Internal binding to the executorch::aten::Tensor instance. - private readonly holder: bindings.Tensor; + readonly holder: bindings.Tensor; /** * @param input - A scalar, or a (nested) Array, or a Uint8Array buffer. diff --git a/src/bindings.cc b/src/bindings.cc index 1ca9052..a89be5b 100644 --- a/src/bindings.cc +++ b/src/bindings.cc @@ -3,6 +3,7 @@ #include "src/evalue.h" #include "src/module.h" +#include "src/sample.h" #include "src/scalar.h" #include "src/tensor.h" @@ -42,7 +43,8 @@ napi_value Init(napi_env env, napi_value exports) { #else "config", "Release", #endif - "elementSize", &er::elementSize); + "elementSize", &er::elementSize, + "sample", &etjs::Sample); return exports; } diff --git a/src/sample.cc b/src/sample.cc new file mode 100644 index 0000000..99c3ac3 --- /dev/null +++ b/src/sample.cc @@ -0,0 +1,133 @@ +#include "src/sample.h" + +#include + +#include + +#include "src/tensor.h" + +namespace etjs { + +namespace { + +template +struct ProbIndex { + T prob; + size_t index; +}; + +template +size_t SampleArgMax(T* probs, size_t size) { + size_t max_i = 0; + T max_p = probs[0]; + for (size_t i = 1; i < size; i++) { + if (probs[i] > max_p) { + max_i = i; + max_p = probs[i]; + } + } + return max_i; +} + +template +size_t SampleMult(const std::vector& probs, float coin) { + T cdf{}; + for (size_t i = 0; i < probs.size(); i++) { + cdf += probs[i]; + if (coin < cdf) + return i; + } + return probs.size() - 1; +} + +template +size_t SampleTopP(const std::vector& probs, float top_p, float coin) { + size_t n0 = 0; + std::vector> probindex(probs.size()); + + float cutoff = (1.0f - top_p) / (probs.size() - 1); + for (size_t i = 0; i < probs.size(); i++) { + if (probs[i] >= cutoff) { + probindex[n0].index = i; + probindex[n0].prob = probs[i]; + n0++; + } + } + + std::sort(probindex.begin(), probindex.end(), + [](const auto& a, const auto& b) { return a.prob > b.prob; }); + + T cumulative_prob = 0; + int32_t last_idx = n0 - 1; + for (size_t i = 0; i < n0; i++) { + cumulative_prob += probindex[i].prob; + if (cumulative_prob > top_p) { + last_idx = i; + break; + } + } + + T r = coin * cumulative_prob; + T cdf = 0; + for (size_t i = 0; i <= last_idx; i++) { + cdf += probindex[i].prob; + if (r < cdf) + return probindex[i].index; + } + return probindex[last_idx].index; +} + +template +void Softmax(std::vector& x) { + T max_val = *std::max_element(x.begin(), x.end()); + + T sum = 0; + for (size_t i = 0; i < x.size(); i++) { + x[i] = std::expf(x[i] - max_val); + sum += x[i]; + } + + for (size_t i = 0; i < x.size(); i++) { + x[i] = x[i] / sum; + } +} + +float RandomF32() { + static std::mt19937 engine; + static std::uniform_real_distribution distribution(0.0, 1.0); + return distribution(engine); +} + +template +size_t Sample(T* input, size_t size, float temperature, float top_p) { + if (temperature == 0) + return SampleArgMax(input, size); + + std::vector logits(input, input + size); + for (size_t i = 0; i < size; i++) + logits[i] = logits[i] / temperature; + + Softmax(logits); + + float coin = RandomF32(); + if (top_p <= 0 || top_p >= 1) + return SampleMult(logits, coin); + else + return SampleTopP(logits, top_p, coin); +} + +} // namespace + +size_t Sample(Tensor* tensor, float temperature, float top_p) { + ET_CHECK_MSG(tensor->size() > 0, "Tensor can not be empty"); + ET_CHECK_MSG(tensor->ndim() == 1 || + (tensor->ndim() == 2 && tensor->shape()[0] == 1), + "Tensor's shape must be [N] or [1, N]."); + size_t ret = 0; + ET_SWITCH_REALHBBF16_TYPES(tensor->dtype(), nullptr, "sample", CTYPE, [&] { + ret = Sample(tensor->data(), tensor->size(), temperature, top_p); + }); + return ret; +} + +} // namespace etjs diff --git a/src/sample.h b/src/sample.h new file mode 100644 index 0000000..c9afbd7 --- /dev/null +++ b/src/sample.h @@ -0,0 +1,14 @@ +#ifndef SRC_SAMPLE_H_ +#define SRC_SAMPLE_H_ + +#include + +namespace etjs { + +class Tensor; + +size_t Sample(Tensor* tensor, float temperature, float top_p); + +} // namespace etjs + +#endif // SRC_SAMPLE_H_ diff --git a/src/tensor.h b/src/tensor.h index 9ed664e..0f73e98 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -42,7 +42,7 @@ class Tensor { size_t itemsize() const { return impl_.element_size(); } template - const T* data() const { return static_cast(data_.data); } + T* data() { return static_cast(data_.data); } private: Buffer data_; diff --git a/tests/sample.spec.ts b/tests/sample.spec.ts new file mode 100644 index 0000000..f15199a --- /dev/null +++ b/tests/sample.spec.ts @@ -0,0 +1,18 @@ +import {DType, Tensor, sample} from '..'; +import {assert} from 'chai'; + +describe('Sample', () => { + it('argmax', () => { + const logits = Array.from({length: 128}, () => Math.random()); + logits[89] = 1; + const index = sample(new Tensor(logits), {temperature: 0}); + assert.equal(index, 89); + }); + + it('argmax bfloat16', () => { + const logits = Array.from({length: 128}, () => Math.random()); + logits[64] = 1; + const index = sample(new Tensor(logits, DType.BFloat16), {temperature: 0}); + assert.equal(index, 64); + }); +});