From 39f68a74864eca5e21d4e5455439b133140bb73a Mon Sep 17 00:00:00 2001
From: Mathieu Beligon <mathieu@feedly.com>
Date: Tue, 31 Mar 2020 20:36:30 -0400
Subject: [PATCH] [robots] (armor dataset factory) store images in files after
 extraction

---
 .../dataset/armor_color_dataset_factory.py    |  4 ++-
 .../dataset/armor_digit_dataset_factory.py    |  7 ++++-
 .../dataset/armor_image_dataset_factory.py    | 30 +++++++++++++++----
 3 files changed, 34 insertions(+), 7 deletions(-)

diff --git a/robots-at-robots/research/dataset/armor_color_dataset_factory.py b/robots-at-robots/research/dataset/armor_color_dataset_factory.py
index 83cf323..893947f 100644
--- a/robots-at-robots/research/dataset/armor_color_dataset_factory.py
+++ b/robots-at-robots/research/dataset/armor_color_dataset_factory.py
@@ -5,5 +5,7 @@ from research.dataset.armor_image_dataset_factory import ArmorImageDatasetGenera
 
 
 class ArmorColorDatasetGenerator(ArmorImageDatasetGenerator[str]):
-    def _label(self, color: ArmorColor, digit: int, k: int, path: Path) -> str:
+    task_name: str = "colors"
+
+    def _label_from_armor_info(self, color: ArmorColor, digit: int, k: int, path: Path) -> str:
         return color.name
diff --git a/robots-at-robots/research/dataset/armor_digit_dataset_factory.py b/robots-at-robots/research/dataset/armor_digit_dataset_factory.py
index f1ed6f1..8e421cf 100644
--- a/robots-at-robots/research/dataset/armor_digit_dataset_factory.py
+++ b/robots-at-robots/research/dataset/armor_digit_dataset_factory.py
@@ -6,10 +6,15 @@ from research.dataset.armor_image_dataset_factory import ArmorImageDatasetGenera
 
 
 class ArmorDigitDatasetGenerator(ArmorImageDatasetGenerator[int]):
+    task_name: str = "digits"
+
     def __init__(self, acceptable_digits: Set[int]):
         self.acceptable_digits = acceptable_digits
 
-    def _label(self, color: ArmorColor, number: int, k: int, path: Path) -> int:
+    def _label_from_str(self, label: str) -> int:
+        return int(label)
+
+    def _label_from_armor_info(self, color: ArmorColor, number: int, k: int, path: Path) -> int:
         return number
 
     def _valid_label(self, label: int) -> bool:
diff --git a/robots-at-robots/research/dataset/armor_image_dataset_factory.py b/robots-at-robots/research/dataset/armor_image_dataset_factory.py
index f90d0a3..85dd6cc 100644
--- a/robots-at-robots/research/dataset/armor_image_dataset_factory.py
+++ b/robots-at-robots/research/dataset/armor_image_dataset_factory.py
@@ -1,6 +1,9 @@
 from abc import abstractmethod
+from dataclasses import dataclass
 from pathlib import Path
-from typing import TypeVar, Tuple, List, Iterable
+from typing import TypeVar, Tuple, Iterable
+
+import cv2
 
 from polystar.common.models.image import Image
 from polystar.common.models.object import ArmorColor
@@ -12,15 +15,32 @@ T = TypeVar("T")
 
 
 class ArmorImageDatasetGenerator(ImageDatasetGenerator[T]):
+    task_name: str
+
     def from_roco_dataset(self, dataset: DirectoryROCODataset) -> Iterable[Tuple[Image, T]]:
+        if not (dataset.dataset_path / self.task_name).exists():
+            self._create_labelized_armor_images_from_roco(dataset)
+        return self._get_saved_images_and_labels(dataset)
+
+    def _create_labelized_armor_images_from_roco(self, dataset):
+        dset_path = dataset.dataset_path / self.task_name
+        dset_path.mkdir()
         for (armor_img, color, digit, k, path) in ArmorDatasetFactory.from_dataset(dataset):
-            label = self._label(color, digit, k, path)
-            if self._valid_label(label):
-                yield armor_img, label
+            label = self._label_from_armor_info(color, digit, k, path)
+            cv2.imwrite(str(dset_path / f"{path.stem}-{k}-{label}.jpg"), cv2.cvtColor(armor_img, cv2.COLOR_RGB2BGR))
+
+    def _get_saved_images_and_labels(self, dataset: DirectoryROCODataset) -> Iterable[Tuple[Image, T]]:
+        return (
+            (Image.from_path(image_path), self._label_from_str(image_path.stem.split("-")[-1]))
+            for image_path in (dataset.dataset_path / self.task_name).glob("*.jpg")
+        )
 
     @abstractmethod
-    def _label(self, color: ArmorColor, digit: int, k: int, path: Path) -> T:
+    def _label_from_armor_info(self, color: ArmorColor, digit: int, k: int, path: Path) -> T:
         pass
 
     def _valid_label(self, label: T) -> bool:
         return True
+
+    def _label_from_str(self, label: str) -> T:
+        return label
-- 
GitLab