diff --git a/common/polystar/common/pipeline/classification/classification_pipeline.py b/common/polystar/common/pipeline/classification/classification_pipeline.py index becb457b4e59311ed81171682e94b25a13591e22..074cc15e5942abe0593f547939c64747224b1ba5 100644 --- a/common/polystar/common/pipeline/classification/classification_pipeline.py +++ b/common/polystar/common/pipeline/classification/classification_pipeline.py @@ -27,8 +27,13 @@ class ClassificationPipeline(Pipeline, Generic[IT, EnumT], ABC): return super().fit(x, y_indices, **fit_params) def predict(self, x: Sequence[IT]) -> List[EnumT]: - indices = asarray(self.predict_proba(x)).argmax(axis=1) - return [self.classes_[i] for i in indices] + return self.predict_proba_and_classes(x)[1] + + def predict_proba_and_classes(self, x: Sequence[IT]) -> Tuple[ndarray, List[EnumT]]: + proba = asarray(self.predict_proba(x)) + indices = proba.argmax(axis=1) + classes = [self.classes_[i] for i in indices] + return proba, classes def score(self, x: Sequence[IT], y: List[EnumT], **score_params) -> float: """It is needed to have a proper CV""" diff --git a/common/tests/common/unittests/image_pipeline/test_image_classifier_pipeline.py b/common/tests/common/unittests/image_pipeline/test_image_classifier_pipeline.py index a8d360fa790692fff02aa797ff74005d82ca9cae..13321eb01e044aee1bf8222a72bb9b123af0d5cc 100644 --- a/common/tests/common/unittests/image_pipeline/test_image_classifier_pipeline.py +++ b/common/tests/common/unittests/image_pipeline/test_image_classifier_pipeline.py @@ -52,6 +52,12 @@ class TestRuleBasedClassifier(TestCase): self.pipeline.predict_proba(list("aacbz")), ) + def test_predict_proba_and_classes(self): + self.pipeline.classifier.n_classes = 3 # This is normally done during fitting + proba, classes = self.pipeline.predict_proba_and_classes(list("aacbz")) + array_equal(asarray([[1, 0, 0], [1, 0, 0], [0, 0, 1], [0, 1, 0], [0, 0, 1]]), proba) + self.assertEqual([Letter.A, Letter.A, Letter.Z, Letter.B, Letter.Z], classes) + class StrToIntPipe(PipeABC[str, int]): def transform_single(self, example: str) -> int: