diff --git a/src/research/common/datasets/roco/roco_datasets.py b/src/research/common/datasets/roco/roco_datasets.py index b7765e94746ea8fadd1a20ada6c71d72c1e0fb2d..dbecdcdec9add9f5939986cdcc3e4d8955f19e1d 100644 --- a/src/research/common/datasets/roco/roco_datasets.py +++ b/src/research/common/datasets/roco/roco_datasets.py @@ -7,7 +7,6 @@ from research.common.datasets.roco.roco_dataset_builder import ROCODatasetBuilde from research.common.datasets.union_dataset import UnionLazyDataset -# FIXME : we should be able to access a builder 2 times class ROCODatasetsMeta(type): def __init__(cls, name: str, bases, dct): super().__init__(name, bases, dct) diff --git a/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py b/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py index 65a982acf125e98c6726df18f28eb8fb285b7788..3ceab94fe181e843f8d6b22f07ed21a89b7dbffb 100644 --- a/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py +++ b/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py @@ -6,31 +6,50 @@ from research.common.datasets.roco.zoo.dji_zoomed import DJIROCOZoomedDatasets from research.common.datasets.roco.zoo.twitch import TwitchROCODatasets +# FIXME: find a better way to do that (builder need to be instantiated once per call) class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]): DJI_ZOOMED = DJIROCOZoomedDatasets DJI = DJIROCODatasets TWITCH = TwitchROCODatasets - TWITCH_TRAIN_DATASETS = [ - TWITCH.T470149066, - TWITCH.T470150052, - TWITCH.T470152289, - TWITCH.T470153081, - TWITCH.T470158483, - ] - TWITCH_VALIDATION_DATASETS = [TWITCH.T470152932, TWITCH.T470149568] - TWITCH_TEST_DATASETS = [TWITCH.T470152838, TWITCH.T470151286] - - DJI_TRAIN_DATASETS = [DJI.FINAL, DJI.CENTRAL_CHINA, DJI.NORTH_CHINA, DJI.SOUTH_CHINA] - DJI_ZOOMED_TRAIN_DATASETS = [ - DJI_ZOOMED.FINAL, - DJI_ZOOMED.CENTRAL_CHINA, - DJI_ZOOMED.NORTH_CHINA, - DJI_ZOOMED.SOUTH_CHINA, - ] - - TWITCH_DJI_TRAIN_DATASETS = TWITCH_TRAIN_DATASETS + DJI_TRAIN_DATASETS - TWITCH_DJI_ZOOMED_TRAIN_DATASETS = TWITCH_TRAIN_DATASETS + DJI_TRAIN_DATASETS + @property + def TWITCH_TRAIN_DATASETS(self): + return [ + self.TWITCH.T470149066, + self.TWITCH.T470150052, + self.TWITCH.T470152289, + self.TWITCH.T470153081, + self.TWITCH.T470158483, + ] + + @property + def TWITCH_VALIDATION_DATASETS(self): + return [self.TWITCH.T470152932, self.TWITCH.T470149568] + + @property + def TWITCH_TEST_DATASETS(self): + return [self.TWITCH.T470152838, self.TWITCH.T470151286] + + @property + def DJI_TRAIN_DATASETS(self): + return [self.DJI.FINAL, self.DJI.CENTRAL_CHINA, self.DJI.NORTH_CHINA, self.DJI.SOUTH_CHINA] + + @property + def DJI_ZOOMED_TRAIN_DATASETS(self): + return [ + self.DJI_ZOOMED.FINAL, + self.DJI_ZOOMED.CENTRAL_CHINA, + self.DJI_ZOOMED.NORTH_CHINA, + self.DJI_ZOOMED.SOUTH_CHINA, + ] + + @property + def TWITCH_DJI_TRAIN_DATASETS(self): + return self.TWITCH_TRAIN_DATASETS + self.DJI_TRAIN_DATASETS + + @property + def TWITCH_DJI_ZOOMED_TRAIN_DATASETS(self): + return self.TWITCH_TRAIN_DATASETS + self.DJI_TRAIN_DATASETS DEFAULT_TEST_DATASETS = TWITCH_TEST_DATASETS DEFAULT_VALIDATION_DATASETS = TWITCH_VALIDATION_DATASETS @@ -41,3 +60,8 @@ class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]): ROCODatasetsZoo = ROCODatasetsZoo() + + +if __name__ == "__main__": + ROCODatasetsZoo.DEFAULT_TEST_DATASETS[0].build_lazy() + ROCODatasetsZoo.DEFAULT_TEST_DATASETS[0].build_lazy()