From b2ef7dc8643884648d71000c980aa250c5f16ba5 Mon Sep 17 00:00:00 2001
From: Mathieu Beligon <mathieu@feedly.com>
Date: Sat, 12 Dec 2020 13:22:23 +0100
Subject: [PATCH] [common] (classificationPipeline) add a
 test_proba_and_classes method

---
 .../pipeline/classification/classification_pipeline.py   | 9 +++++++--
 .../image_pipeline/test_image_classifier_pipeline.py     | 6 ++++++
 2 files changed, 13 insertions(+), 2 deletions(-)

diff --git a/common/polystar/common/pipeline/classification/classification_pipeline.py b/common/polystar/common/pipeline/classification/classification_pipeline.py
index becb457..074cc15 100644
--- a/common/polystar/common/pipeline/classification/classification_pipeline.py
+++ b/common/polystar/common/pipeline/classification/classification_pipeline.py
@@ -27,8 +27,13 @@ class ClassificationPipeline(Pipeline, Generic[IT, EnumT], ABC):
         return super().fit(x, y_indices, **fit_params)
 
     def predict(self, x: Sequence[IT]) -> List[EnumT]:
-        indices = asarray(self.predict_proba(x)).argmax(axis=1)
-        return [self.classes_[i] for i in indices]
+        return self.predict_proba_and_classes(x)[1]
+
+    def predict_proba_and_classes(self, x: Sequence[IT]) -> Tuple[ndarray, List[EnumT]]:
+        proba = asarray(self.predict_proba(x))
+        indices = proba.argmax(axis=1)
+        classes = [self.classes_[i] for i in indices]
+        return proba, classes
 
     def score(self, x: Sequence[IT], y: List[EnumT], **score_params) -> float:
         """It is needed to have a proper CV"""
diff --git a/common/tests/common/unittests/image_pipeline/test_image_classifier_pipeline.py b/common/tests/common/unittests/image_pipeline/test_image_classifier_pipeline.py
index a8d360f..13321eb 100644
--- a/common/tests/common/unittests/image_pipeline/test_image_classifier_pipeline.py
+++ b/common/tests/common/unittests/image_pipeline/test_image_classifier_pipeline.py
@@ -52,6 +52,12 @@ class TestRuleBasedClassifier(TestCase):
             self.pipeline.predict_proba(list("aacbz")),
         )
 
+    def test_predict_proba_and_classes(self):
+        self.pipeline.classifier.n_classes = 3  # This is normally done during fitting
+        proba, classes = self.pipeline.predict_proba_and_classes(list("aacbz"))
+        array_equal(asarray([[1, 0, 0], [1, 0, 0], [0, 0, 1], [0, 1, 0], [0, 0, 1]]), proba)
+        self.assertEqual([Letter.A, Letter.A, Letter.Z, Letter.B, Letter.Z], classes)
+
 
 class StrToIntPipe(PipeABC[str, int]):
     def transform_single(self, example: str) -> int:
-- 
GitLab