From 92d5ba7cc6b68fc2249e12d9388c09b5ae38f27e Mon Sep 17 00:00:00 2001
From: Mathieu Beligon <mathieu@feedly.com>
Date: Thu, 24 Sep 2020 19:42:46 +0200
Subject: [PATCH] [common] (datasets) refactor: adapt script to new format

---
 .../common/datasets_v3/dataset_builder.py     |  7 ++-
 .../common/datasets_v3/lazy_dataset.py        |  5 +++
 .../roco/roco_annotation.py                   |  0
 .../common/datasets_v3/roco/roco_dataset.py   |  2 +-
 .../datasets_v3/roco/roco_dataset_builder.py  |  2 +-
 .../roco/roco_dataset_descriptor.py           |  4 +-
 .../common/datasets_v3/roco/roco_datasets.py  |  9 ++--
 .../datasets_v3/roco/zoo/roco_dataset_zoo.py  |  4 +-
 ...t_twith_datasets_from_manual_annotation.py | 33 ++++++--------
 .../scripts/create_tensorflow_records.py      | 21 +++++----
 .../common/scripts/improve_roco_by_zooming.py | 44 ++++++++++++-------
 .../common/scripts/visualize_dataset.py       |  8 ++--
 12 files changed, 78 insertions(+), 61 deletions(-)
 rename common/research/common/{datasets => datasets_v3}/roco/roco_annotation.py (100%)

diff --git a/common/research/common/datasets_v3/dataset_builder.py b/common/research/common/datasets_v3/dataset_builder.py
index fdc890f..2bb1665 100644
--- a/common/research/common/datasets_v3/dataset_builder.py
+++ b/common/research/common/datasets_v3/dataset_builder.py
@@ -1,4 +1,4 @@
-from typing import Callable, Generic, Iterable
+from typing import Callable, Generic, Iterable, Iterator, Tuple
 
 from polystar.common.filters.filter_abc import FilterABC
 from polystar.common.filters.pass_through_filter import PassThroughFilter
@@ -9,7 +9,7 @@ from research.common.datasets_v3.lazy_dataset import ExampleT, LazyDataset, Targ
 from research.common.datasets_v3.transform_dataset import TransformDataset
 
 
-class DatasetBuilder(Generic[ExampleT, TargetT]):
+class DatasetBuilder(Generic[ExampleT, TargetT], Iterable[Tuple[ExampleT, TargetT, str]]):
     def __init__(self, dataset: LazyDataset[ExampleT, TargetT]):
         self.dataset = dataset
         self._built = False
@@ -19,6 +19,9 @@ class DatasetBuilder(Generic[ExampleT, TargetT]):
         self._built = True
         return self.dataset
 
+    def __iter__(self) -> Iterator[Tuple[ExampleT, TargetT, str]]:
+        return iter(self.build_lazy())
+
     def build(self) -> Dataset[ExampleT, TargetT]:
         assert not self._built
         self._built = True
