Skip to content
Snippets Groups Projects
Commit ac2fcbce authored by Mathieu Beligon's avatar Mathieu Beligon
Browse files

[ROCO ZOO] allow to access a zoo list multiple times

parent 6b7f7909
No related branches found
No related tags found
No related merge requests found
...@@ -7,7 +7,6 @@ from research.common.datasets.roco.roco_dataset_builder import ROCODatasetBuilde ...@@ -7,7 +7,6 @@ from research.common.datasets.roco.roco_dataset_builder import ROCODatasetBuilde
from research.common.datasets.union_dataset import UnionLazyDataset from research.common.datasets.union_dataset import UnionLazyDataset
# FIXME : we should be able to access a builder 2 times
class ROCODatasetsMeta(type): class ROCODatasetsMeta(type):
def __init__(cls, name: str, bases, dct): def __init__(cls, name: str, bases, dct):
super().__init__(name, bases, dct) super().__init__(name, bases, dct)
......
...@@ -6,31 +6,50 @@ from research.common.datasets.roco.zoo.dji_zoomed import DJIROCOZoomedDatasets ...@@ -6,31 +6,50 @@ from research.common.datasets.roco.zoo.dji_zoomed import DJIROCOZoomedDatasets
from research.common.datasets.roco.zoo.twitch import TwitchROCODatasets from research.common.datasets.roco.zoo.twitch import TwitchROCODatasets
# FIXME: find a better way to do that (builder need to be instantiated once per call)
class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]): class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]):
DJI_ZOOMED = DJIROCOZoomedDatasets DJI_ZOOMED = DJIROCOZoomedDatasets
DJI = DJIROCODatasets DJI = DJIROCODatasets
TWITCH = TwitchROCODatasets TWITCH = TwitchROCODatasets
TWITCH_TRAIN_DATASETS = [ @property
TWITCH.T470149066, def TWITCH_TRAIN_DATASETS(self):
TWITCH.T470150052, return [
TWITCH.T470152289, self.TWITCH.T470149066,
TWITCH.T470153081, self.TWITCH.T470150052,
TWITCH.T470158483, self.TWITCH.T470152289,
] self.TWITCH.T470153081,
TWITCH_VALIDATION_DATASETS = [TWITCH.T470152932, TWITCH.T470149568] self.TWITCH.T470158483,
TWITCH_TEST_DATASETS = [TWITCH.T470152838, TWITCH.T470151286] ]
DJI_TRAIN_DATASETS = [DJI.FINAL, DJI.CENTRAL_CHINA, DJI.NORTH_CHINA, DJI.SOUTH_CHINA] @property
DJI_ZOOMED_TRAIN_DATASETS = [ def TWITCH_VALIDATION_DATASETS(self):
DJI_ZOOMED.FINAL, return [self.TWITCH.T470152932, self.TWITCH.T470149568]
DJI_ZOOMED.CENTRAL_CHINA,
DJI_ZOOMED.NORTH_CHINA, @property
DJI_ZOOMED.SOUTH_CHINA, def TWITCH_TEST_DATASETS(self):
] return [self.TWITCH.T470152838, self.TWITCH.T470151286]
TWITCH_DJI_TRAIN_DATASETS = TWITCH_TRAIN_DATASETS + DJI_TRAIN_DATASETS @property
TWITCH_DJI_ZOOMED_TRAIN_DATASETS = TWITCH_TRAIN_DATASETS + DJI_TRAIN_DATASETS def DJI_TRAIN_DATASETS(self):
return [self.DJI.FINAL, self.DJI.CENTRAL_CHINA, self.DJI.NORTH_CHINA, self.DJI.SOUTH_CHINA]
@property
def DJI_ZOOMED_TRAIN_DATASETS(self):
return [
self.DJI_ZOOMED.FINAL,
self.DJI_ZOOMED.CENTRAL_CHINA,
self.DJI_ZOOMED.NORTH_CHINA,
self.DJI_ZOOMED.SOUTH_CHINA,
]
@property
def TWITCH_DJI_TRAIN_DATASETS(self):
return self.TWITCH_TRAIN_DATASETS + self.DJI_TRAIN_DATASETS
@property
def TWITCH_DJI_ZOOMED_TRAIN_DATASETS(self):
return self.TWITCH_TRAIN_DATASETS + self.DJI_TRAIN_DATASETS
DEFAULT_TEST_DATASETS = TWITCH_TEST_DATASETS DEFAULT_TEST_DATASETS = TWITCH_TEST_DATASETS
DEFAULT_VALIDATION_DATASETS = TWITCH_VALIDATION_DATASETS DEFAULT_VALIDATION_DATASETS = TWITCH_VALIDATION_DATASETS
...@@ -41,3 +60,8 @@ class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]): ...@@ -41,3 +60,8 @@ class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]):
ROCODatasetsZoo = ROCODatasetsZoo() ROCODatasetsZoo = ROCODatasetsZoo()
if __name__ == "__main__":
ROCODatasetsZoo.DEFAULT_TEST_DATASETS[0].build_lazy()
ROCODatasetsZoo.DEFAULT_TEST_DATASETS[0].build_lazy()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment