From ac2fcbce59e55b83ba71f65804bfc22880e9d074 Mon Sep 17 00:00:00 2001
From: Mathieu Beligon <mathieu@feedly.com>
Date: Mon, 22 Mar 2021 19:50:11 -0400
Subject: [PATCH] [ROCO ZOO] allow to access a zoo list multiple times

---
 .../common/datasets/roco/roco_datasets.py     |  1 -
 .../datasets/roco/zoo/roco_dataset_zoo.py     | 64 +++++++++++++------
 2 files changed, 44 insertions(+), 21 deletions(-)

diff --git a/src/research/common/datasets/roco/roco_datasets.py b/src/research/common/datasets/roco/roco_datasets.py
index b7765e9..dbecdcd 100644
--- a/src/research/common/datasets/roco/roco_datasets.py
+++ b/src/research/common/datasets/roco/roco_datasets.py
@@ -7,7 +7,6 @@ from research.common.datasets.roco.roco_dataset_builder import ROCODatasetBuilde
 from research.common.datasets.union_dataset import UnionLazyDataset
 
 
-# FIXME : we should be able to access a builder 2 times
 class ROCODatasetsMeta(type):
     def __init__(cls, name: str, bases, dct):
         super().__init__(name, bases, dct)
diff --git a/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py b/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py
index 65a982a..3ceab94 100644
--- a/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py
+++ b/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py
@@ -6,31 +6,50 @@ from research.common.datasets.roco.zoo.dji_zoomed import DJIROCOZoomedDatasets
 from research.common.datasets.roco.zoo.twitch import TwitchROCODatasets
 
 
+# FIXME: find a better way to do that (builder need to be instantiated once per call)
 class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]):
     DJI_ZOOMED = DJIROCOZoomedDatasets
     DJI = DJIROCODatasets
     TWITCH = TwitchROCODatasets
 
-    TWITCH_TRAIN_DATASETS = [
-        TWITCH.T470149066,
-        TWITCH.T470150052,
-        TWITCH.T470152289,
-        TWITCH.T470153081,
-        TWITCH.T470158483,
-    ]
-    TWITCH_VALIDATION_DATASETS = [TWITCH.T470152932, TWITCH.T470149568]
-    TWITCH_TEST_DATASETS = [TWITCH.T470152838, TWITCH.T470151286]
-
-    DJI_TRAIN_DATASETS = [DJI.FINAL, DJI.CENTRAL_CHINA, DJI.NORTH_CHINA, DJI.SOUTH_CHINA]
-    DJI_ZOOMED_TRAIN_DATASETS = [
-        DJI_ZOOMED.FINAL,
-        DJI_ZOOMED.CENTRAL_CHINA,
-        DJI_ZOOMED.NORTH_CHINA,
-        DJI_ZOOMED.SOUTH_CHINA,
-    ]
-
-    TWITCH_DJI_TRAIN_DATASETS = TWITCH_TRAIN_DATASETS + DJI_TRAIN_DATASETS
-    TWITCH_DJI_ZOOMED_TRAIN_DATASETS = TWITCH_TRAIN_DATASETS + DJI_TRAIN_DATASETS
+    @property
+    def TWITCH_TRAIN_DATASETS(self):
+        return [
+            self.TWITCH.T470149066,
+            self.TWITCH.T470150052,
+            self.TWITCH.T470152289,
+            self.TWITCH.T470153081,
+            self.TWITCH.T470158483,
+        ]
+
+    @property
+    def TWITCH_VALIDATION_DATASETS(self):
+        return [self.TWITCH.T470152932, self.TWITCH.T470149568]
+
+    @property
+    def TWITCH_TEST_DATASETS(self):
+        return [self.TWITCH.T470152838, self.TWITCH.T470151286]
+
+    @property
+    def DJI_TRAIN_DATASETS(self):
+        return [self.DJI.FINAL, self.DJI.CENTRAL_CHINA, self.DJI.NORTH_CHINA, self.DJI.SOUTH_CHINA]
+
+    @property
+    def DJI_ZOOMED_TRAIN_DATASETS(self):
+        return [
+            self.DJI_ZOOMED.FINAL,
+            self.DJI_ZOOMED.CENTRAL_CHINA,
+            self.DJI_ZOOMED.NORTH_CHINA,
+            self.DJI_ZOOMED.SOUTH_CHINA,
+        ]
+
+    @property
+    def TWITCH_DJI_TRAIN_DATASETS(self):
+        return self.TWITCH_TRAIN_DATASETS + self.DJI_TRAIN_DATASETS
+
+    @property
+    def TWITCH_DJI_ZOOMED_TRAIN_DATASETS(self):
+        return self.TWITCH_TRAIN_DATASETS + self.DJI_TRAIN_DATASETS
 
     DEFAULT_TEST_DATASETS = TWITCH_TEST_DATASETS
     DEFAULT_VALIDATION_DATASETS = TWITCH_VALIDATION_DATASETS
@@ -41,3 +60,8 @@ class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]):
 
 
 ROCODatasetsZoo = ROCODatasetsZoo()
+
+
+if __name__ == "__main__":
+    ROCODatasetsZoo.DEFAULT_TEST_DATASETS[0].build_lazy()
+    ROCODatasetsZoo.DEFAULT_TEST_DATASETS[0].build_lazy()
-- 
GitLab