Skip to content
Snippets Groups Projects
Commit bce85aff authored by Mathieu Beligon's avatar Mathieu Beligon
Browse files

[ROCO ZOO] Fix property typing

parent ac2fcbce
No related branches found
No related tags found
No related merge requests found
from typing import Iterable, Iterator, Type
from typing import Iterable, Iterator, List, Type
from research.common.datasets.roco.roco_datasets import ROCODatasets
from research.common.datasets.roco.zoo.dji import DJIROCODatasets
......@@ -7,13 +7,14 @@ from research.common.datasets.roco.zoo.twitch import TwitchROCODatasets
# FIXME: find a better way to do that (builder need to be instantiated once per call)
class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]):
# FIXME: improve the singleton pattern here
class ROCODatasetsZooClass(Iterable[Type[ROCODatasets]]):
DJI_ZOOMED = DJIROCOZoomedDatasets
DJI = DJIROCODatasets
TWITCH = TwitchROCODatasets
@property
def TWITCH_TRAIN_DATASETS(self):
def TWITCH_TRAIN_DATASETS(self) -> List[ROCODatasets]:
return [
self.TWITCH.T470149066,
self.TWITCH.T470150052,
......@@ -23,19 +24,19 @@ class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]):
]
@property
def TWITCH_VALIDATION_DATASETS(self):
def TWITCH_VALIDATION_DATASETS(self) -> List[ROCODatasets]:
return [self.TWITCH.T470152932, self.TWITCH.T470149568]
@property
def TWITCH_TEST_DATASETS(self):
def TWITCH_TEST_DATASETS(self) -> List[ROCODatasets]:
return [self.TWITCH.T470152838, self.TWITCH.T470151286]
@property
def DJI_TRAIN_DATASETS(self):
def DJI_TRAIN_DATASETS(self) -> List[ROCODatasets]:
return [self.DJI.FINAL, self.DJI.CENTRAL_CHINA, self.DJI.NORTH_CHINA, self.DJI.SOUTH_CHINA]
@property
def DJI_ZOOMED_TRAIN_DATASETS(self):
def DJI_ZOOMED_TRAIN_DATASETS(self) -> List[ROCODatasets]:
return [
self.DJI_ZOOMED.FINAL,
self.DJI_ZOOMED.CENTRAL_CHINA,
......@@ -44,11 +45,11 @@ class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]):
]
@property
def TWITCH_DJI_TRAIN_DATASETS(self):
def TWITCH_DJI_TRAIN_DATASETS(self) -> List[ROCODatasets]:
return self.TWITCH_TRAIN_DATASETS + self.DJI_TRAIN_DATASETS
@property
def TWITCH_DJI_ZOOMED_TRAIN_DATASETS(self):
def TWITCH_DJI_ZOOMED_TRAIN_DATASETS(self) -> List[ROCODatasets]:
return self.TWITCH_TRAIN_DATASETS + self.DJI_TRAIN_DATASETS
DEFAULT_TEST_DATASETS = TWITCH_TEST_DATASETS
......@@ -59,7 +60,7 @@ class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]):
return iter((self.DJI, self.DJI_ZOOMED, self.TWITCH))
ROCODatasetsZoo = ROCODatasetsZoo()
ROCODatasetsZoo = ROCODatasetsZooClass()
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment