Skip to content

Commit

Permalink
Add predictProbaWithThreshold method
Browse files Browse the repository at this point in the history
  • Loading branch information
schmidek committed Jun 10, 2024
1 parent 3a9cee1 commit f6c5230
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 23 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ fastText-jni is a Java wrapper for [fastText](https://github.com/facebookresearc

## Usage

`implementation group: 'com.diffbot', name: 'fasttext-jni', version: '0.9.2.7'`
`implementation group: 'com.diffbot', name: 'fasttext-jni', version: '0.9.2.8'`

```java
FastTextModel model;
Expand Down
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ plugins {
}

group = 'com.diffbot'
version = '0.9.2.7'
version = '0.9.2.8'

sourceCompatibility = 11
targetCompatibility = 11
Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/diffbot/fasttext/FastTextModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public FastTextModel(InputStream inputStream) throws IOException {
private native void load(ByteBuffer byteBuffer);
public native Prediction predictProba(String s);
public native Prediction[] predictProbaTopK(String s, int k);
public native Prediction[] predictProbaWithThreshold(String s, float threshold);

@Override
public native void close();
Expand Down
34 changes: 31 additions & 3 deletions src/main/native/fasttext_jni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ JNIEXPORT jobjectArray JNICALL Java_com_diffbot_fasttext_FastTextModel_predictPr
return NULL;
}

if (predictions.size() < k) {
k = predictions.size();
}
k = predictions.size();
jobjectArray top_k_predictions = env->NewObjectArray(k, predictionClass, nullptr);

for (int i = 0; i < k; i++) {
Expand All @@ -118,6 +116,36 @@ JNIEXPORT jobjectArray JNICALL Java_com_diffbot_fasttext_FastTextModel_predictPr
return top_k_predictions;
}

JNIEXPORT jobjectArray JNICALL Java_com_diffbot_fasttext_FastTextModel_predictProbaWithThreshold
(JNIEnv *env, jobject obj, jstring s, jfloat threshold) {

jboolean isCopy;
const char* utf_string = env->GetStringUTFChars(s, &isCopy);
FastTextWrapper *ft = (FastTextWrapper *) env->GetLongField(obj, handleFieldID);
std::vector<std::pair<fasttext::real,std::string>> predictions;
std::string text = utf_string;
std::stringstream stream;
stream << text << '\n';
bool result = ft->predictLine(stream, predictions, -1, threshold);
env->ReleaseStringUTFChars(s, utf_string);
if (!result || predictions.size() == 0) {
return NULL;
}

jsize k = predictions.size();
jobjectArray ret = env->NewObjectArray(k, predictionClass, nullptr);

for (int i = 0; i < k; i++) {
std::pair<fasttext::real, std::string> prediction = predictions[i];

jstring label = env->NewStringUTF(prediction.second.c_str());
jobject pred = env->NewObject(predictionClass, predictionConstructor, prediction.first, label);
env->SetObjectArrayElement(ret, i, pred);
}

return ret;
}

JNIEXPORT void JNICALL Java_com_diffbot_fasttext_FastTextModel_close
(JNIEnv *env, jobject obj) {
FastTextWrapper *ft = (FastTextWrapper *) env->GetLongField(obj, handleFieldID);
Expand Down
22 changes: 15 additions & 7 deletions src/main/native/fasttext_jni.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,21 @@ JNIEXPORT void JNICALL Java_com_diffbot_fasttext_FastTextModel_load
JNIEXPORT jobject JNICALL Java_com_diffbot_fasttext_FastTextModel_predictProba
(JNIEnv *, jobject, jstring);

/*
* Class: com_diffbot_fasttext_FastTextModel
* Method: predictProbaTopK
* Signature: (Ljava/lang/String;I)[Lcom/diffbot/fasttext/Prediction;
*/
JNIEXPORT jobjectArray JNICALL Java_com_diffbot_fasttext_FastTextModel_predictProbaTopK
(JNIEnv *, jobject, jstring, jint);
/*
* Class: com_diffbot_fasttext_FastTextModel
* Method: predictProbaTopK
* Signature: (Ljava/lang/String;I)[Lcom/diffbot/fasttext/Prediction;
*/
JNIEXPORT jobjectArray JNICALL Java_com_diffbot_fasttext_FastTextModel_predictProbaTopK
(JNIEnv *, jobject, jstring, jint);

/*
* Class: com_diffbot_fasttext_FastTextModel
* Method: predictProbaWithThreshold
* Signature: (Ljava/lang/String;F)[Lcom/diffbot/fasttext/Prediction;
*/
JNIEXPORT jobjectArray JNICALL Java_com_diffbot_fasttext_FastTextModel_predictProbaWithThreshold
(JNIEnv *, jobject, jstring, jfloat);

/*
* Class: com_diffbot_fasttext_FastTextModel
Expand Down
39 changes: 28 additions & 11 deletions src/test/java/com/diffbot/fasttext/FastTextTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ public class FastTextTest {

@Test
public void modelTest() throws IOException {
Prediction prediction;
FastTextModel model;
/*
* language identification
Expand All @@ -21,16 +20,34 @@ public void modelTest() throws IOException {
model = new FastTextModel(inputStream);
}

prediction = model.predictProba("Web Data for your AI Imagine if your app could access the web like a structured database .");
System.out.println(prediction.label + " : " + prediction.probability);
Assertions.assertEquals("__label__en", prediction.label);
{
Prediction prediction = model.predictProba("Web Data for your AI Imagine if your app could access the web like a structured database .");
System.out.println(prediction.label + " : " + prediction.probability);
Assertions.assertEquals("__label__en", prediction.label);
}

prediction = model.predictProba("AI のための Web データ アプリが構造化データベースのように Web にアクセスできるかどうかを想像してください。");
System.out.println(prediction.label + " : " + prediction.probability);
Assertions.assertEquals("__label__ja", prediction.label);
{
Prediction prediction = model.predictProba("AI のための Web データ アプリが構造化データベースのように Web にアクセスできるかどうかを想像してください。");
System.out.println(prediction.label + " : " + prediction.probability);
Assertions.assertEquals("__label__ja", prediction.label);
}

Prediction[] predictions = model.predictProbaTopK(" ", -1);
Assertions.assertEquals(predictions.length, 176);
{
Prediction[] predictions = model.predictProbaTopK(" ", -1);
Assertions.assertEquals(176, predictions.length);
}

{
float threshold = 0.01f;
Prediction[] predictions = model.predictProbaWithThreshold(" ", threshold);
Assertions.assertEquals(24, predictions.length);
double prev = Float.MAX_VALUE;
for (Prediction prediction : predictions) {
Assertions.assertTrue(prediction.probability >= threshold);
Assertions.assertTrue(prediction.probability <= prev);
prev = prediction.probability;
}
}

model.close();

Expand All @@ -42,7 +59,7 @@ public void modelTest() throws IOException {
model = new FastTextModel(inputStream);
}

prediction = model.predictProba("web data for your ai imagine if your app could access the web like a structured database .");
Prediction prediction = model.predictProba("web data for your ai imagine if your app could access the web like a structured database .");
System.out.println(prediction.label + " : " + prediction.probability);
Assertions.assertEquals("__label__4", prediction.label);

Expand All @@ -54,7 +71,7 @@ public void modelTest() throws IOException {
System.out.println(p.label + " : " + p.probability);
}

Assertions.assertEquals(array.length, 4);
Assertions.assertEquals(4, array.length);

model.close();
}
Expand Down

0 comments on commit f6c5230

Please sign in to comment.