From bce85afffd677b9da5f4921e8865c0c0c941bf61 Mon Sep 17 00:00:00 2001 From: Mathieu Beligon <mathieu@feedly.com> Date: Mon, 22 Mar 2021 19:53:58 -0400 Subject: [PATCH] [ROCO ZOO] Fix property typing --- .../datasets/roco/zoo/roco_dataset_zoo.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py b/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py index 3ceab94..46bb835 100644 --- a/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py +++ b/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py @@ -1,4 +1,4 @@ -from typing import Iterable, Iterator, Type +from typing import Iterable, Iterator, List, Type from research.common.datasets.roco.roco_datasets import ROCODatasets from research.common.datasets.roco.zoo.dji import DJIROCODatasets @@ -7,13 +7,14 @@ from research.common.datasets.roco.zoo.twitch import TwitchROCODatasets # FIXME: find a better way to do that (builder need to be instantiated once per call) -class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]): +# FIXME: improve the singleton pattern here +class ROCODatasetsZooClass(Iterable[Type[ROCODatasets]]): DJI_ZOOMED = DJIROCOZoomedDatasets DJI = DJIROCODatasets TWITCH = TwitchROCODatasets @property - def TWITCH_TRAIN_DATASETS(self): + def TWITCH_TRAIN_DATASETS(self) -> List[ROCODatasets]: return [ self.TWITCH.T470149066, self.TWITCH.T470150052, @@ -23,19 +24,19 @@ class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]): ] @property - def TWITCH_VALIDATION_DATASETS(self): + def TWITCH_VALIDATION_DATASETS(self) -> List[ROCODatasets]: return [self.TWITCH.T470152932, self.TWITCH.T470149568] @property - def TWITCH_TEST_DATASETS(self): + def TWITCH_TEST_DATASETS(self) -> List[ROCODatasets]: return [self.TWITCH.T470152838, self.TWITCH.T470151286] @property - def DJI_TRAIN_DATASETS(self): + def DJI_TRAIN_DATASETS(self) -> List[ROCODatasets]: return [self.DJI.FINAL, self.DJI.CENTRAL_CHINA, self.DJI.NORTH_CHINA, self.DJI.SOUTH_CHINA] @property - def DJI_ZOOMED_TRAIN_DATASETS(self): + def DJI_ZOOMED_TRAIN_DATASETS(self) -> List[ROCODatasets]: return [ self.DJI_ZOOMED.FINAL, self.DJI_ZOOMED.CENTRAL_CHINA, @@ -44,11 +45,11 @@ class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]): ] @property - def TWITCH_DJI_TRAIN_DATASETS(self): + def TWITCH_DJI_TRAIN_DATASETS(self) -> List[ROCODatasets]: return self.TWITCH_TRAIN_DATASETS + self.DJI_TRAIN_DATASETS @property - def TWITCH_DJI_ZOOMED_TRAIN_DATASETS(self): + def TWITCH_DJI_ZOOMED_TRAIN_DATASETS(self) -> List[ROCODatasets]: return self.TWITCH_TRAIN_DATASETS + self.DJI_TRAIN_DATASETS DEFAULT_TEST_DATASETS = TWITCH_TEST_DATASETS @@ -59,7 +60,7 @@ class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]): return iter((self.DJI, self.DJI_ZOOMED, self.TWITCH)) -ROCODatasetsZoo = ROCODatasetsZoo() +ROCODatasetsZoo = ROCODatasetsZooClass() if __name__ == "__main__": -- GitLab