From a1c9642b00b080d0b085b7e3d04ccb98318ad79e Mon Sep 17 00:00:00 2001
From: Mathieu Beligon <mathieu@feedly.com>
Date: Tue, 31 Mar 2020 21:20:57 -0400
Subject: [PATCH] [robots] (armor factory) yield armor object instead of its
 color and number only

---
 .../research/dataset/armor_color_dataset_factory.py    |  6 +++---
 .../research/dataset/armor_dataset_factory.py          | 10 ++++------
 .../research/dataset/armor_digit_dataset_factory.py    |  6 +++---
 .../research/dataset/armor_image_dataset_factory.py    |  8 ++++----
 4 files changed, 14 insertions(+), 16 deletions(-)

diff --git a/robots-at-robots/research/dataset/armor_color_dataset_factory.py b/robots-at-robots/research/dataset/armor_color_dataset_factory.py
index 893947f..c601fa1 100644
--- a/robots-at-robots/research/dataset/armor_color_dataset_factory.py
+++ b/robots-at-robots/research/dataset/armor_color_dataset_factory.py
@@ -1,11 +1,11 @@
 from pathlib import Path
 
-from polystar.common.models.object import ArmorColor
+from polystar.common.models.object import Armor
 from research.dataset.armor_image_dataset_factory import ArmorImageDatasetGenerator
 
 
 class ArmorColorDatasetGenerator(ArmorImageDatasetGenerator[str]):
     task_name: str = "colors"
 
-    def _label_from_armor_info(self, color: ArmorColor, digit: int, k: int, path: Path) -> str:
-        return color.name
+    def _label_from_armor_info(self, armor: Armor, k: int, path: Path) -> str:
+        return armor.color.name
diff --git a/robots-at-robots/research/dataset/armor_dataset_factory.py b/robots-at-robots/research/dataset/armor_dataset_factory.py
index 7940046..5a47acb 100644
--- a/robots-at-robots/research/dataset/armor_dataset_factory.py
+++ b/robots-at-robots/research/dataset/armor_dataset_factory.py
@@ -13,14 +13,12 @@ from research_common.dataset.roco_dataset import ROCODataset
 
 class ArmorDatasetFactory:
     @staticmethod
-    def from_image_annotation(
-        image_annotation: ImageAnnotation,
-    ) -> Iterable[Tuple[Image, ArmorColor, ArmorNumber, int, Path]]:
+    def from_image_annotation(image_annotation: ImageAnnotation) -> Iterable[Tuple[Image, Armor, int, Path]]:
         img = image_annotation.image
         armors: List[Armor] = TypeObjectValidator(ObjectType.Armor).filter(image_annotation.objects, img)
         for i, obj in enumerate(armors):
             croped_img = img[obj.y : obj.y + obj.h, obj.x : obj.x + obj.w]
-            yield croped_img, obj.color, obj.numero, i, image_annotation.image_path
+            yield croped_img, obj, i, image_annotation.image_path
 
     @staticmethod
     def from_dataset(dataset: ROCODataset) -> Iterable[Tuple[Image, ArmorColor, ArmorNumber, int, Path]]:
@@ -30,8 +28,8 @@ class ArmorDatasetFactory:
 
 
 if __name__ == "__main__":
-    for i, (armor_img, c, n, k, p) in enumerate(ArmorDatasetFactory.from_dataset(DJIROCODataset.CentralChina)):
-        print(c, n, k, "in", p)
+    for i, (armor_img, armor, k, p) in enumerate(ArmorDatasetFactory.from_dataset(DJIROCODataset.CentralChina)):
+        print(armor, k, "in", p)
         plt.imshow(armor_img)
         plt.show()
         plt.clf()
diff --git a/robots-at-robots/research/dataset/armor_digit_dataset_factory.py b/robots-at-robots/research/dataset/armor_digit_dataset_factory.py
index 8e421cf..a38d4a1 100644
--- a/robots-at-robots/research/dataset/armor_digit_dataset_factory.py
+++ b/robots-at-robots/research/dataset/armor_digit_dataset_factory.py
@@ -1,7 +1,7 @@
 from pathlib import Path
 from typing import Set
 
-from polystar.common.models.object import ArmorColor
+from polystar.common.models.object import Armor
 from research.dataset.armor_image_dataset_factory import ArmorImageDatasetGenerator
 
 
@@ -14,8 +14,8 @@ class ArmorDigitDatasetGenerator(ArmorImageDatasetGenerator[int]):
     def _label_from_str(self, label: str) -> int:
         return int(label)
 
-    def _label_from_armor_info(self, color: ArmorColor, number: int, k: int, path: Path) -> int:
-        return number
+    def _label_from_armor_info(self, armor: Armor, k: int, path: Path) -> int:
+        return armor.numero
 
     def _valid_label(self, label: int) -> bool:
         return label in self.acceptable_digits
diff --git a/robots-at-robots/research/dataset/armor_image_dataset_factory.py b/robots-at-robots/research/dataset/armor_image_dataset_factory.py
index 7bff809..b211c63 100644
--- a/robots-at-robots/research/dataset/armor_image_dataset_factory.py
+++ b/robots-at-robots/research/dataset/armor_image_dataset_factory.py
@@ -6,7 +6,7 @@ from typing import TypeVar, Tuple, Iterable
 import cv2
 
 from polystar.common.models.image import Image
-from polystar.common.models.object import ArmorColor
+from polystar.common.models.object import Armor
 from polystar.common.utils.time import create_time_id
 from research.dataset.armor_dataset_factory import ArmorDatasetFactory
 from research_common.dataset.directory_roco_dataset import DirectoryROCODataset
@@ -26,8 +26,8 @@ class ArmorImageDatasetGenerator(ImageDatasetGenerator[T]):
     def _create_labelized_armor_images_from_roco(self, dataset):
         dset_path = dataset.dataset_path / self.task_name
         dset_path.mkdir(exist_ok=True)
-        for (armor_img, color, digit, k, path) in ArmorDatasetFactory.from_dataset(dataset):
-            label = self._label_from_armor_info(color, digit, k, path)
+        for (armor_img, armor, k, path) in ArmorDatasetFactory.from_dataset(dataset):
+            label = self._label_from_armor_info(armor, k, path)
             cv2.imwrite(str(dset_path / f"{path.stem}-{k}-{label}.jpg"), cv2.cvtColor(armor_img, cv2.COLOR_RGB2BGR))
         (dataset.dataset_path / self.task_name / ".lock").write_text(
             json.dumps({"version": "0.0", "date": create_time_id()})
@@ -44,7 +44,7 @@ class ArmorImageDatasetGenerator(ImageDatasetGenerator[T]):
         return self._label_from_str(image_path.stem.split("-")[-1])
 
     @abstractmethod
-    def _label_from_armor_info(self, color: ArmorColor, digit: int, k: int, path: Path) -> T:
+    def _label_from_armor_info(self, armor: Armor, k: int, path: Path) -> T:
         pass
 
     def _valid_label(self, label: T) -> bool:
-- 
GitLab