From e670a7acedf1d6af5d1f802904d54554f31d010c Mon Sep 17 00:00:00 2001
From: Mathieu Beligon <mathieu@feedly.com>
Date: Tue, 31 Mar 2020 17:11:31 -0400
Subject: [PATCH] [common] (image pipeline evaluation) Add size of each dataset

---
 .../image_dataset_generator.py                | 20 ++++++++++---------
 .../image_pipeline_evaluator.py               |  9 ++++++---
 .../dataset/armor_image_dataset_factory.py    |  9 +++------
 3 files changed, 20 insertions(+), 18 deletions(-)

diff --git a/common/research_common/image_pipeline_evaluation/image_dataset_generator.py b/common/research_common/image_pipeline_evaluation/image_dataset_generator.py
index b526836..b2713f6 100644
--- a/common/research_common/image_pipeline_evaluation/image_dataset_generator.py
+++ b/common/research_common/image_pipeline_evaluation/image_dataset_generator.py
@@ -8,14 +8,16 @@ T = TypeVar("T")
 
 
 class ImageDatasetGenerator(Generic[T]):
+    def from_roco_datasets(self, datasets: Iterable[DirectoryROCODataset]) -> Tuple[List[Image], List[T], List[int]]:
+        images, labels, dataset_sizes = [], [], []
+        for dataset in datasets:
+            prev_total_size = len(images)
+            for img, label in self.from_roco_dataset(dataset):
+                images.append(img)
+                labels.append(label)
+            dataset_sizes.append(len(images) - prev_total_size)
+        return images, labels, dataset_sizes
+
     @abstractmethod
-    def from_roco_dataset(self, dataset: DirectoryROCODataset) -> Tuple[List[Image], List[T]]:
+    def from_roco_dataset(self, dataset: DirectoryROCODataset) -> Iterable[Tuple[Image, T]]:
         pass
-
-    def from_roco_datasets(self, datasets: Iterable[DirectoryROCODataset]) -> Tuple[List[Image], List[T]]:
-        images, labels = [], []
-        for dataset in datasets:
-            imgs, lbls = self.from_roco_dataset(dataset)
-            images.extend(imgs)
-            labels.extend(lbls)
-        return images, labels
diff --git a/common/research_common/image_pipeline_evaluation/image_pipeline_evaluator.py b/common/research_common/image_pipeline_evaluation/image_pipeline_evaluator.py
index d952269..993aa88 100644
--- a/common/research_common/image_pipeline_evaluation/image_pipeline_evaluator.py
+++ b/common/research_common/image_pipeline_evaluation/image_pipeline_evaluator.py
@@ -8,7 +8,6 @@ from sklearn.metrics import classification_report
 from polystar.common.image_pipeline.image_pipeline import ImagePipeline
 from polystar.common.models.image import Image
 from research_common.dataset.directory_roco_dataset import DirectoryROCODataset
-from research_common.dataset.roco_dataset import ROCODataset
 from research_common.image_pipeline_evaluation.image_dataset_generator import ImageDatasetGenerator
 
 
@@ -31,8 +30,12 @@ class ImagePipelineEvaluator:
         logging.info("Loading data")
         self.train_roco_datasets = train_roco_datasets
         self.test_roco_datasets = test_roco_datasets
-        self.train_images, self.train_labels = image_dataset_generator.from_roco_datasets(train_roco_datasets)
-        self.test_images, self.test_labels = image_dataset_generator.from_roco_datasets(test_roco_datasets)
+        self.train_images, self.train_labels, self.train_dataset_sizes = image_dataset_generator.from_roco_datasets(
+            train_roco_datasets
+        )
+        self.test_images, self.test_labels, self.test_dataset_sizes = image_dataset_generator.from_roco_datasets(
+            test_roco_datasets
+        )
 
     def evaluate_pipelines(self, pipelines: Iterable[ImagePipeline]) -> Dict[str, ClassificationResults]:
         return {str(pipeline): self.evaluate(pipeline) for pipeline in pipelines}
diff --git a/robots-at-robots/research/dataset/armor_image_dataset_factory.py b/robots-at-robots/research/dataset/armor_image_dataset_factory.py
index 9ce0a1a..f90d0a3 100644
--- a/robots-at-robots/research/dataset/armor_image_dataset_factory.py
+++ b/robots-at-robots/research/dataset/armor_image_dataset_factory.py
@@ -1,6 +1,6 @@
 from abc import abstractmethod
 from pathlib import Path
-from typing import TypeVar, Tuple, List
+from typing import TypeVar, Tuple, List, Iterable
 
 from polystar.common.models.image import Image
 from polystar.common.models.object import ArmorColor
@@ -12,14 +12,11 @@ T = TypeVar("T")
 
 
 class ArmorImageDatasetGenerator(ImageDatasetGenerator[T]):
-    def from_roco_dataset(self, dataset: DirectoryROCODataset) -> Tuple[List[Image], List[T]]:
-        images, labels = [], []
+    def from_roco_dataset(self, dataset: DirectoryROCODataset) -> Iterable[Tuple[Image, T]]:
         for (armor_img, color, digit, k, path) in ArmorDatasetFactory.from_dataset(dataset):
             label = self._label(color, digit, k, path)
             if self._valid_label(label):
-                images.append(armor_img)
-                labels.append(label)
-        return images, labels
+                yield armor_img, label
 
     @abstractmethod
     def _label(self, color: ArmorColor, digit: int, k: int, path: Path) -> T:
-- 
GitLab