diff --git a/common/research/common/dataset/roco_dataset_descriptor.py b/common/research/common/dataset/roco_dataset_descriptor.py index f957eeaf5a21afcf55b962e0615996ca0845783a..f7c3016cc7990eb610268a063252275655d367e1 100644 --- a/common/research/common/dataset/roco_dataset_descriptor.py +++ b/common/research/common/dataset/roco_dataset_descriptor.py @@ -1,5 +1,4 @@ from dataclasses import dataclass, field -from itertools import chain from pathlib import Path from typing import Dict @@ -8,6 +7,7 @@ from polystar.common.models.object import Armor, ObjectType from polystar.common.utils.markdown import MarkdownFile from research.common.datasets.roco.roco_dataset import ROCODataset from research.common.datasets.roco.zoo.roco_datasets_zoo import ROCODatasetsZoo +from research.common.datasets.union_dataset import UnionDataset @dataclass @@ -66,5 +66,7 @@ def make_markdown_dataset_report(dataset: ROCODataset, report_dir: Path): if __name__ == "__main__": - for dset in chain(*ROCODatasetsZoo()): - make_markdown_dataset_report(dset, dset.dataset_path) + for datasets in ROCODatasetsZoo(): + make_markdown_dataset_report(UnionDataset(datasets, datasets.name), datasets.directory) + for dset in datasets: + make_markdown_dataset_report(dset, dset.dataset_path) diff --git a/common/research/common/datasets/image_dataset.py b/common/research/common/datasets/image_dataset.py index a9531339b27908fbd248bf9cac9406f5d64932a8..70f8fb2715e2da72134167e71ad985d4f909336e 100644 --- a/common/research/common/datasets/image_dataset.py +++ b/common/research/common/datasets/image_dataset.py @@ -1,11 +1,11 @@ -from typing import Generic, Iterator, List, Tuple, TypeVar +from typing import Generic, Iterable, Iterator, List, Tuple, TypeVar from polystar.common.models.image import Image TargetT = TypeVar("TargetT") -class ImageDataset(Generic[TargetT]): +class ImageDataset(Generic[TargetT], Iterable[Tuple[Image, TargetT]]): def __init__(self, name: str, images: List[Image] = None, targets: List[TargetT] = None): self.name = name self._targets = targets diff --git a/common/research/common/datasets/roco/roco_datasets.py b/common/research/common/datasets/roco/roco_datasets.py index fc0a7d522c1dab294f4a57f76c31940420afa2ed..55056b5b916d2bdf4317788997d21144c3e37368 100644 --- a/common/research/common/datasets/roco/roco_datasets.py +++ b/common/research/common/datasets/roco/roco_datasets.py @@ -1,25 +1,32 @@ -from typing import Any, Iterator, List, Tuple +from abc import abstractmethod +from pathlib import Path +from typing import Any, ClassVar, Iterable, Iterator, List, Tuple -from research.common.dataset.directory_roco_dataset import DirectoryROCODataset +from research.common.datasets.roco.directory_roco_dataset import \ + DirectoryROCODataset -class ROCODatasets: - def make_dataset(dataset_name: str, *args: Any) -> DirectoryROCODataset: +class ROCODatasets(Iterable[DirectoryROCODataset]): + name: ClassVar[str] + datasets: ClassVar[List[DirectoryROCODataset]] + directory: ClassVar[Path] + + @classmethod + @abstractmethod + def make_dataset(cls, dataset_name: str, *args: Any) -> DirectoryROCODataset: pass def __init_subclass__(cls, **kwargs): cls.datasets: List[DirectoryROCODataset] = [] for dataset_name, args in cls.__dict__.items(): - if ( - not callable(args) - and not dataset_name.startswith("_") - and dataset_name not in ("make_dataset", "datasets") - ): + if not dataset_name.islower(): if not isinstance(args, Tuple): args = (args,) dataset = cls.make_dataset(dataset_name, *args) setattr(cls, dataset_name, dataset) cls.datasets.append(dataset) + cls.name = cls.__name__[: -len("ROCODatasets")] + def __iter__(self) -> Iterator[DirectoryROCODataset]: return self.datasets.__iter__() diff --git a/common/research/common/datasets/roco/zoo/dji.py b/common/research/common/datasets/roco/zoo/dji.py index 0b7ca9fe2c3b7791706b4c339b66fc9dccc84c19..221430a7511d1583e907ea92c23f3197e5efeb80 100644 --- a/common/research/common/datasets/roco/zoo/dji.py +++ b/common/research/common/datasets/roco/zoo/dji.py @@ -5,11 +5,13 @@ from research.common.datasets.roco.roco_datasets import ROCODatasets class DJIROCODatasets(ROCODatasets): + directory = DJI_ROCO_DSET_DIR + CentralChina = "robomaster_Central China Regional Competition" NorthChina = "robomaster_North China Regional Competition" SouthChina = "robomaster_South China Regional Competition" Final = "robomaster_Final Tournament" - @staticmethod - def make_dataset(dataset_name: str, competition_name: str) -> DirectoryROCODataset: - return DirectoryROCODataset(DJI_ROCO_DSET_DIR / competition_name, dataset_name) + @classmethod + def make_dataset(cls, dataset_name: str, competition_name: str) -> DirectoryROCODataset: + return DirectoryROCODataset(cls.directory / competition_name, dataset_name) diff --git a/common/research/common/datasets/roco/zoo/dji_zoomed.py b/common/research/common/datasets/roco/zoo/dji_zoomed.py index 39444e3a3726dddf3aadf05ba7a927cfe5cb1df3..5d91133216b51424e07a9ebfa28452932ed5cc61 100644 --- a/common/research/common/datasets/roco/zoo/dji_zoomed.py +++ b/common/research/common/datasets/roco/zoo/dji_zoomed.py @@ -1,3 +1,5 @@ +from typing import Any + from polystar.common.utils.str_utils import camel2snake from research.common.constants import DJI_ROCO_ZOOMED_DSET_DIR from research.common.datasets.roco.directory_roco_dataset import \ @@ -6,11 +8,13 @@ from research.common.datasets.roco.roco_datasets import ROCODatasets class DJIROCOZoomedDatasets(ROCODatasets): + directory = DJI_ROCO_ZOOMED_DSET_DIR + CentralChina = () NorthChina = () SouthChina = () Final = () - @staticmethod - def make_dataset(dataset_name: str) -> DirectoryROCODataset: - return DirectoryROCODataset(DJI_ROCO_ZOOMED_DSET_DIR / camel2snake(dataset_name), f"{dataset_name}ZoomedV2") + @classmethod + def make_dataset(cls, dataset_name: str, *args: Any) -> DirectoryROCODataset: + return DirectoryROCODataset(cls.directory / camel2snake(dataset_name), f"{dataset_name}ZoomedV2") diff --git a/common/research/common/datasets/roco/zoo/roco_datasets_zoo.py b/common/research/common/datasets/roco/zoo/roco_datasets_zoo.py index cfeab521b01bdc016dcfa740b4bbe8313cb9d7a4..3de4a2822130553875c4c5e97b4c471aaebe177c 100644 --- a/common/research/common/datasets/roco/zoo/roco_datasets_zoo.py +++ b/common/research/common/datasets/roco/zoo/roco_datasets_zoo.py @@ -1,9 +1,12 @@ +from typing import Iterable + +from research.common.datasets.roco.roco_datasets import ROCODatasets from research.common.datasets.roco.zoo.dji import DJIROCODatasets from research.common.datasets.roco.zoo.dji_zoomed import DJIROCOZoomedDatasets from research.common.datasets.roco.zoo.twitch import TwitchROCODatasets -class ROCODatasetsZoo: +class ROCODatasetsZoo(Iterable[ROCODatasets]): DJI_ZOOMED = DJIROCOZoomedDatasets() DJI = DJIROCODatasets() TWITCH = TwitchROCODatasets() diff --git a/common/research/common/datasets/roco/zoo/twitch.py b/common/research/common/datasets/roco/zoo/twitch.py index 9d7a0af4f8fa118b2c8a1fa68d170e656eeb2b85..68ca46ac3e8254fe5a3b3e86a332e019f3a7dfa8 100644 --- a/common/research/common/datasets/roco/zoo/twitch.py +++ b/common/research/common/datasets/roco/zoo/twitch.py @@ -1,3 +1,5 @@ +from typing import Any + from research.common.constants import TWITCH_DSET_DIR from research.common.datasets.roco.directory_roco_dataset import \ DirectoryROCODataset @@ -5,6 +7,8 @@ from research.common.datasets.roco.roco_datasets import ROCODatasets class TwitchROCODatasets(ROCODatasets): + directory = TWITCH_DSET_DIR / "v1" + TWITCH_470149568 = () TWITCH_470150052 = () TWITCH_470151286 = () @@ -14,7 +18,7 @@ class TwitchROCODatasets(ROCODatasets): TWITCH_470153081 = () TWITCH_470158483 = () - @staticmethod - def make_dataset(dataset_name: str) -> DirectoryROCODataset: + @classmethod + def make_dataset(cls, dataset_name: str, *args: Any) -> DirectoryROCODataset: twitch_id = dataset_name[len("TWITCH_") :] - return DirectoryROCODataset(TWITCH_DSET_DIR / "v1" / twitch_id, f"T{twitch_id}") + return DirectoryROCODataset(cls.directory / twitch_id, f"T{twitch_id}") diff --git a/common/research/common/datasets/union_dataset.py b/common/research/common/datasets/union_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..623887862ae88553bacb24aa47ef428100b65c4f --- /dev/null +++ b/common/research/common/datasets/union_dataset.py @@ -0,0 +1,19 @@ +from itertools import chain +from typing import Iterable, Iterator, List, Tuple + +from polystar.common.models.image import Image +from research.common.datasets.image_dataset import ImageDataset, TargetT + + +class UnionDataset(ImageDataset[TargetT]): + def __init__(self, datasets: Iterable[ImageDataset[TargetT]], name: str): + super().__init__(name) + self.datasets = list(datasets) + + def __iter__(self) -> Iterator[Tuple[Image, TargetT]]: + for dataset in self.datasets: + yield from dataset + + @property + def targets(self) -> List[TargetT]: + return list(chain(*(dataset.targets for dataset in self.datasets)))