Skip to content

Commit

Permalink
Add sample utility
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Nov 14, 2024
1 parent 5f8cee4 commit cf2c47a
Show file tree
Hide file tree
Showing 12 changed files with 220 additions and 32 deletions.
22 changes: 12 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,6 @@ export declare enum DType {
BFloat16
}

/**
* Optional options describing the tensor.
*/
export interface TensorOptions {
shape?: number[];
dimOrder?: number[];
strides?: number[];
}

type Nested<T> = Nested<T>[] | T;

/**
Expand All @@ -171,7 +162,9 @@ export declare class Tensor {
* @param options.dimOrder
* @param options.strides
*/
constructor(input: Nested<boolean | number> | Uint8Array, dtype?: DType, { shape, dimOrder, strides }?: TensorOptions);
constructor(input: Nested<boolean | number> | Uint8Array,
dtype?: DType,
{ shape, dimOrder, strides }?: { shape?: number[]; dimOrder?: number[]; strides?: number[]; });
/**
* Return the tensor as a scalar.
*/
Expand Down Expand Up @@ -209,6 +202,15 @@ export declare class Tensor {
*/
get itemsize(): number;
}

/**
* Samples from the given tensor using a softmax over logits.
*/
export declare function sample(logits: Tensor,
{
temperature = 1,
topP = 1,
}?: { temperature?: number; topP?: number }): number;
```

## Development
Expand Down
1 change: 1 addition & 0 deletions bindings.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,4 @@ export const backends: Backends;
export const config: 'Debug' | 'Release';

export function elementSize(dtype: number): number;
export function sample(tensor: Tensor, temperature: number, topP: number): number;
34 changes: 34 additions & 0 deletions lib/common.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import bindings from '../bindings.js';
import type {Tensor} from './tensor.js';

/**
* Data type.
*/
export enum DType {
Uint8 = bindings.ScalarType.Byte,
Int8 = bindings.ScalarType.Char,
Int16 = bindings.ScalarType.Short,
Int32 = bindings.ScalarType.Int,
Float16 = bindings.ScalarType.Half,
Float32 = bindings.ScalarType.Float,
Float64 = bindings.ScalarType.Double,
Bool = bindings.ScalarType.Bool,
BFloat16 = bindings.ScalarType.BFloat16,
}

/**
* Samples from the given tensor using a softmax over logits.
*/
export function sample(logits: Tensor,
{
temperature = 1,
topP = 1,
}: {temperature?: number, topP?: number} = {}) {
if (logits.size == 0)
throw new Error('The logits must not be empty.');
if (logits.ndim == 0 ||
logits.ndim > 2 ||
logits.ndim == 2 && logits.shape[0] != 1)
throw new Error('The shape of logits must be [N] or [1, N].');
return bindings.sample(logits.holder, temperature, topP);
}
2 changes: 1 addition & 1 deletion lib/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export {backends, config} from '../bindings.js';
export {DType} from './scalar.js';
export {DType, sample} from './common.js';
export {Module} from './module.js';
export {Tensor} from './tensor.js';
2 changes: 1 addition & 1 deletion lib/module.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import bindings from '../bindings.js';
import {DType} from './scalar.js';
import {DType} from './common.js';
import {Tensor} from './tensor.js';

/**
Expand Down
16 changes: 0 additions & 16 deletions lib/scalar.ts

This file was deleted.

4 changes: 2 additions & 2 deletions lib/tensor.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import bindings from '../bindings.js';
import {DType} from './scalar.js';
import {DType} from './common.js';

type Nested<T> = Nested<T>[] | T;

Expand Down Expand Up @@ -30,7 +30,7 @@ export class Tensor {
readonly shape: number[];

// Internal binding to the executorch::aten::Tensor instance.
private readonly holder: bindings.Tensor;
readonly holder: bindings.Tensor;

/**
* @param input - A scalar, or a (nested) Array, or a Uint8Array buffer.
Expand Down
4 changes: 3 additions & 1 deletion src/bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "src/evalue.h"
#include "src/module.h"
#include "src/sample.h"
#include "src/scalar.h"
#include "src/tensor.h"

Expand Down Expand Up @@ -42,7 +43,8 @@ napi_value Init(napi_env env, napi_value exports) {
#else
"config", "Release",
#endif
"elementSize", &er::elementSize);
"elementSize", &er::elementSize,
"sample", &etjs::Sample);
return exports;
}

Expand Down
133 changes: 133 additions & 0 deletions src/sample.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#include "src/sample.h"

#include <random>

#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>

#include "src/tensor.h"

namespace etjs {

namespace {

template<typename T>
struct ProbIndex {
T prob;
size_t index;
};

template<typename T>
size_t SampleArgMax(T* probs, size_t size) {
size_t max_i = 0;
T max_p = probs[0];
for (size_t i = 1; i < size; i++) {
if (probs[i] > max_p) {
max_i = i;
max_p = probs[i];
}
}
return max_i;
}

template<typename T>
size_t SampleMult(const std::vector<T>& probs, float coin) {
T cdf{};
for (size_t i = 0; i < probs.size(); i++) {
cdf += probs[i];
if (coin < cdf)
return i;
}
return probs.size() - 1;
}

template<typename T>
size_t SampleTopP(const std::vector<T>& probs, float top_p, float coin) {
size_t n0 = 0;
std::vector<ProbIndex<T>> probindex(probs.size());

float cutoff = (1.0f - top_p) / (probs.size() - 1);
for (size_t i = 0; i < probs.size(); i++) {
if (probs[i] >= cutoff) {
probindex[n0].index = i;
probindex[n0].prob = probs[i];
n0++;
}
}

std::sort(probindex.begin(), probindex.end(),
[](const auto& a, const auto& b) { return a.prob > b.prob; });

T cumulative_prob = 0;
int32_t last_idx = n0 - 1;
for (size_t i = 0; i < n0; i++) {
cumulative_prob += probindex[i].prob;
if (cumulative_prob > top_p) {
last_idx = i;
break;
}
}

T r = coin * cumulative_prob;
T cdf = 0;
for (size_t i = 0; i <= last_idx; i++) {
cdf += probindex[i].prob;
if (r < cdf)
return probindex[i].index;
}
return probindex[last_idx].index;
}

template<typename T>
void Softmax(std::vector<T>& x) {
T max_val = *std::max_element(x.begin(), x.end());

T sum = 0;
for (size_t i = 0; i < x.size(); i++) {
x[i] = std::expf(x[i] - max_val);
sum += x[i];
}

for (size_t i = 0; i < x.size(); i++) {
x[i] = x[i] / sum;
}
}

float RandomF32() {
static std::mt19937 engine;
static std::uniform_real_distribution<float> distribution(0.0, 1.0);
return distribution(engine);
}

template<typename T>
size_t Sample(T* input, size_t size, float temperature, float top_p) {
if (temperature == 0)
return SampleArgMax(input, size);

std::vector<T> logits(input, input + size);
for (size_t i = 0; i < size; i++)
logits[i] = logits[i] / temperature;

Softmax(logits);

float coin = RandomF32();
if (top_p <= 0 || top_p >= 1)
return SampleMult(logits, coin);
else
return SampleTopP(logits, top_p, coin);
}

} // namespace

size_t Sample(Tensor* tensor, float temperature, float top_p) {
ET_CHECK_MSG(tensor->size() > 0, "Tensor can not be empty");
ET_CHECK_MSG(tensor->ndim() == 1 ||
(tensor->ndim() == 2 && tensor->shape()[0] == 1),
"Tensor's shape must be [N] or [1, N].");
size_t ret = 0;
ET_SWITCH_REALHBBF16_TYPES(tensor->dtype(), nullptr, "sample", CTYPE, [&] {
ret = Sample(tensor->data<CTYPE>(), tensor->size(), temperature, top_p);
});
return ret;
}

} // namespace etjs
14 changes: 14 additions & 0 deletions src/sample.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#ifndef SRC_SAMPLE_H_
#define SRC_SAMPLE_H_

#include <stddef.h>

namespace etjs {

class Tensor;

size_t Sample(Tensor* tensor, float temperature, float top_p);

} // namespace etjs

#endif // SRC_SAMPLE_H_
2 changes: 1 addition & 1 deletion src/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class Tensor {
size_t itemsize() const { return impl_.element_size(); }

template<typename T>
const T* data() const { return static_cast<const T*>(data_.data); }
T* data() { return static_cast<T*>(data_.data); }

private:
Buffer data_;
Expand Down
18 changes: 18 additions & 0 deletions tests/sample.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import {DType, Tensor, sample} from '..';
import {assert} from 'chai';

describe('Sample', () => {
it('argmax', () => {
const logits = Array.from({length: 128}, () => Math.random());
logits[89] = 1;
const index = sample(new Tensor(logits), {temperature: 0});
assert.equal(index, 89);
});

it('argmax bfloat16', () => {
const logits = Array.from({length: 128}, () => Math.random());
logits[64] = 1;
const index = sample(new Tensor(logits, DType.BFloat16), {temperature: 0});
assert.equal(index, 64);
});
});

0 comments on commit cf2c47a

Please sign in to comment.