From 459fb01c72c7912d982f700fad1624a6899905f7 Mon Sep 17 00:00:00 2001
From: Mathieu Beligon <mathieu@feedly.com>
Date: Fri, 11 Sep 2020 22:44:48 +0200
Subject: [PATCH] [common] (tf records) adapt to new datasets

---
 .../common/dataset/tensorflow_record.py       | 38 +++++++-------
 .../research/common/datasets/image_dataset.py | 10 +++-
 .../datasets/roco/directory_roco_dataset.py   | 12 ++---
 .../scripts/create_tensorflow_records.py      | 51 +++++++++----------
 4 files changed, 60 insertions(+), 51 deletions(-)

diff --git a/common/research/common/dataset/tensorflow_record.py b/common/research/common/dataset/tensorflow_record.py
index 7e83538..762cf2f 100644
--- a/common/research/common/dataset/tensorflow_record.py
+++ b/common/research/common/dataset/tensorflow_record.py
@@ -1,44 +1,48 @@
 import hashlib
+from pathlib import Path
 from shutil import move
-from typing import Iterable
+from typing import List
 
 import tensorflow as tf
-from tensorflow_core.python.lib.io import python_io
-from tqdm import tqdm
-
-from polystar.common.models.image_annotation import ImageAnnotation
 from polystar.common.models.label_map import label_map
 from research.common.constants import TENSORFLOW_RECORDS_DIR
-from research.common.dataset.roco_dataset import ROCODataset
+from research.common.datasets.roco.directory_roco_dataset import \
+    DirectoryROCODataset
+from research.common.datasets.roco.roco_annotation import ROCOAnnotation
+from tensorflow_core.python.lib.io import python_io
+from tqdm import tqdm
 
 
 class TensorflowRecordFactory:
     @staticmethod
-    def from_datasets(datasets: Iterable[ROCODataset], name: str):
+    def from_datasets(datasets: List[DirectoryROCODataset], prefix: str = ""):
+        name = prefix + "_".join(d.name for d in datasets)
         writer = python_io.TFRecordWriter(str(TENSORFLOW_RECORDS_DIR / f"{name}.record"))
         c = 0
-        for dataset in datasets:
-            for image_annotation in tqdm(dataset.image_annotations, desc=dataset.dataset_name, total=len(dataset)):
-                writer.write(_example_from_image_annotation(image_annotation).SerializeToString())
+        for dataset in tqdm(datasets, desc=name, total=len(datasets)):
+            for image_path, annotation in tqdm(
+                dataset.unloaded_items(), desc=dataset.name, total=len(dataset), unit="img", leave=False
+            ):
+                writer.write(_example_from_image_annotation(image_path, annotation).SerializeToString())
                 c += 1
         writer.close()
         move(str(TENSORFLOW_RECORDS_DIR / f"{name}.record"), str(TENSORFLOW_RECORDS_DIR / f"{name}_{c}_imgs.record"))
 
     @staticmethod
-    def from_dataset(dataset: ROCODataset):
-        TensorflowRecordFactory.from_datasets([dataset], name=dataset.dataset_name)
+    def from_dataset(dataset: DirectoryROCODataset, prefix: str = ""):
+        TensorflowRecordFactory.from_datasets([dataset], prefix)
 
 
-def _example_from_image_annotation(image_annotation: ImageAnnotation) -> tf.train.Example:
-    image_name = image_annotation.image_path.name
-    encoded_jpg = image_annotation.image_path.read_bytes()
+def _example_from_image_annotation(image_path: Path, annotation: ROCOAnnotation) -> tf.train.Example:
+    image_name = image_path.name
+    encoded_jpg = image_path.read_bytes()
     key = hashlib.sha256(encoded_jpg).hexdigest()
 
-    width, height = image_annotation.width, image_annotation.height
+    width, height = annotation.w, annotation.h
 
     xmin, ymin, xmax, ymax, classes, classes_text = [], [], [], [], [], []
 
-    for obj in image_annotation.objects:
+    for obj in annotation.objects:
         xmin.append(float(obj.box.x1) / width)
         ymin.append(float(obj.box.y1) / height)
         xmax.append(float(obj.box.x2) / width)
diff --git a/common/research/common/datasets/image_dataset.py b/common/research/common/datasets/image_dataset.py
index 70f8fb2..a730661 100644
--- a/common/research/common/datasets/image_dataset.py
+++ b/common/research/common/datasets/image_dataset.py
@@ -15,6 +15,11 @@ class ImageDataset(Generic[TargetT], Iterable[Tuple[Image, TargetT]]):
     def __iter__(self) -> Iterator[Tuple[Image, TargetT]]:
         return zip(self.images, self.targets)
 
+    def __len__(self):
+        if not self._is_loaded:
+            self._load_data()
+        return len(self.images)
+
     def __str__(self):
         return f"<{self.__class__.__name__} {self.name}>"
 
@@ -33,7 +38,10 @@ class ImageDataset(Generic[TargetT], Iterable[Tuple[Image, TargetT]]):
     def _load_data(self):
         if self._is_loaded:
             return
-        self._images, self._targets = map(list, zip(*self))
+        self._images, self._targets = [], []
+        for image, target in self:
+            self._images.append(image)
+            self._targets.append(target)
         self._check_consistency()
 
     def _check_consistency(self):
diff --git a/common/research/common/datasets/roco/directory_roco_dataset.py b/common/research/common/datasets/roco/directory_roco_dataset.py
index ba8f09e..260b5d4 100644
--- a/common/research/common/datasets/roco/directory_roco_dataset.py
+++ b/common/research/common/datasets/roco/directory_roco_dataset.py
@@ -33,15 +33,13 @@ class DirectoryROCODataset(ROCODataset):
     def __len__(self):
         return ilen(self.annotation_paths)
 
