From 060af02d60c1885fd7acf0c2e0fff631f5ed8d54 Mon Sep 17 00:00:00 2001
From: Mathieu Beligon <mathieu@feedly.com>
Date: Fri, 11 Sep 2020 21:33:53 +0200
Subject: [PATCH] [common] (UnionDataset) Add in the new format

---
 .../common/dataset/roco_dataset_descriptor.py |  8 +++---
 .../research/common/datasets/image_dataset.py |  4 +--
 .../common/datasets/roco/roco_datasets.py     | 25 ++++++++++++-------
 .../research/common/datasets/roco/zoo/dji.py  |  8 +++---
 .../common/datasets/roco/zoo/dji_zoomed.py    | 10 +++++---
 .../datasets/roco/zoo/roco_datasets_zoo.py    |  5 +++-
 .../common/datasets/roco/zoo/twitch.py        | 10 +++++---
 .../research/common/datasets/union_dataset.py | 19 ++++++++++++++
 8 files changed, 65 insertions(+), 24 deletions(-)
 create mode 100644 common/research/common/datasets/union_dataset.py

diff --git a/common/research/common/dataset/roco_dataset_descriptor.py b/common/research/common/dataset/roco_dataset_descriptor.py
index f957eea..f7c3016 100644
--- a/common/research/common/dataset/roco_dataset_descriptor.py
+++ b/common/research/common/dataset/roco_dataset_descriptor.py
@@ -1,5 +1,4 @@
 from dataclasses import dataclass, field
-from itertools import chain
 from pathlib import Path
 from typing import Dict
 
@@ -8,6 +7,7 @@ from polystar.common.models.object import Armor, ObjectType
 from polystar.common.utils.markdown import MarkdownFile
 from research.common.datasets.roco.roco_dataset import ROCODataset
 from research.common.datasets.roco.zoo.roco_datasets_zoo import ROCODatasetsZoo
+from research.common.datasets.union_dataset import UnionDataset
 
 
 @dataclass
@@ -66,5 +66,7 @@ def make_markdown_dataset_report(dataset: ROCODataset, report_dir: Path):
 
 
 if __name__ == "__main__":
-    for dset in chain(*ROCODatasetsZoo()):
-        make_markdown_dataset_report(dset, dset.dataset_path)
+    for datasets in ROCODatasetsZoo():
+        make_markdown_dataset_report(UnionDataset(datasets, datasets.name), datasets.directory)
+        for dset in datasets:
+            make_markdown_dataset_report(dset, dset.dataset_path)
diff --git a/common/research/common/datasets/image_dataset.py b/common/research/common/datasets/image_dataset.py
index a953133..70f8fb2 100644
--- a/common/research/common/datasets/image_dataset.py
+++ b/common/research/common/datasets/image_dataset.py
@@ -1,11 +1,11 @@
-from typing import Generic, Iterator, List, Tuple, TypeVar
+from typing import Generic, Iterable, Iterator, List, Tuple, TypeVar
 
 from polystar.common.models.image import Image
 
 TargetT = TypeVar("TargetT")
 
 
-class ImageDataset(Generic[TargetT]):
+class ImageDataset(Generic[TargetT], Iterable[Tuple[Image, TargetT]]):
     def __init__(self, name: str, images: List[Image] = None, targets: List[TargetT] = None):
         self.name = name
         self._targets = targets
diff --git a/common/research/common/datasets/roco/roco_datasets.py b/common/research/common/datasets/roco/roco_datasets.py
index fc0a7d5..55056b5 100644
--- a/common/research/common/datasets/roco/roco_datasets.py
+++ b/common/research/common/datasets/roco/roco_datasets.py
@@ -1,25 +1,32 @@
-from typing import Any, Iterator, List, Tuple
+from abc import abstractmethod
+from pathlib import Path
+from typing import Any, ClassVar, Iterable, Iterator, List, Tuple
 
-from research.common.dataset.directory_roco_dataset import DirectoryROCODataset
+from research.common.datasets.roco.directory_roco_dataset import \
+    DirectoryROCODataset
 
 
-class ROCODatasets:
-    def make_dataset(dataset_name: str, *args: Any) -> DirectoryROCODataset:
+class ROCODatasets(Iterable[DirectoryROCODataset]):
+    name: ClassVar[str]
+    datasets: ClassVar[List[DirectoryROCODataset]]
+    directory: ClassVar[Path]
+
+    @classmethod
+    @abstractmethod
+    def make_dataset(cls, dataset_name: str, *args: Any) -> DirectoryROCODataset:
         pass
 
     def __init_subclass__(cls, **kwargs):
         cls.datasets: List[DirectoryROCODataset] = []
         for dataset_name, args in cls.__dict__.items():
-            if (
-                not callable(args)
-                and not dataset_name.startswith("_")
-                and dataset_name not in ("make_dataset", "datasets")
-            ):
+            if not dataset_name.islower():
                 if not isinstance(args, Tuple):
                     args = (args,)
                 dataset = cls.make_dataset(dataset_name, *args)
                 setattr(cls, dataset_name, dataset)
                 cls.datasets.append(dataset)
 
+        cls.name = cls.__name__[: -len("ROCODatasets")]
+
     def __iter__(self) -> Iterator[DirectoryROCODataset]:
         return self.datasets.__iter__()
diff --git a/common/research/common/datasets/roco/zoo/dji.py b/common/research/common/datasets/roco/zoo/dji.py
index 0b7ca9f..221430a 100644
--- a/common/research/common/datasets/roco/zoo/dji.py
+++ b/common/research/common/datasets/roco/zoo/dji.py
@@ -5,11 +5,13 @@ from research.common.datasets.roco.roco_datasets import ROCODatasets
 
 
 class DJIROCODatasets(ROCODatasets):
+    directory = DJI_ROCO_DSET_DIR
+
     CentralChina = "robomaster_Central China Regional Competition"
     NorthChina = "robomaster_North China Regional Competition"
     SouthChina = "robomaster_South China Regional Competition"
     Final = "robomaster_Final Tournament"
 
-    @staticmethod
-    def make_dataset(dataset_name: str, competition_name: str) -> DirectoryROCODataset:
-        return DirectoryROCODataset(DJI_ROCO_DSET_DIR / competition_name, dataset_name)
+    @classmethod
+    def make_dataset(cls, dataset_name: str, competition_name: str) -> DirectoryROCODataset:
+        return DirectoryROCODataset(cls.directory / competition_name, dataset_name)
diff --git a/common/research/common/datasets/roco/zoo/dji_zoomed.py b/common/research/common/datasets/roco/zoo/dji_zoomed.py
index 39444e3..5d91133 100644
--- a/common/research/common/datasets/roco/zoo/dji_zoomed.py
+++ b/common/research/common/datasets/roco/zoo/dji_zoomed.py
@@ -1,3 +1,5 @@
+from typing import Any
+
 from polystar.common.utils.str_utils import camel2snake
 from research.common.constants import DJI_ROCO_ZOOMED_DSET_DIR
 from research.common.datasets.roco.directory_roco_dataset import \
@@ -6,11 +8,13 @@ from research.common.datasets.roco.roco_datasets import ROCODatasets
 
 
 class DJIROCOZoomedDatasets(ROCODatasets):
+    directory = DJI_ROCO_ZOOMED_DSET_DIR
+
     CentralChina = ()
     NorthChina = ()
     SouthChina = ()
     Final = ()
 
