From 39f68a74864eca5e21d4e5455439b133140bb73a Mon Sep 17 00:00:00 2001 From: Mathieu Beligon <mathieu@feedly.com> Date: Tue, 31 Mar 2020 20:36:30 -0400 Subject: [PATCH] [robots] (armor dataset factory) store images in files after extraction --- .../dataset/armor_color_dataset_factory.py | 4 ++- .../dataset/armor_digit_dataset_factory.py | 7 ++++- .../dataset/armor_image_dataset_factory.py | 30 +++++++++++++++---- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/robots-at-robots/research/dataset/armor_color_dataset_factory.py b/robots-at-robots/research/dataset/armor_color_dataset_factory.py index 83cf323..893947f 100644 --- a/robots-at-robots/research/dataset/armor_color_dataset_factory.py +++ b/robots-at-robots/research/dataset/armor_color_dataset_factory.py @@ -5,5 +5,7 @@ from research.dataset.armor_image_dataset_factory import ArmorImageDatasetGenera class ArmorColorDatasetGenerator(ArmorImageDatasetGenerator[str]): - def _label(self, color: ArmorColor, digit: int, k: int, path: Path) -> str: + task_name: str = "colors" + + def _label_from_armor_info(self, color: ArmorColor, digit: int, k: int, path: Path) -> str: return color.name diff --git a/robots-at-robots/research/dataset/armor_digit_dataset_factory.py b/robots-at-robots/research/dataset/armor_digit_dataset_factory.py index f1ed6f1..8e421cf 100644 --- a/robots-at-robots/research/dataset/armor_digit_dataset_factory.py +++ b/robots-at-robots/research/dataset/armor_digit_dataset_factory.py @@ -6,10 +6,15 @@ from research.dataset.armor_image_dataset_factory import ArmorImageDatasetGenera class ArmorDigitDatasetGenerator(ArmorImageDatasetGenerator[int]): + task_name: str = "digits" + def __init__(self, acceptable_digits: Set[int]): self.acceptable_digits = acceptable_digits - def _label(self, color: ArmorColor, number: int, k: int, path: Path) -> int: + def _label_from_str(self, label: str) -> int: + return int(label) + + def _label_from_armor_info(self, color: ArmorColor, number: int, k: int, path: Path) -> int: return number def _valid_label(self, label: int) -> bool: diff --git a/robots-at-robots/research/dataset/armor_image_dataset_factory.py b/robots-at-robots/research/dataset/armor_image_dataset_factory.py index f90d0a3..85dd6cc 100644 --- a/robots-at-robots/research/dataset/armor_image_dataset_factory.py +++ b/robots-at-robots/research/dataset/armor_image_dataset_factory.py @@ -1,6 +1,9 @@ from abc import abstractmethod +from dataclasses import dataclass from pathlib import Path -from typing import TypeVar, Tuple, List, Iterable +from typing import TypeVar, Tuple, Iterable + +import cv2 from polystar.common.models.image import Image from polystar.common.models.object import ArmorColor @@ -12,15 +15,32 @@ T = TypeVar("T") class ArmorImageDatasetGenerator(ImageDatasetGenerator[T]): + task_name: str + def from_roco_dataset(self, dataset: DirectoryROCODataset) -> Iterable[Tuple[Image, T]]: + if not (dataset.dataset_path / self.task_name).exists(): + self._create_labelized_armor_images_from_roco(dataset) + return self._get_saved_images_and_labels(dataset) + + def _create_labelized_armor_images_from_roco(self, dataset): + dset_path = dataset.dataset_path / self.task_name + dset_path.mkdir() for (armor_img, color, digit, k, path) in ArmorDatasetFactory.from_dataset(dataset): - label = self._label(color, digit, k, path) - if self._valid_label(label): - yield armor_img, label + label = self._label_from_armor_info(color, digit, k, path) + cv2.imwrite(str(dset_path / f"{path.stem}-{k}-{label}.jpg"), cv2.cvtColor(armor_img, cv2.COLOR_RGB2BGR)) + + def _get_saved_images_and_labels(self, dataset: DirectoryROCODataset) -> Iterable[Tuple[Image, T]]: + return ( + (Image.from_path(image_path), self._label_from_str(image_path.stem.split("-")[-1])) + for image_path in (dataset.dataset_path / self.task_name).glob("*.jpg") + ) @abstractmethod - def _label(self, color: ArmorColor, digit: int, k: int, path: Path) -> T: + def _label_from_armor_info(self, color: ArmorColor, digit: int, k: int, path: Path) -> T: pass def _valid_label(self, label: T) -> bool: return True + + def _label_from_str(self, label: str) -> T: + return label -- GitLab