From ee40a5b3b6babfeb38c833c7fed91a1367e3109f Mon Sep 17 00:00:00 2001 From: Dusan Malusev Date: Tue, 16 Apr 2024 20:18:10 +0200 Subject: [PATCH] Bring back multiline predictions Signed-off-by: Dusan Malusev --- fasttext.go | 28 ++++++++++++++-------------- fasttext_test.go | 42 +++++++++++++++++++++++++----------------- 2 files changed, 39 insertions(+), 31 deletions(-) diff --git a/fasttext.go b/fasttext.go index 8a13ada..c7704a7 100644 --- a/fasttext.go +++ b/fasttext.go @@ -65,24 +65,24 @@ func (handle *Model) Close() error { return nil } -// func (handle *Model) MultiLinePredict(lines []string, k int32, threshoad float32) ([]Predictions, error) { -// predics := make([]Predictions, 0, len(lines)) +func (handle *Model) MultiLinePredict(lines []string, k int32, threshoad float32) ([]Predictions, error) { + predics := make([]Predictions, 0, len(lines)) -// for _, line := range lines { -// predictions, err := handle.Predict(line, k, threshoad) -// if err != nil && errors.Is(err, ErrPredictionFailed) { -// return nil, err -// } + for _, line := range lines { + predictions, err := handle.Predict(line, k, threshoad) + if err != nil && errors.Is(err, ErrPredictionFailed) { + return nil, err + } -// predics = append(predics, predictions) -// } + predics = append(predics, predictions) + } -// if len(predics) == 0 { -// return nil, ErrNoPredictions -// } + if len(predics) == 0 { + return nil, ErrNoPredictions + } -// return predics, nil -// } + return predics, nil +} // func (handle *Model) PredictOne(query string, threshoad float32) (Prediction, error) { // r := C.FastText_Predict( diff --git a/fasttext_test.go b/fasttext_test.go index 6b19d76..6c316fe 100644 --- a/fasttext_test.go +++ b/fasttext_test.go @@ -12,6 +12,7 @@ func TestOpen(t *testing.T) { assert := require.New(t) t.Run("Success", func(t *testing.T) { + t.Parallel() model, err := fasttext.Open("testdata/lid.176.ftz") assert.NoError(err) @@ -20,6 +21,7 @@ func TestOpen(t *testing.T) { }) t.Run("FailedToOpen", func(t *testing.T) { + t.Parallel() model, err := fasttext.Open("testdata/lid-not-found.176.ftz") assert.EqualError(err, "testdata/lid-not-found.176.ftz cannot be opened for loading!") @@ -42,34 +44,36 @@ func TestOpen(t *testing.T) { // assert.Greater(prediction.Probability, float32(0.7)) // } -// func TestMultilinePredict(t *testing.T) { -// t.Parallel() -// assert := require.New(t) +func TestMultilinePredict(t *testing.T) { + t.Parallel() + assert := require.New(t) -// model, err := fasttext.Open("testdata/lid.176.ftz") + model, err := fasttext.Open("testdata/lid.176.ftz") -// assert.NoError(err) + assert.NoError(err) -// predictions, err := model.MultiLinePredict([]string{ -// "Πες γεια στον μικρό μου φίλο", -// "Say 'ello to my little friend", -// }, 1, 0.5) + predictions, err := model.MultiLinePredict([]string{ + "Πες γεια στον μικρό μου φίλο", + "Say 'ello to my little friend", + }, 1, 0.5) -// assert.NoError(err) -// assert.NotEmpty(predictions) -// assert.Len(predictions, 2) + assert.NoError(err) + assert.NotEmpty(predictions) + assert.Len(predictions, 2) -// assert.Len(predictions[0], 1) -// assert.Equal(predictions[0][0].Label, "el") // el => for greek -// assert.Len(predictions[1], 1) -// assert.Equal(predictions[1][0].Label, "en") -// } + assert.Len(predictions[0], 1) + assert.Equal(predictions[0][0].Label, "el") // el => for greek + assert.Len(predictions[1], 1) + assert.Equal(predictions[1][0].Label, "en") +} func TestPredict(t *testing.T) { t.Parallel() + assert := require.New(t) t.Run("WithOnePrediction", func(t *testing.T) { + t.Parallel() model, err := fasttext.Open("testdata/lid.176.ftz") assert.NoError(err) @@ -113,6 +117,8 @@ func TestPredict(t *testing.T) { }) t.Run("WithMultiple", func(t *testing.T) { + t.Parallel() + model, err := fasttext.Open("testdata/lid.176.ftz") assert.NoError(err) @@ -128,6 +134,8 @@ func TestPredict(t *testing.T) { }) t.Run("Gibberish", func(t *testing.T) { + t.Parallel() + model, err := fasttext.Open("testdata/lid.176.ftz") assert.NoError(err)