From b2ef7dc8643884648d71000c980aa250c5f16ba5 Mon Sep 17 00:00:00 2001 From: Mathieu Beligon <mathieu@feedly.com> Date: Sat, 12 Dec 2020 13:22:23 +0100 Subject: [PATCH] [common] (classificationPipeline) add a test_proba_and_classes method --- .../pipeline/classification/classification_pipeline.py | 9 +++++++-- .../image_pipeline/test_image_classifier_pipeline.py | 6 ++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/common/polystar/common/pipeline/classification/classification_pipeline.py b/common/polystar/common/pipeline/classification/classification_pipeline.py index becb457..074cc15 100644 --- a/common/polystar/common/pipeline/classification/classification_pipeline.py +++ b/common/polystar/common/pipeline/classification/classification_pipeline.py @@ -27,8 +27,13 @@ class ClassificationPipeline(Pipeline, Generic[IT, EnumT], ABC): return super().fit(x, y_indices, **fit_params) def predict(self, x: Sequence[IT]) -> List[EnumT]: - indices = asarray(self.predict_proba(x)).argmax(axis=1) - return [self.classes_[i] for i in indices] + return self.predict_proba_and_classes(x)[1] + + def predict_proba_and_classes(self, x: Sequence[IT]) -> Tuple[ndarray, List[EnumT]]: + proba = asarray(self.predict_proba(x)) + indices = proba.argmax(axis=1) + classes = [self.classes_[i] for i in indices] + return proba, classes def score(self, x: Sequence[IT], y: List[EnumT], **score_params) -> float: """It is needed to have a proper CV""" diff --git a/common/tests/common/unittests/image_pipeline/test_image_classifier_pipeline.py b/common/tests/common/unittests/image_pipeline/test_image_classifier_pipeline.py index a8d360f..13321eb 100644 --- a/common/tests/common/unittests/image_pipeline/test_image_classifier_pipeline.py +++ b/common/tests/common/unittests/image_pipeline/test_image_classifier_pipeline.py @@ -52,6 +52,12 @@ class TestRuleBasedClassifier(TestCase): self.pipeline.predict_proba(list("aacbz")), ) + def test_predict_proba_and_classes(self): + self.pipeline.classifier.n_classes = 3 # This is normally done during fitting + proba, classes = self.pipeline.predict_proba_and_classes(list("aacbz")) + array_equal(asarray([[1, 0, 0], [1, 0, 0], [0, 0, 1], [0, 1, 0], [0, 0, 1]]), proba) + self.assertEqual([Letter.A, Letter.A, Letter.Z, Letter.B, Letter.Z], classes) + class StrToIntPipe(PipeABC[str, int]): def transform_single(self, example: str) -> int: -- GitLab