-    @staticmethod
-    def make_dataset(dataset_name: str) -> DirectoryROCODataset:
-        return DirectoryROCODataset(DJI_ROCO_ZOOMED_DSET_DIR / camel2snake(dataset_name), f"{dataset_name}ZoomedV2")
+    @classmethod
+    def make_dataset(cls, dataset_name: str, *args: Any) -> DirectoryROCODataset:
+        return DirectoryROCODataset(cls.directory / camel2snake(dataset_name), f"{dataset_name}ZoomedV2")
diff --git a/common/research/common/datasets/roco/zoo/roco_datasets_zoo.py b/common/research/common/datasets/roco/zoo/roco_datasets_zoo.py
index cfeab52..3de4a28 100644
--- a/common/research/common/datasets/roco/zoo/roco_datasets_zoo.py
+++ b/common/research/common/datasets/roco/zoo/roco_datasets_zoo.py
@@ -1,9 +1,12 @@
+from typing import Iterable
+
+from research.common.datasets.roco.roco_datasets import ROCODatasets
 from research.common.datasets.roco.zoo.dji import DJIROCODatasets
 from research.common.datasets.roco.zoo.dji_zoomed import DJIROCOZoomedDatasets
 from research.common.datasets.roco.zoo.twitch import TwitchROCODatasets
 
 
-class ROCODatasetsZoo:
+class ROCODatasetsZoo(Iterable[ROCODatasets]):
     DJI_ZOOMED = DJIROCOZoomedDatasets()
     DJI = DJIROCODatasets()
     TWITCH = TwitchROCODatasets()
diff --git a/common/research/common/datasets/roco/zoo/twitch.py b/common/research/common/datasets/roco/zoo/twitch.py
index 9d7a0af..68ca46a 100644
--- a/common/research/common/datasets/roco/zoo/twitch.py
+++ b/common/research/common/datasets/roco/zoo/twitch.py
@@ -1,3 +1,5 @@
+from typing import Any
+
 from research.common.constants import TWITCH_DSET_DIR
 from research.common.datasets.roco.directory_roco_dataset import \
     DirectoryROCODataset
@@ -5,6 +7,8 @@ from research.common.datasets.roco.roco_datasets import ROCODatasets
 
 
 class TwitchROCODatasets(ROCODatasets):
+    directory = TWITCH_DSET_DIR / "v1"
+
     TWITCH_470149568 = ()
     TWITCH_470150052 = ()
     TWITCH_470151286 = ()
@@ -14,7 +18,7 @@ class TwitchROCODatasets(ROCODatasets):
     TWITCH_470153081 = ()
     TWITCH_470158483 = ()
 
-    @staticmethod
-    def make_dataset(dataset_name: str) -> DirectoryROCODataset:
+    @classmethod
+    def make_dataset(cls, dataset_name: str, *args: Any) -> DirectoryROCODataset:
         twitch_id = dataset_name[len("TWITCH_") :]
-        return DirectoryROCODataset(TWITCH_DSET_DIR / "v1" / twitch_id, f"T{twitch_id}")
+        return DirectoryROCODataset(cls.directory / twitch_id, f"T{twitch_id}")
diff --git a/common/research/common/datasets/union_dataset.py b/common/research/common/datasets/union_dataset.py
new file mode 100644
index 0000000..6238878
--- /dev/null
+++ b/common/research/common/datasets/union_dataset.py
@@ -0,0 +1,19 @@
+from itertools import chain
+from typing import Iterable, Iterator, List, Tuple
+
+from polystar.common.models.image import Image
+from research.common.datasets.image_dataset import ImageDataset, TargetT
+
+
+class UnionDataset(ImageDataset[TargetT]):
+    def __init__(self, datasets: Iterable[ImageDataset[TargetT]], name: str):
+        super().__init__(name)
+        self.datasets = list(datasets)
+
+    def __iter__(self) -> Iterator[Tuple[Image, TargetT]]:
+        for dataset in self.datasets:
+            yield from dataset
+
+    @property
+    def targets(self) -> List[TargetT]:
+        return list(chain(*(dataset.targets for dataset in self.datasets)))
-- 
GitLab