diff --git a/common/research/common/datasets_v3/lazy_dataset.py b/common/research/common/datasets_v3/lazy_dataset.py
index c240969..4449ba7 100644
--- a/common/research/common/datasets_v3/lazy_dataset.py
+++ b/common/research/common/datasets_v3/lazy_dataset.py
@@ -15,3 +15,8 @@ class LazyDataset(Generic[ExampleT, TargetT], Iterable[Tuple[ExampleT, TargetT,
 
     def __len__(self):
         raise NotImplemented()
+
+    def __str__(self):
+        return f"dataset {self.name}"
+
+    __repr__ = __str__
diff --git a/common/research/common/datasets/roco/roco_annotation.py b/common/research/common/datasets_v3/roco/roco_annotation.py
similarity index 100%
rename from common/research/common/datasets/roco/roco_annotation.py
rename to common/research/common/datasets_v3/roco/roco_annotation.py
diff --git a/common/research/common/datasets_v3/roco/roco_dataset.py b/common/research/common/datasets_v3/roco/roco_dataset.py
index e365294..646c6c5 100644
--- a/common/research/common/datasets_v3/roco/roco_dataset.py
+++ b/common/research/common/datasets_v3/roco/roco_dataset.py
@@ -1,9 +1,9 @@
 from pathlib import Path
 
 from polystar.common.models.image import Image
-from research.common.datasets.roco.roco_annotation import ROCOAnnotation
 from research.common.datasets_v3.dataset import Dataset
 from research.common.datasets_v3.lazy_dataset import LazyDataset
+from research.common.datasets_v3.roco.roco_annotation import ROCOAnnotation
 
 LazyROCOFileDataset = LazyDataset[Path, ROCOAnnotation]
 ROCOFileDataset = Dataset[Path, ROCOAnnotation]
diff --git a/common/research/common/datasets_v3/roco/roco_dataset_builder.py b/common/research/common/datasets_v3/roco/roco_dataset_builder.py
index ce09a0d..73ebc12 100644
--- a/common/research/common/datasets_v3/roco/roco_dataset_builder.py
+++ b/common/research/common/datasets_v3/roco/roco_dataset_builder.py
@@ -1,7 +1,7 @@
 from pathlib import Path
 
-from research.common.datasets.roco.roco_annotation import ROCOAnnotation
 from research.common.datasets_v3.image_file_dataset_builder import DirectoryDatasetBuilder
+from research.common.datasets_v3.roco.roco_annotation import ROCOAnnotation
 
 
 class ROCODatasetBuilder(DirectoryDatasetBuilder):
diff --git a/common/research/common/datasets_v3/roco/roco_dataset_descriptor.py b/common/research/common/datasets_v3/roco/roco_dataset_descriptor.py
index f72ff11..0904aad 100644
--- a/common/research/common/datasets_v3/roco/roco_dataset_descriptor.py
+++ b/common/research/common/datasets_v3/roco/roco_dataset_descriptor.py
@@ -3,10 +3,10 @@ from pathlib import Path
 from typing import Dict
 
 from pandas import DataFrame
-from tqdm import tqdm
 
 from polystar.common.models.object import Armor, ObjectType
 from polystar.common.utils.markdown import MarkdownFile
+from polystar.common.utils.tqdm import smart_tqdm
 from research.common.datasets_v3.roco.roco_dataset import LazyROCOFileDataset
 from research.common.datasets_v3.roco.zoo.roco_dataset_zoo import ROCODatasetsZoo
 
@@ -29,7 +29,7 @@ class ROCODatasetStats:
         rv.armors_color2num2count = {c: {n: 0 for n in range(10)} for c in colors}
         for c in colors:
             rv.armors_color2num2count[c]["total"] = 0
-        for (_, annotation, _) in tqdm(dataset, desc=dataset.name, unit="frame"):
+        for (_, annotation, _) in smart_tqdm(dataset, desc=dataset.name, unit="frame"):
             rv.n_images += 1
             rv.n_runes += annotation.has_rune
             for obj in annotation.objects:
diff --git a/common/research/common/datasets_v3/roco/roco_datasets.py b/common/research/common/datasets_v3/roco/roco_datasets.py
index 7c7767d..d64b454 100644
--- a/common/research/common/datasets_v3/roco/roco_datasets.py
+++ b/common/research/common/datasets_v3/roco/roco_datasets.py
@@ -1,10 +1,10 @@
 from abc import abstractmethod
-from enum import Enum
+from enum import Enum, EnumMeta
 from pathlib import Path
 from typing import Iterator
 
 from polystar.common.utils.str_utils import snake2camel
-from research.common.datasets.roco.roco_annotation import ROCOAnnotation
+from research.common.datasets_v3.roco.roco_annotation import ROCOAnnotation
 from research.common.datasets_v3.roco.roco_dataset import (
     LazyROCODataset,
     LazyROCOFileDataset,
@@ -45,8 +45,9 @@ class ROCODatasets(Enum):
     def datasets_dir(cls) -> Path:  # Fixme: in python 37, we can define a class var using the _ignore_ attribute
         pass
 
-    def __iter__(self) -> Iterator["ROCODatasets"]:  # needed for pycharm typing, dont know why
-        return self.__iter__()
+    @classmethod
+    def __iter__(cls) -> Iterator["ROCODatasets"]:  # needed for pycharm typing, dont know why
+        return EnumMeta.__iter__(cls)
 
     @classmethod
     def union(cls) -> UnionLazyDataset[Path, ROCOAnnotation]:
diff --git a/common/research/common/datasets_v3/roco/zoo/roco_dataset_zoo.py b/common/research/common/datasets_v3/roco/zoo/roco_dataset_zoo.py
index 4e1a3ba..ba571b9 100644
--- a/common/research/common/datasets_v3/roco/zoo/roco_dataset_zoo.py
+++ b/common/research/common/datasets_v3/roco/zoo/roco_dataset_zoo.py
@@ -1,4 +1,4 @@
-from typing import Iterable, Type
+from typing import Iterable, Iterator, Type
 
 from research.common.datasets_v3.roco.roco_datasets import ROCODatasets
 from research.common.datasets_v3.roco.zoo.dji import DJIROCODatasets
@@ -11,7 +11,7 @@ class ROCODatasetsZoo(Iterable[Type[ROCODatasets]]):
     DJI = DJIROCODatasets
     TWITCH = TwitchROCODatasets
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[Type[ROCODatasets]]:
         return iter((self.DJI, self.DJI_ZOOMED, self.TWITCH))
 
 
diff --git a/common/research/common/scripts/construct_twith_datasets_from_manual_annotation.py b/common/research/common/scripts/construct_twith_datasets_from_manual_annotation.py
index 459ae3d..0cbabc4 100644
--- a/common/research/common/scripts/construct_twith_datasets_from_manual_annotation.py
+++ b/common/research/common/scripts/construct_twith_datasets_from_manual_annotation.py
@@ -1,15 +1,10 @@
 from os import remove
 from shutil import copy, make_archive, move, rmtree
 
-from research.common.constants import (TWITCH_DSET_DIR,
-                                       TWITCH_DSET_ROBOTS_VIEWS_DIR,
-                                       TWITCH_ROBOTS_VIEWS_DIR)
-from research.common.datasets.roco.directory_roco_dataset import \
-    DirectoryROCODataset
-from research.common.datasets.roco.roco_dataset_descriptor import \
-    make_markdown_dataset_report
-from research.common.scripts.construct_dataset_from_manual_annotation import \
-    construct_dataset_from_manual_annotations
+from research.common.constants import TWITCH_DSET_DIR, TWITCH_DSET_ROBOTS_VIEWS_DIR, TWITCH_ROBOTS_VIEWS_DIR
+from research.common.datasets_v3.roco.roco_dataset_builder import ROCODatasetBuilder
+from research.common.datasets_v3.roco.roco_dataset_descriptor import make_markdown_dataset_report
+from research.common.scripts.construct_dataset_from_manual_annotation import construct_dataset_from_manual_annotations
 from research.common.scripts.correct_annotations import AnnotationFileCorrector
 
 
@@ -25,15 +20,15 @@ def _correct_manual_annotations():
 
 
 def _extract_runes_images():
-    all_twitch_dataset = _get_mixed_dataset()
+    all_twitch_dataset = _get_mixed_dataset_builder()
     for image_file, annotation, _ in all_twitch_dataset:
         if annotation.has_rune:
             copy(str(image_file), str(TWITCH_DSET_DIR / "runes" / image_file.name))
 
 
 def _separate_twitch_videos():
-    all_twitch_dataset = _get_mixed_dataset()
-    for image_file, annotation, _ in all_twitch_dataset:
+    all_twitch_dataset_builder = _get_mixed_dataset_builder()
+    for image_file, annotation, _ in all_twitch_dataset_builder:
         video_name = image_file.name.split("-")[0]
         dset_path = TWITCH_DSET_ROBOTS_VIEWS_DIR / video_name
         images_path = dset_path / "image"
@@ -42,7 +37,7 @@ def _separate_twitch_videos():
         annotations_path.mkdir(exist_ok=True, parents=True)
         move(str(image_file), str(images_path / image_file.name))
         xml_name = f"{image_file.stem}.xml"
-        move(str(all_twitch_dataset.annotations_dir / xml_name), str(annotations_path / xml_name))
+        move(str(all_twitch_dataset_builder.annotations_dir / xml_name), str(annotations_path / xml_name))
     if list((TWITCH_DSET_ROBOTS_VIEWS_DIR / "image").glob("*")):
         raise Exception(f"Some images remains unmoved")
     for remaining_file in (TWITCH_DSET_ROBOTS_VIEWS_DIR / "image_annotation").glob("*"):
@@ -53,19 +48,19 @@ def _separate_twitch_videos():
 
 
 def _make_global_report():
-    all_twitch_dataset = _get_mixed_dataset()
-    make_markdown_dataset_report(all_twitch_dataset, all_twitch_dataset.main_dir)
+    all_twitch_dataset_builder = _get_mixed_dataset_builder()
+    make_markdown_dataset_report(all_twitch_dataset_builder.build_lazy(), all_twitch_dataset_builder.main_dir)
 
 
-def _get_mixed_dataset() -> DirectoryROCODataset:
-    return DirectoryROCODataset(TWITCH_DSET_ROBOTS_VIEWS_DIR, "Twitch")
+def _get_mixed_dataset_builder() -> ROCODatasetBuilder:
+    return ROCODatasetBuilder(TWITCH_DSET_ROBOTS_VIEWS_DIR, "Twitch")
 
 
 def _make_separate_reports():
     for video_dset_path in TWITCH_DSET_ROBOTS_VIEWS_DIR.glob("*"):
         if video_dset_path.is_dir():
-            twitch_dset = DirectoryROCODataset(video_dset_path, f"TWITCH_{video_dset_path.name}")
-            make_markdown_dataset_report(twitch_dset, twitch_dset.main_dir)
+            twitch_dset = ROCODatasetBuilder(video_dset_path, f"TWITCH_{video_dset_path.name}")
+            make_markdown_dataset_report(twitch_dset.build_lazy(), twitch_dset.main_dir)
 
 
 if __name__ == "__main__":
diff --git a/common/research/common/scripts/create_tensorflow_records.py b/common/research/common/scripts/create_tensorflow_records.py
index 7563a0a..768b1c2 100644
--- a/common/research/common/scripts/create_tensorflow_records.py
+++ b/common/research/common/scripts/create_tensorflow_records.py
@@ -1,15 +1,14 @@
-from itertools import chain
-
 from research.common.dataset.tensorflow_record import TensorflowRecordFactory
-from research.common.datasets.roco.zoo.dji import DJIROCODatasets
-from research.common.datasets.roco.zoo.dji_zoomed import DJIROCOZoomedDatasets
-from research.common.datasets.roco.zoo.roco_datasets_zoo import ROCODatasetsZoo
-from research.common.datasets.roco.zoo.twitch import TwitchROCODatasets
+from research.common.datasets_v3.roco.zoo.dji import DJIROCODatasets
+from research.common.datasets_v3.roco.zoo.dji_zoomed import DJIROCOZoomedDatasets
+from research.common.datasets_v3.roco.zoo.roco_dataset_zoo import ROCODatasetsZoo
+from research.common.datasets_v3.roco.zoo.twitch import TwitchROCODatasets
 
 
 def create_one_record_per_roco_dset():
-    for roco_set in chain(*(datasets for datasets in ROCODatasetsZoo())):
-        TensorflowRecordFactory.from_dataset(roco_set)
+    for datasets in ROCODatasetsZoo:
+        for dataset in datasets:
+            TensorflowRecordFactory.from_dataset(dataset)
 
 
 def create_twitch_records():
@@ -30,17 +29,17 @@ def create_twitch_records():
 
 def create_dji_records():
     TensorflowRecordFactory.from_datasets(
-        [DJIROCODatasets.CentralChina, DJIROCODatasets.NorthChina, DJIROCODatasets.SouthChina], "DJI_Train_"
+        [DJIROCODatasets.CENTRAL_CHINA, DJIROCODatasets.NORTH_CHINA, DJIROCODatasets.SOUTH_CHINA], "DJI_Train_"
     )
     TensorflowRecordFactory.from_dataset(DJIROCODatasets.Final, "DJI_Test_")
 
 
 def create_dji_zoomed_records():
     TensorflowRecordFactory.from_datasets(
-        [DJIROCOZoomedDatasets.CentralChina, DJIROCOZoomedDatasets.NorthChina, DJIROCOZoomedDatasets.SouthChina],
+        [DJIROCOZoomedDatasets.CENTRAL_CHINA, DJIROCOZoomedDatasets.NORTH_CHINA, DJIROCOZoomedDatasets.SOUTH_CHINA],
         "DJIZoomedV2_Train_",
     )
-    TensorflowRecordFactory.from_dataset(DJIROCOZoomedDatasets.Final, "DJIZoomedV2_Test_")
+    TensorflowRecordFactory.from_dataset(DJIROCOZoomedDatasets.FINAL, "DJIZoomedV2_Test_")
 
 
 if __name__ == "__main__":
diff --git a/common/research/common/scripts/improve_roco_by_zooming.py b/common/research/common/scripts/improve_roco_by_zooming.py
index ed7d40a..5547422 100644
--- a/common/research/common/scripts/improve_roco_by_zooming.py
+++ b/common/research/common/scripts/improve_roco_by_zooming.py
@@ -1,32 +1,46 @@
+from pathlib import Path
+from typing import Tuple
+
+from polystar.common.models.image import save_image
+from polystar.common.utils.str_utils import camel2snake
 from polystar.common.utils.tqdm import smart_tqdm
+from research.common.constants import DJI_ROCO_ZOOMED_DSET_DIR
 from research.common.dataset.improvement.zoom import Zoomer
-from research.common.dataset.perturbations.image_modifiers.brightness import \
-    BrightnessModifier
-from research.common.dataset.perturbations.image_modifiers.contrast import \
-    ContrastModifier
-from research.common.dataset.perturbations.image_modifiers.saturation import \
-    SaturationModifier
+from research.common.dataset.perturbations.image_modifiers.brightness import BrightnessModifier
+from research.common.dataset.perturbations.image_modifiers.contrast import ContrastModifier
+from research.common.dataset.perturbations.image_modifiers.saturation import SaturationModifier
 from research.common.dataset.perturbations.perturbator import ImagePerturbator
-from research.common.datasets.roco.roco_dataset import ROCOFileDataset
-from research.common.datasets.roco.zoo.dji_zoomed import DJIROCOZoomedDatasets
-from research.common.datasets.roco.zoo.roco_datasets_zoo import ROCODatasetsZoo
+from research.common.datasets_v3.roco.roco_dataset import LazyROCODataset
+from research.common.datasets_v3.roco.zoo.roco_dataset_zoo import ROCODatasetsZoo
 
 
 def improve_dji_roco_dataset_by_zooming_and_perturbating(
-    dset: ROCOFileDataset, zoomer: Zoomer, perturbator: ImagePerturbator
+    dset: LazyROCODataset, zoomer: Zoomer, perturbator: ImagePerturbator
 ):
-    zoomed_dset = DJIROCOZoomedDatasets.make_dataset(dset.name)
-    zoomed_dset.create()
+    image_dir, annotation_dir = _prepare_empty_zoomed_dir(DJI_ROCO_ZOOMED_DSET_DIR / camel2snake(dset.name).lower())
 
-    for img, annotation, name in smart_tqdm(dset.open(), desc=f"Processing {dset}", unit="image", total=len(dset)):
+    for img, annotation, name in smart_tqdm(dset, desc=f"Processing {dset}", unit="image"):
         for zoomed_image, zoomed_annotation, zoomed_name in zoomer.zoom(img, annotation, name):
             zoomed_image = perturbator.perturbate(zoomed_image)
-            zoomed_dset.add(zoomed_image, zoomed_annotation, zoomed_name)
+            save_image(zoomed_image, image_dir / f"{zoomed_name}.jpg")
+            (annotation_dir / f"{zoomed_name}.xml").write_text(zoomed_annotation.to_xml())
 
 
 def improve_all_dji_datasets_by_zooming_and_perturbating(zoomer: Zoomer, perturbator: ImagePerturbator):
     for _dset in ROCODatasetsZoo.DJI:
-        improve_dji_roco_dataset_by_zooming_and_perturbating(zoomer=zoomer, dset=_dset, perturbator=perturbator)
+        improve_dji_roco_dataset_by_zooming_and_perturbating(zoomer=zoomer, dset=_dset.lazy(), perturbator=perturbator)
+
+
+def _prepare_empty_zoomed_dir(dir_path: Path) -> Tuple[Path, Path]:
+    dir_path.mkdir()
+
+    annotation_dir = dir_path / "image_annotation"
+    image_dir = dir_path / "image"
+
+    annotation_dir.mkdir()
+    image_dir.mkdir()
+
+    return image_dir, annotation_dir
 
 
 if __name__ == "__main__":
diff --git a/common/research/common/scripts/visualize_dataset.py b/common/research/common/scripts/visualize_dataset.py
index eda5a19..a74f2e4 100644
--- a/common/research/common/scripts/visualize_dataset.py
+++ b/common/research/common/scripts/visualize_dataset.py
@@ -1,9 +1,9 @@
 from polystar.common.view.plt_results_viewer import PltResultViewer
-from research.common.datasets.roco.roco_dataset import ROCODataset
-from research.common.datasets.roco.zoo.roco_datasets_zoo import ROCODatasetsZoo
+from research.common.datasets_v3.roco.roco_dataset import LazyROCODataset
+from research.common.datasets_v3.roco.zoo.roco_dataset_zoo import ROCODatasetsZoo
 
 
-def visualize_dataset(dataset: ROCODataset, n_images: int):
+def visualize_dataset(dataset: LazyROCODataset, n_images: int):
     viewer = PltResultViewer(dataset.name)
 
     for i, (image, annotation, name) in enumerate(dataset, 1):
@@ -14,4 +14,4 @@ def visualize_dataset(dataset: ROCODataset, n_images: int):
 
 
 if __name__ == "__main__":
-    visualize_dataset(ROCODatasetsZoo.DJI_ZOOMED.CentralChina.open(), 20)
+    visualize_dataset(ROCODatasetsZoo.DJI_ZOOMED.CENTRAL_CHINA.lazy(), 20)
-- 
GitLab