-    def __iter__(self) -> Iterable[Tuple[Image, ROCOAnnotation]]:
+    def unloaded_items(self) -> Iterable[Tuple[Path, ROCOAnnotation]]:
         for annotation_file in self.annotation_paths:
-            yield self._load_from_annotation_file(annotation_file)
+            yield self.images_dir_path / f"{annotation_file.stem}.jpg", ROCOAnnotation.from_xml_file(annotation_file)
 
-    def _load_from_annotation_file(self, annotation_file: Path) -> Tuple[Image, ROCOAnnotation]:
-        return (
-            Image.from_path(self.images_dir_path / f"{annotation_file.stem}.jpg"),
-            ROCOAnnotation.from_xml_file(annotation_file),
-        )
+    def __iter__(self) -> Iterable[Tuple[Image, ROCOAnnotation]]:
+        for image_path, roco_annotation in self.unloaded_items():
+            yield Image.from_path(image_path), roco_annotation
 
     def save_one(self, image: Image, annotation: ROCOAnnotation):
         Image.save(image, self.images_dir_path / f"{annotation.name}.jpg")
diff --git a/common/research/common/scripts/create_tensorflow_records.py b/common/research/common/scripts/create_tensorflow_records.py
index 9ca8456..6983ff3 100644
--- a/common/research/common/scripts/create_tensorflow_records.py
+++ b/common/research/common/scripts/create_tensorflow_records.py
@@ -1,52 +1,51 @@
 from itertools import chain
 
-from research.common.dataset.dji.dji_roco_datasets import DJIROCODataset
-from research.common.dataset.dji.dji_roco_zoomed_datasets import DJIROCOZoomedDataset
 from research.common.dataset.tensorflow_record import TensorflowRecordFactory
-from research.common.dataset.twitch.twitch_roco_datasets import TwitchROCODataset
-from research.common.dataset.union_dataset import UnionDataset
+from research.common.datasets.roco.zoo.dji import DJIROCODatasets
+from research.common.datasets.roco.zoo.dji_zoomed import DJIROCOZoomedDatasets
+from research.common.datasets.roco.zoo.roco_datasets_zoo import ROCODatasetsZoo
+from research.common.datasets.roco.zoo.twitch import TwitchROCODatasets
 
 
 def create_one_record_per_roco_dset():
-    for roco_set in chain(DJIROCODataset, DJIROCOZoomedDataset, TwitchROCODataset):
+    for roco_set in chain(*(datasets for datasets in ROCODatasetsZoo())):
         TensorflowRecordFactory.from_dataset(roco_set)
 
 
 def create_twitch_records():
-    TensorflowRecordFactory.from_dataset(
-        UnionDataset(
-            TwitchROCODataset.TWITCH_470149568,
-            TwitchROCODataset.TWITCH_470150052,
-            TwitchROCODataset.TWITCH_470151286,
-            TwitchROCODataset.TWITCH_470152289,
-            TwitchROCODataset.TWITCH_470152730,
-        )
+    TensorflowRecordFactory.from_datasets(
+        [
+            TwitchROCODatasets.TWITCH_470149568,
+            TwitchROCODatasets.TWITCH_470150052,
+            TwitchROCODatasets.TWITCH_470151286,
+            TwitchROCODatasets.TWITCH_470152289,
+            TwitchROCODatasets.TWITCH_470152730,
+        ],
+        "Twitch_Train_",
     )
-    TensorflowRecordFactory.from_dataset(
-        UnionDataset(
-            TwitchROCODataset.TWITCH_470152838, TwitchROCODataset.TWITCH_470153081, TwitchROCODataset.TWITCH_470158483,
-        )
+    TensorflowRecordFactory.from_datasets(
+        [TwitchROCODatasets.TWITCH_470152838, TwitchROCODatasets.TWITCH_470153081, TwitchROCODatasets.TWITCH_470158483],
+        "Twitch_Test_",
     )
 
 
 def create_dji_records():
-    TensorflowRecordFactory.from_dataset(
-        UnionDataset(DJIROCODataset.CentralChina, DJIROCODataset.NorthChina, DJIROCODataset.SouthChina)
+    TensorflowRecordFactory.from_datasets(
+        [DJIROCODatasets.CentralChina, DJIROCODatasets.NorthChina, DJIROCODatasets.SouthChina], "DJI_Train_"
     )
-    TensorflowRecordFactory.from_dataset(DJIROCODataset.Final)
+    TensorflowRecordFactory.from_dataset(DJIROCODatasets.Final, "DJI_Test_")
 
 
 def create_dji_zoomed_records():
-    TensorflowRecordFactory.from_dataset(
-        UnionDataset(
-            DJIROCOZoomedDataset.CentralChina, DJIROCOZoomedDataset.NorthChina, DJIROCOZoomedDataset.SouthChina
-        )
+    TensorflowRecordFactory.from_datasets(
+        [DJIROCOZoomedDatasets.CentralChina, DJIROCOZoomedDatasets.NorthChina, DJIROCOZoomedDatasets.SouthChina],
+        "DJIZoomedV2_Train_",
     )
-    TensorflowRecordFactory.from_dataset(DJIROCOZoomedDataset.Final)
+    TensorflowRecordFactory.from_dataset(DJIROCOZoomedDatasets.Final, "DJIZoomedV2_Test_")
 
 
 if __name__ == "__main__":
-    create_one_record_per_roco_dset()
+    # create_one_record_per_roco_dset()
     create_twitch_records()
     create_dji_records()
     create_dji_zoomed_records()
-- 
GitLab