From ac2fcbce59e55b83ba71f65804bfc22880e9d074 Mon Sep 17 00:00:00 2001 From: Mathieu Beligon <mathieu@feedly.com> Date: Mon, 22 Mar 2021 19:50:11 -0400 Subject: [PATCH] [ROCO ZOO] allow to access a zoo list multiple times --- .../common/datasets/roco/roco_datasets.py | 1 - .../datasets/roco/zoo/roco_dataset_zoo.py | 64 +++++++++++++------ 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/src/research/common/datasets/roco/roco_datasets.py b/src/research/common/datasets/roco/roco_datasets.py index b7765e9..dbecdcd 100644 --- a/src/research/common/datasets/roco/roco_datasets.py +++ b/src/research/common/datasets/roco/roco_datasets.py @@ -7,7 +7,6 @@ from research.common.datasets.roco.roco_dataset_builder import ROCODatasetBuilde from research.common.datasets.union_dataset import UnionLazyDataset -# FIXME : we should be able to access a builder 2 times class ROCODatasetsMeta(type): def __init__(cls, name: str, bases, dct): super().__init__(name, bases, dct) diff --git a/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py b/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py index 65a982a..3ceab94 100644 --- a/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py +++ b/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py @@ -6,31 +6,50 @@ from research.common.datasets.roco.zoo.dji_zoomed import DJIROCOZoomedDatasets from research.common.datasets.roco.zoo.twitch import TwitchROCODatasets +# FIXME: find a better way to do that (builder need to be instantiated once per call) class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]): DJI_ZOOMED = DJIROCOZoomedDatasets DJI = DJIROCODatasets TWITCH = TwitchROCODatasets - TWITCH_TRAIN_DATASETS = [ - TWITCH.T470149066, - TWITCH.T470150052, - TWITCH.T470152289, - TWITCH.T470153081, - TWITCH.T470158483, - ] - TWITCH_VALIDATION_DATASETS = [TWITCH.T470152932, TWITCH.T470149568] - TWITCH_TEST_DATASETS = [TWITCH.T470152838, TWITCH.T470151286] - - DJI_TRAIN_DATASETS = [DJI.FINAL, DJI.CENTRAL_CHINA, DJI.NORTH_CHINA, DJI.SOUTH_CHINA] - DJI_ZOOMED_TRAIN_DATASETS = [ - DJI_ZOOMED.FINAL, - DJI_ZOOMED.CENTRAL_CHINA, - DJI_ZOOMED.NORTH_CHINA, - DJI_ZOOMED.SOUTH_CHINA, - ] - - TWITCH_DJI_TRAIN_DATASETS = TWITCH_TRAIN_DATASETS + DJI_TRAIN_DATASETS - TWITCH_DJI_ZOOMED_TRAIN_DATASETS = TWITCH_TRAIN_DATASETS + DJI_TRAIN_DATASETS + @property + def TWITCH_TRAIN_DATASETS(self): + return [ + self.TWITCH.T470149066, + self.TWITCH.T470150052, + self.TWITCH.T470152289, + self.TWITCH.T470153081, + self.TWITCH.T470158483, + ] + + @property + def TWITCH_VALIDATION_DATASETS(self): + return [self.TWITCH.T470152932, self.TWITCH.T470149568] + + @property + def TWITCH_TEST_DATASETS(self): + return [self.TWITCH.T470152838, self.TWITCH.T470151286] + + @property + def DJI_TRAIN_DATASETS(self): + return [self.DJI.FINAL, self.DJI.CENTRAL_CHINA, self.DJI.NORTH_CHINA, self.DJI.SOUTH_CHINA] + + @property + def DJI_ZOOMED_TRAIN_DATASETS(self): + return [ + self.DJI_ZOOMED.FINAL, + self.DJI_ZOOMED.CENTRAL_CHINA, + self.DJI_ZOOMED.NORTH_CHINA, + self.DJI_ZOOMED.SOUTH_CHINA, + ] + + @property + def TWITCH_DJI_TRAIN_DATASETS(self): + return self.TWITCH_TRAIN_DATASETS + self.DJI_TRAIN_DATASETS + + @property + def TWITCH_DJI_ZOOMED_TRAIN_DATASETS(self): + return self.TWITCH_TRAIN_DATASETS + self.DJI_TRAIN_DATASETS DEFAULT_TEST_DATASETS = TWITCH_TEST_DATASETS DEFAULT_VALIDATION_DATASETS = TWITCH_VALIDATION_DATASETS @@ -41,3 +60,8 @@ class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]): ROCODatasetsZoo = ROCODatasetsZoo() + + +if __name__ == "__main__": + ROCODatasetsZoo.DEFAULT_TEST_DATASETS[0].build_lazy() + ROCODatasetsZoo.DEFAULT_TEST_DATASETS[0].build_lazy() -- GitLab