From bce85afffd677b9da5f4921e8865c0c0c941bf61 Mon Sep 17 00:00:00 2001
From: Mathieu Beligon <mathieu@feedly.com>
Date: Mon, 22 Mar 2021 19:53:58 -0400
Subject: [PATCH] [ROCO ZOO] Fix property typing

---
 .../datasets/roco/zoo/roco_dataset_zoo.py     | 21 ++++++++++---------
 1 file changed, 11 insertions(+), 10 deletions(-)

diff --git a/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py b/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py
index 3ceab94..46bb835 100644
--- a/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py
+++ b/src/research/common/datasets/roco/zoo/roco_dataset_zoo.py
@@ -1,4 +1,4 @@
-from typing import Iterable, Iterator, Type
+from typing import Iterable, Iterator, List, Type
 
 from research.common.datasets.roco.roco_datasets import ROCODatasets
 from research.common.datasets.roco.zoo.dji import DJIROCODatasets
@@ -7,13 +7,14 @@ from research.common.datasets.roco.zoo.twitch import TwitchROCODatasets
 
 
 # FIXME: find a better way to do that (builder need to be instantiated once per call)
-class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]):
+# FIXME: improve the singleton pattern here
+class ROCODatasetsZooClass(Iterable[Type[ROCODatasets]]):
     DJI_ZOOMED = DJIROCOZoomedDatasets
     DJI = DJIROCODatasets
     TWITCH = TwitchROCODatasets
 
     @property
-    def TWITCH_TRAIN_DATASETS(self):
+    def TWITCH_TRAIN_DATASETS(self) -> List[ROCODatasets]:
         return [
             self.TWITCH.T470149066,
             self.TWITCH.T470150052,
@@ -23,19 +24,19 @@ class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]):
         ]
 
     @property
-    def TWITCH_VALIDATION_DATASETS(self):
+    def TWITCH_VALIDATION_DATASETS(self) -> List[ROCODatasets]:
         return [self.TWITCH.T470152932, self.TWITCH.T470149568]
 
     @property
-    def TWITCH_TEST_DATASETS(self):
+    def TWITCH_TEST_DATASETS(self) -> List[ROCODatasets]:
         return [self.TWITCH.T470152838, self.TWITCH.T470151286]
 
     @property
-    def DJI_TRAIN_DATASETS(self):
+    def DJI_TRAIN_DATASETS(self) -> List[ROCODatasets]:
         return [self.DJI.FINAL, self.DJI.CENTRAL_CHINA, self.DJI.NORTH_CHINA, self.DJI.SOUTH_CHINA]
 
     @property
-    def DJI_ZOOMED_TRAIN_DATASETS(self):
+    def DJI_ZOOMED_TRAIN_DATASETS(self) -> List[ROCODatasets]:
         return [
             self.DJI_ZOOMED.FINAL,
             self.DJI_ZOOMED.CENTRAL_CHINA,
@@ -44,11 +45,11 @@ class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]):
         ]
 
     @property
-    def TWITCH_DJI_TRAIN_DATASETS(self):
+    def TWITCH_DJI_TRAIN_DATASETS(self) -> List[ROCODatasets]:
         return self.TWITCH_TRAIN_DATASETS + self.DJI_TRAIN_DATASETS
 
     @property
-    def TWITCH_DJI_ZOOMED_TRAIN_DATASETS(self):
+    def TWITCH_DJI_ZOOMED_TRAIN_DATASETS(self) -> List[ROCODatasets]:
         return self.TWITCH_TRAIN_DATASETS + self.DJI_TRAIN_DATASETS
 
     DEFAULT_TEST_DATASETS = TWITCH_TEST_DATASETS
@@ -59,7 +60,7 @@ class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]):
         return iter((self.DJI, self.DJI_ZOOMED, self.TWITCH))
 
 
-ROCODatasetsZoo = ROCODatasetsZoo()
+ROCODatasetsZoo = ROCODatasetsZooClass()
 
 
 if __name__ == "__main__":
-- 
GitLab