Skip to content
Snippets Groups Projects
Commit b2ef7dc8 authored by Mathieu Beligon's avatar Mathieu Beligon
Browse files

[common] (classificationPipeline) add a test_proba_and_classes method

parent dc623cf7
No related branches found
No related tags found
No related merge requests found
......@@ -27,8 +27,13 @@ class ClassificationPipeline(Pipeline, Generic[IT, EnumT], ABC):
return super().fit(x, y_indices, **fit_params)
def predict(self, x: Sequence[IT]) -> List[EnumT]:
indices = asarray(self.predict_proba(x)).argmax(axis=1)
return [self.classes_[i] for i in indices]
return self.predict_proba_and_classes(x)[1]
def predict_proba_and_classes(self, x: Sequence[IT]) -> Tuple[ndarray, List[EnumT]]:
proba = asarray(self.predict_proba(x))
indices = proba.argmax(axis=1)
classes = [self.classes_[i] for i in indices]
return proba, classes
def score(self, x: Sequence[IT], y: List[EnumT], **score_params) -> float:
"""It is needed to have a proper CV"""
......
......@@ -52,6 +52,12 @@ class TestRuleBasedClassifier(TestCase):
self.pipeline.predict_proba(list("aacbz")),
)
def test_predict_proba_and_classes(self):
self.pipeline.classifier.n_classes = 3 # This is normally done during fitting
proba, classes = self.pipeline.predict_proba_and_classes(list("aacbz"))
array_equal(asarray([[1, 0, 0], [1, 0, 0], [0, 0, 1], [0, 1, 0], [0, 0, 1]]), proba)
self.assertEqual([Letter.A, Letter.A, Letter.Z, Letter.B, Letter.Z], classes)
class StrToIntPipe(PipeABC[str, int]):
def transform_single(self, example: str) -> int:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment