diff --git a/README.md b/README.md index c4650c1..d3e5af8 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ fastText-jni is a Java wrapper for [fastText](https://github.com/facebookresearc ## Usage -`implementation group: 'com.diffbot', name: 'fasttext-jni', version: '0.9.2.7'` +`implementation group: 'com.diffbot', name: 'fasttext-jni', version: '0.9.2.8'` ```java FastTextModel model; diff --git a/build.gradle b/build.gradle index 1d4c862..e7be8d9 100644 --- a/build.gradle +++ b/build.gradle @@ -4,7 +4,7 @@ plugins { } group = 'com.diffbot' -version = '0.9.2.7' +version = '0.9.2.8' sourceCompatibility = 11 targetCompatibility = 11 diff --git a/src/main/java/com/diffbot/fasttext/FastTextModel.java b/src/main/java/com/diffbot/fasttext/FastTextModel.java index c2e0c24..4d47d63 100644 --- a/src/main/java/com/diffbot/fasttext/FastTextModel.java +++ b/src/main/java/com/diffbot/fasttext/FastTextModel.java @@ -27,6 +27,7 @@ public FastTextModel(InputStream inputStream) throws IOException { private native void load(ByteBuffer byteBuffer); public native Prediction predictProba(String s); public native Prediction[] predictProbaTopK(String s, int k); + public native Prediction[] predictProbaWithThreshold(String s, float threshold); @Override public native void close(); diff --git a/src/main/native/fasttext_jni.cc b/src/main/native/fasttext_jni.cc index 10347ab..30b034b 100644 --- a/src/main/native/fasttext_jni.cc +++ b/src/main/native/fasttext_jni.cc @@ -102,9 +102,7 @@ JNIEXPORT jobjectArray JNICALL Java_com_diffbot_fasttext_FastTextModel_predictPr return NULL; } - if (predictions.size() < k) { - k = predictions.size(); - } + k = predictions.size(); jobjectArray top_k_predictions = env->NewObjectArray(k, predictionClass, nullptr); for (int i = 0; i < k; i++) { @@ -118,6 +116,36 @@ JNIEXPORT jobjectArray JNICALL Java_com_diffbot_fasttext_FastTextModel_predictPr return top_k_predictions; } +JNIEXPORT jobjectArray JNICALL Java_com_diffbot_fasttext_FastTextModel_predictProbaWithThreshold +(JNIEnv *env, jobject obj, jstring s, jfloat threshold) { + + jboolean isCopy; + const char* utf_string = env->GetStringUTFChars(s, &isCopy); + FastTextWrapper *ft = (FastTextWrapper *) env->GetLongField(obj, handleFieldID); + std::vector> predictions; + std::string text = utf_string; + std::stringstream stream; + stream << text << '\n'; + bool result = ft->predictLine(stream, predictions, -1, threshold); + env->ReleaseStringUTFChars(s, utf_string); + if (!result || predictions.size() == 0) { + return NULL; + } + + jsize k = predictions.size(); + jobjectArray ret = env->NewObjectArray(k, predictionClass, nullptr); + + for (int i = 0; i < k; i++) { + std::pair prediction = predictions[i]; + + jstring label = env->NewStringUTF(prediction.second.c_str()); + jobject pred = env->NewObject(predictionClass, predictionConstructor, prediction.first, label); + env->SetObjectArrayElement(ret, i, pred); + } + + return ret; +} + JNIEXPORT void JNICALL Java_com_diffbot_fasttext_FastTextModel_close (JNIEnv *env, jobject obj) { FastTextWrapper *ft = (FastTextWrapper *) env->GetLongField(obj, handleFieldID); diff --git a/src/main/native/fasttext_jni.h b/src/main/native/fasttext_jni.h index 37a2168..02a1d2e 100644 --- a/src/main/native/fasttext_jni.h +++ b/src/main/native/fasttext_jni.h @@ -23,13 +23,21 @@ JNIEXPORT void JNICALL Java_com_diffbot_fasttext_FastTextModel_load JNIEXPORT jobject JNICALL Java_com_diffbot_fasttext_FastTextModel_predictProba (JNIEnv *, jobject, jstring); - /* - * Class: com_diffbot_fasttext_FastTextModel - * Method: predictProbaTopK - * Signature: (Ljava/lang/String;I)[Lcom/diffbot/fasttext/Prediction; - */ - JNIEXPORT jobjectArray JNICALL Java_com_diffbot_fasttext_FastTextModel_predictProbaTopK - (JNIEnv *, jobject, jstring, jint); +/* + * Class: com_diffbot_fasttext_FastTextModel + * Method: predictProbaTopK + * Signature: (Ljava/lang/String;I)[Lcom/diffbot/fasttext/Prediction; + */ +JNIEXPORT jobjectArray JNICALL Java_com_diffbot_fasttext_FastTextModel_predictProbaTopK + (JNIEnv *, jobject, jstring, jint); + +/* + * Class: com_diffbot_fasttext_FastTextModel + * Method: predictProbaWithThreshold + * Signature: (Ljava/lang/String;F)[Lcom/diffbot/fasttext/Prediction; + */ +JNIEXPORT jobjectArray JNICALL Java_com_diffbot_fasttext_FastTextModel_predictProbaWithThreshold + (JNIEnv *, jobject, jstring, jfloat); /* * Class: com_diffbot_fasttext_FastTextModel diff --git a/src/test/java/com/diffbot/fasttext/FastTextTest.java b/src/test/java/com/diffbot/fasttext/FastTextTest.java index b6ac25b..5586ccd 100644 --- a/src/test/java/com/diffbot/fasttext/FastTextTest.java +++ b/src/test/java/com/diffbot/fasttext/FastTextTest.java @@ -11,7 +11,6 @@ public class FastTextTest { @Test public void modelTest() throws IOException { - Prediction prediction; FastTextModel model; /* * language identification @@ -21,16 +20,34 @@ public void modelTest() throws IOException { model = new FastTextModel(inputStream); } - prediction = model.predictProba("Web Data for your AI Imagine if your app could access the web like a structured database ."); - System.out.println(prediction.label + " : " + prediction.probability); - Assertions.assertEquals("__label__en", prediction.label); + { + Prediction prediction = model.predictProba("Web Data for your AI Imagine if your app could access the web like a structured database ."); + System.out.println(prediction.label + " : " + prediction.probability); + Assertions.assertEquals("__label__en", prediction.label); + } - prediction = model.predictProba("AI のための Web データ アプリが構造化データベースのように Web にアクセスできるかどうかを想像してください。"); - System.out.println(prediction.label + " : " + prediction.probability); - Assertions.assertEquals("__label__ja", prediction.label); + { + Prediction prediction = model.predictProba("AI のための Web データ アプリが構造化データベースのように Web にアクセスできるかどうかを想像してください。"); + System.out.println(prediction.label + " : " + prediction.probability); + Assertions.assertEquals("__label__ja", prediction.label); + } - Prediction[] predictions = model.predictProbaTopK(" ", -1); - Assertions.assertEquals(predictions.length, 176); + { + Prediction[] predictions = model.predictProbaTopK(" ", -1); + Assertions.assertEquals(176, predictions.length); + } + + { + float threshold = 0.01f; + Prediction[] predictions = model.predictProbaWithThreshold(" ", threshold); + Assertions.assertEquals(24, predictions.length); + double prev = Float.MAX_VALUE; + for (Prediction prediction : predictions) { + Assertions.assertTrue(prediction.probability >= threshold); + Assertions.assertTrue(prediction.probability <= prev); + prev = prediction.probability; + } + } model.close(); @@ -42,7 +59,7 @@ public void modelTest() throws IOException { model = new FastTextModel(inputStream); } - prediction = model.predictProba("web data for your ai imagine if your app could access the web like a structured database ."); + Prediction prediction = model.predictProba("web data for your ai imagine if your app could access the web like a structured database ."); System.out.println(prediction.label + " : " + prediction.probability); Assertions.assertEquals("__label__4", prediction.label); @@ -54,7 +71,7 @@ public void modelTest() throws IOException { System.out.println(p.label + " : " + p.probability); } - Assertions.assertEquals(array.length, 4); + Assertions.assertEquals(4, array.length); model.close(); }