diff --git a/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py b/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py index 3ceab94fe181e843f8d6b22f07ed21a89b7dbffb..46bb835df3e967413951dcc7815bd51d1fa37bbe 100644 --- a/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py +++ b/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py @@ -1,4 +1,4 @@ -from typing import Iterable, Iterator, Type +from typing import Iterable, Iterator, List, Type from research.common.datasets.roco.roco_datasets import ROCODatasets from research.common.datasets.roco.zoo.dji import DJIROCODatasets @@ -7,13 +7,14 @@ from research.common.datasets.roco.zoo.twitch import TwitchROCODatasets # FIXME: find a better way to do that (builder need to be instantiated once per call) -class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]): +# FIXME: improve the singleton pattern here +class ROCODatasetsZooClass(Iterable[Type[ROCODatasets]]): DJI_ZOOMED = DJIROCOZoomedDatasets DJI = DJIROCODatasets TWITCH = TwitchROCODatasets @property - def TWITCH_TRAIN_DATASETS(self): + def TWITCH_TRAIN_DATASETS(self) -> List[ROCODatasets]: return [ self.TWITCH.T470149066, self.TWITCH.T470150052, @@ -23,19 +24,19 @@ class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]): ] @property - def TWITCH_VALIDATION_DATASETS(self): + def TWITCH_VALIDATION_DATASETS(self) -> List[ROCODatasets]: return [self.TWITCH.T470152932, self.TWITCH.T470149568] @property - def TWITCH_TEST_DATASETS(self): + def TWITCH_TEST_DATASETS(self) -> List[ROCODatasets]: return [self.TWITCH.T470152838, self.TWITCH.T470151286] @property - def DJI_TRAIN_DATASETS(self): + def DJI_TRAIN_DATASETS(self) -> List[ROCODatasets]: return [self.DJI.FINAL, self.DJI.CENTRAL_CHINA, self.DJI.NORTH_CHINA, self.DJI.SOUTH_CHINA] @property - def DJI_ZOOMED_TRAIN_DATASETS(self): + def DJI_ZOOMED_TRAIN_DATASETS(self) -> List[ROCODatasets]: return [ self.DJI_ZOOMED.FINAL, self.DJI_ZOOMED.CENTRAL_CHINA, @@ -44,11 +45,11 @@ class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]): ] @property - def TWITCH_DJI_TRAIN_DATASETS(self): + def TWITCH_DJI_TRAIN_DATASETS(self) -> List[ROCODatasets]: return self.TWITCH_TRAIN_DATASETS + self.DJI_TRAIN_DATASETS @property - def TWITCH_DJI_ZOOMED_TRAIN_DATASETS(self): + def TWITCH_DJI_ZOOMED_TRAIN_DATASETS(self) -> List[ROCODatasets]: return self.TWITCH_TRAIN_DATASETS + self.DJI_TRAIN_DATASETS DEFAULT_TEST_DATASETS = TWITCH_TEST_DATASETS @@ -59,7 +60,7 @@ class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]): return iter((self.DJI, self.DJI_ZOOMED, self.TWITCH)) -ROCODatasetsZoo = ROCODatasetsZoo() +ROCODatasetsZoo = ROCODatasetsZooClass() if __name__ == "__main__":