Skip to content

Commit

Permalink
Add context params and init OpenVINO method (#13)
Browse files Browse the repository at this point in the history
Signed-off-by: Miguel Álvarez <miguelwork92@gmail.com>
  • Loading branch information
GiviMAD authored Nov 11, 2023
1 parent 86d7879 commit 5a7821e
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 25 deletions.
1 change: 1 addition & 0 deletions gen_header.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ set -xe
LIB_SRC=src/main/java/io/github/givimad/whisperjni
javac -h src/main/native \
$LIB_SRC/internal/LibraryUtils.java \
$LIB_SRC/WhisperContextParams.java \
$LIB_SRC/WhisperContext.java \
$LIB_SRC/WhisperSamplingStrategy.java \
$LIB_SRC/WhisperFullParams.java \
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package io.github.givimad.whisperjni;

/**
* The {@link WhisperContextParams} class contains the {@link WhisperContext} params.
*
* @author Miguel Álvarez Díez - Initial contribution
*/
public class WhisperContextParams {
/**
* Enables GPU usage.
*/
public boolean useGPU = true;

/**
* Public constructor.
*/
public WhisperContextParams() {

}
}
53 changes: 42 additions & 11 deletions src/main/java/io/github/givimad/whisperjni/WhisperJNI.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.github.givimad.whisperjni.internal.LibraryUtils;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.function.Consumer;

Expand All @@ -16,12 +17,14 @@ public class WhisperJNI {
private static LibraryLogger libraryLogger;

//region native api
private native int init(String model);
private native int init(String model, WhisperContextParams params);

private native int initNoState(String model);
private native int initNoState(String model, WhisperContextParams params);

private native int initState(int model);

private native void initOpenVINOEncoder(int model, String device);

private native boolean isMultilingual(int model);

private native int full(int context, WhisperFullParams params, float[] samples, int numSamples);
Expand Down Expand Up @@ -62,9 +65,23 @@ public class WhisperJNI {
* @throws IOException if model file is missing.
*/
public WhisperContext init(Path model) throws IOException {
var absModelPath = model.toAbsolutePath();
assertModelExists(model, absModelPath);
return new WhisperContext(this, init(absModelPath.toString()));
return init(model, null);
}

/**
* Creates a new whisper context.
*
* @param model {@link Path} to the whisper ggml model file.
* @param params {@link WhisperContextParams} params for context initialization.
* @return A new {@link WhisperContext}.
* @throws IOException if model file is missing.
*/
public WhisperContext init(Path model, WhisperContextParams params) throws IOException {
assertModelExists(model);
if(params == null) {
params = new WhisperContextParams();
}
return new WhisperContext(this, init(model.toAbsolutePath().toString(), params));
}

/**
Expand All @@ -75,9 +92,23 @@ public WhisperContext init(Path model) throws IOException {
* @throws IOException if model file is missing.
*/
public WhisperContext initNoState(Path model) throws IOException {
var absModelPath = model.toAbsolutePath();
assertModelExists(model, absModelPath);
return new WhisperContext(this, initNoState(absModelPath.toString()));
return initNoState(model, null);
}

/**
* Creates a new whisper context without state.
*
* @param model {@link Path} to the whisper ggml model file.
* @param params {@link WhisperContextParams} params for context initialization.
* @return A new {@link WhisperContext} without state.
* @throws IOException if model file is missing.
*/
public WhisperContext initNoState(Path model, WhisperContextParams params) throws IOException {
assertModelExists(model);
if(params == null) {
params = new WhisperContextParams();
}
return new WhisperContext(this, initNoState(model.toAbsolutePath().toString(), params));
}

/**
Expand Down Expand Up @@ -392,9 +423,9 @@ protected void release() {
}
}

private static void assertModelExists(Path model, Path path) throws IOException {
if (!model.toFile().exists() || !model.toFile().isFile()) {
throw new IOException("Missing model file: " + path);
private static void assertModelExists(Path model) throws IOException {
if (!Files.exists(model) || Files.isDirectory(model)) {
throw new IOException("Missing model file: " + model);
}
}
}
33 changes: 23 additions & 10 deletions src/main/native/io_github_givimad_whisperjni_WhisperJNI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,21 @@ int insertModel(whisper_context *ctx)
contextMap.insert({ref, ctx});
return ref;
}
struct whisper_full_params parseJParams(JNIEnv *env, jobject jParams)

struct whisper_context_params newWhisperContextParams(JNIEnv *env, jobject jParams)
{
jclass paramsJClass = env->GetObjectClass(jParams);
struct whisper_context_params params = whisper_context_default_params();
params.use_gpu = (jboolean)env->GetBooleanField(jParams, env->GetFieldID(paramsJClass, "useGPU", "Z"));
return params;
}
struct whisper_full_params newWhisperFullParams(JNIEnv *env, jobject jParams)
{
jclass paramsJClass = env->GetObjectClass(jParams);

struct whisper_full_params params = whisper_full_default_params(
(whisper_sampling_strategy)env->GetIntField(jParams, env->GetFieldID(paramsJClass, "strategy", "I")));
// int params
whisper_sampling_strategy samplingStrategy = (whisper_sampling_strategy)env->GetIntField(jParams, env->GetFieldID(paramsJClass, "strategy", "I"));
struct whisper_full_params params = whisper_full_default_params(samplingStrategy);

int nThreads = (jint)env->GetIntField(jParams, env->GetFieldID(paramsJClass, "nThreads", "I"));
if (nThreads > 0)
{
Expand Down Expand Up @@ -82,16 +90,16 @@ struct whisper_full_params parseJParams(JNIEnv *env, jobject jParams)
return params;
}

JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_init(JNIEnv *env, jobject thisObject, jstring modelPath)
JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_init(JNIEnv *env, jobject thisObject, jstring modelPath, jobject jParams)
{
const char *path = env->GetStringUTFChars(modelPath, NULL);
return insertModel(whisper_init_from_file_with_params(path, whisper_context_default_params()));
return insertModel(whisper_init_from_file_with_params(path, newWhisperContextParams(env, jParams)));
}

JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_initNoState(JNIEnv *env, jobject thisObject, jstring modelPath)
JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_initNoState(JNIEnv *env, jobject thisObject, jstring modelPath, jobject jParams)
{
const char *path = env->GetStringUTFChars(modelPath, NULL);
return insertModel(whisper_init_from_file_with_params_no_state(path, whisper_context_default_params()));
return insertModel(whisper_init_from_file_with_params_no_state(path, newWhisperContextParams(env, jParams)));
}

JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_initState(JNIEnv *env, jobject thisObject, jint ctxRef)
Expand All @@ -102,6 +110,11 @@ JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_initState(JN
return stateRef;
}

JNIEXPORT void JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_initOpenVINOEncoder(JNIEnv *env, jobject thisObject, jint ctxRef, jstring deviceString) {
const char *device = env->GetStringUTFChars(deviceString, NULL);
whisper_ctx_init_openvino_encoder(contextMap.at(ctxRef), nullptr, device, nullptr);
}

JNIEXPORT jboolean JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_isMultilingual(JNIEnv *env, jobject thisObject, jint ctxRef)
{
return whisper_is_multilingual(contextMap.at(ctxRef));
Expand All @@ -110,13 +123,13 @@ JNIEXPORT jboolean JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_isMultil
JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_full(JNIEnv *env, jobject thisObject, jint ctxRef, jobject jParams, jfloatArray samples, jint numSamples)
{
const float *samplesPointer = env->GetFloatArrayElements(samples, NULL);
return whisper_full(contextMap.at(ctxRef), parseJParams(env, jParams), samplesPointer, numSamples);
return whisper_full(contextMap.at(ctxRef), newWhisperFullParams(env, jParams), samplesPointer, numSamples);
}

JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullWithState(JNIEnv *env, jobject thisObject, jint ctxRef, jint stateRef, jobject jParams, jfloatArray samples, jint numSamples)
{
const float *samplesPointer = env->GetFloatArrayElements(samples, NULL);
return whisper_full_with_state(contextMap.at(ctxRef), stateMap.at(stateRef), parseJParams(env, jParams), samplesPointer, numSamples);
return whisper_full_with_state(contextMap.at(ctxRef), stateMap.at(stateRef), newWhisperFullParams(env, jParams), samplesPointer, numSamples);
}

JNIEXPORT jint JNICALL Java_io_github_givimad_whisperjni_WhisperJNI_fullNSegments(JNIEnv *env, jobject thisObject, jint ctxRef)
Expand Down
16 changes: 12 additions & 4 deletions src/main/native/io_github_givimad_whisperjni_WhisperJNI.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 5a7821e

Please sign in to comment.