diff --git a/common/research/common/dataset/tensorflow_record.py b/common/research/common/dataset/tensorflow_record.py index 7e835380b534ec17c2230b17d30fb2110c1652d5..762cf2f68371275889cc15031b46e3c06780c38b 100644 --- a/common/research/common/dataset/tensorflow_record.py +++ b/common/research/common/dataset/tensorflow_record.py @@ -1,44 +1,48 @@ import hashlib +from pathlib import Path from shutil import move -from typing import Iterable +from typing import List import tensorflow as tf -from tensorflow_core.python.lib.io import python_io -from tqdm import tqdm - -from polystar.common.models.image_annotation import ImageAnnotation from polystar.common.models.label_map import label_map from research.common.constants import TENSORFLOW_RECORDS_DIR -from research.common.dataset.roco_dataset import ROCODataset +from research.common.datasets.roco.directory_roco_dataset import \ + DirectoryROCODataset +from research.common.datasets.roco.roco_annotation import ROCOAnnotation +from tensorflow_core.python.lib.io import python_io +from tqdm import tqdm class TensorflowRecordFactory: @staticmethod - def from_datasets(datasets: Iterable[ROCODataset], name: str): + def from_datasets(datasets: List[DirectoryROCODataset], prefix: str = ""): + name = prefix + "_".join(d.name for d in datasets) writer = python_io.TFRecordWriter(str(TENSORFLOW_RECORDS_DIR / f"{name}.record")) c = 0 - for dataset in datasets: - for image_annotation in tqdm(dataset.image_annotations, desc=dataset.dataset_name, total=len(dataset)): - writer.write(_example_from_image_annotation(image_annotation).SerializeToString()) + for dataset in tqdm(datasets, desc=name, total=len(datasets)): + for image_path, annotation in tqdm( + dataset.unloaded_items(), desc=dataset.name, total=len(dataset), unit="img", leave=False + ): + writer.write(_example_from_image_annotation(image_path, annotation).SerializeToString()) c += 1 writer.close() move(str(TENSORFLOW_RECORDS_DIR / f"{name}.record"), str(TENSORFLOW_RECORDS_DIR / f"{name}_{c}_imgs.record")) @staticmethod - def from_dataset(dataset: ROCODataset): - TensorflowRecordFactory.from_datasets([dataset], name=dataset.dataset_name) + def from_dataset(dataset: DirectoryROCODataset, prefix: str = ""): + TensorflowRecordFactory.from_datasets([dataset], prefix) -def _example_from_image_annotation(image_annotation: ImageAnnotation) -> tf.train.Example: - image_name = image_annotation.image_path.name - encoded_jpg = image_annotation.image_path.read_bytes() +def _example_from_image_annotation(image_path: Path, annotation: ROCOAnnotation) -> tf.train.Example: + image_name = image_path.name + encoded_jpg = image_path.read_bytes() key = hashlib.sha256(encoded_jpg).hexdigest() - width, height = image_annotation.width, image_annotation.height + width, height = annotation.w, annotation.h xmin, ymin, xmax, ymax, classes, classes_text = [], [], [], [], [], [] - for obj in image_annotation.objects: + for obj in annotation.objects: xmin.append(float(obj.box.x1) / width) ymin.append(float(obj.box.y1) / height) xmax.append(float(obj.box.x2) / width) diff --git a/common/research/common/datasets/image_dataset.py b/common/research/common/datasets/image_dataset.py index 70f8fb2715e2da72134167e71ad985d4f909336e..a730661b80c03c2fe5ff63dd2101d388838a71fe 100644 --- a/common/research/common/datasets/image_dataset.py +++ b/common/research/common/datasets/image_dataset.py @@ -15,6 +15,11 @@ class ImageDataset(Generic[TargetT], Iterable[Tuple[Image, TargetT]]): def __iter__(self) -> Iterator[Tuple[Image, TargetT]]: return zip(self.images, self.targets) + def __len__(self): + if not self._is_loaded: + self._load_data() + return len(self.images) + def __str__(self): return f"<{self.__class__.__name__} {self.name}>" @@ -33,7 +38,10 @@ class ImageDataset(Generic[TargetT], Iterable[Tuple[Image, TargetT]]): def _load_data(self): if self._is_loaded: return - self._images, self._targets = map(list, zip(*self)) + self._images, self._targets = [], [] + for image, target in self: + self._images.append(image) + self._targets.append(target) self._check_consistency() def _check_consistency(self): diff --git a/common/research/common/datasets/roco/directory_roco_dataset.py b/common/research/common/datasets/roco/directory_roco_dataset.py index ba8f09ec04145f60088a73486921b9c2752f09fb..260b5d47f6aa547cbfd63c04e0ef8be2b9ed3b72 100644 --- a/common/research/common/datasets/roco/directory_roco_dataset.py +++ b/common/research/common/datasets/roco/directory_roco_dataset.py @@ -33,15 +33,13 @@ class DirectoryROCODataset(ROCODataset): def __len__(self): return ilen(self.annotation_paths) - def __iter__(self) -> Iterable[Tuple[Image, ROCOAnnotation]]: + def unloaded_items(self) -> Iterable[Tuple[Path, ROCOAnnotation]]: for annotation_file in self.annotation_paths: - yield self._load_from_annotation_file(annotation_file) + yield self.images_dir_path / f"{annotation_file.stem}.jpg", ROCOAnnotation.from_xml_file(annotation_file) - def _load_from_annotation_file(self, annotation_file: Path) -> Tuple[Image, ROCOAnnotation]: - return ( - Image.from_path(self.images_dir_path / f"{annotation_file.stem}.jpg"), - ROCOAnnotation.from_xml_file(annotation_file), - ) + def __iter__(self) -> Iterable[Tuple[Image, ROCOAnnotation]]: + for image_path, roco_annotation in self.unloaded_items(): + yield Image.from_path(image_path), roco_annotation def save_one(self, image: Image, annotation: ROCOAnnotation): Image.save(image, self.images_dir_path / f"{annotation.name}.jpg") diff --git a/common/research/common/scripts/create_tensorflow_records.py b/common/research/common/scripts/create_tensorflow_records.py index 9ca8456ae221659545d270e34a418205bcb4c305..6983ff3302e9e18c6c571e0372b05b07f35054a8 100644 --- a/common/research/common/scripts/create_tensorflow_records.py +++ b/common/research/common/scripts/create_tensorflow_records.py @@ -1,52 +1,51 @@ from itertools import chain -from research.common.dataset.dji.dji_roco_datasets import DJIROCODataset -from research.common.dataset.dji.dji_roco_zoomed_datasets import DJIROCOZoomedDataset from research.common.dataset.tensorflow_record import TensorflowRecordFactory -from research.common.dataset.twitch.twitch_roco_datasets import TwitchROCODataset -from research.common.dataset.union_dataset import UnionDataset +from research.common.datasets.roco.zoo.dji import DJIROCODatasets +from research.common.datasets.roco.zoo.dji_zoomed import DJIROCOZoomedDatasets +from research.common.datasets.roco.zoo.roco_datasets_zoo import ROCODatasetsZoo +from research.common.datasets.roco.zoo.twitch import TwitchROCODatasets def create_one_record_per_roco_dset(): - for roco_set in chain(DJIROCODataset, DJIROCOZoomedDataset, TwitchROCODataset): + for roco_set in chain(*(datasets for datasets in ROCODatasetsZoo())): TensorflowRecordFactory.from_dataset(roco_set) def create_twitch_records(): - TensorflowRecordFactory.from_dataset( - UnionDataset( - TwitchROCODataset.TWITCH_470149568, - TwitchROCODataset.TWITCH_470150052, - TwitchROCODataset.TWITCH_470151286, - TwitchROCODataset.TWITCH_470152289, - TwitchROCODataset.TWITCH_470152730, - ) + TensorflowRecordFactory.from_datasets( + [ + TwitchROCODatasets.TWITCH_470149568, + TwitchROCODatasets.TWITCH_470150052, + TwitchROCODatasets.TWITCH_470151286, + TwitchROCODatasets.TWITCH_470152289, + TwitchROCODatasets.TWITCH_470152730, + ], + "Twitch_Train_", ) - TensorflowRecordFactory.from_dataset( - UnionDataset( - TwitchROCODataset.TWITCH_470152838, TwitchROCODataset.TWITCH_470153081, TwitchROCODataset.TWITCH_470158483, - ) + TensorflowRecordFactory.from_datasets( + [TwitchROCODatasets.TWITCH_470152838, TwitchROCODatasets.TWITCH_470153081, TwitchROCODatasets.TWITCH_470158483], + "Twitch_Test_", ) def create_dji_records(): - TensorflowRecordFactory.from_dataset( - UnionDataset(DJIROCODataset.CentralChina, DJIROCODataset.NorthChina, DJIROCODataset.SouthChina) + TensorflowRecordFactory.from_datasets( + [DJIROCODatasets.CentralChina, DJIROCODatasets.NorthChina, DJIROCODatasets.SouthChina], "DJI_Train_" ) - TensorflowRecordFactory.from_dataset(DJIROCODataset.Final) + TensorflowRecordFactory.from_dataset(DJIROCODatasets.Final, "DJI_Test_") def create_dji_zoomed_records(): - TensorflowRecordFactory.from_dataset( - UnionDataset( - DJIROCOZoomedDataset.CentralChina, DJIROCOZoomedDataset.NorthChina, DJIROCOZoomedDataset.SouthChina - ) + TensorflowRecordFactory.from_datasets( + [DJIROCOZoomedDatasets.CentralChina, DJIROCOZoomedDatasets.NorthChina, DJIROCOZoomedDatasets.SouthChina], + "DJIZoomedV2_Train_", ) - TensorflowRecordFactory.from_dataset(DJIROCOZoomedDataset.Final) + TensorflowRecordFactory.from_dataset(DJIROCOZoomedDatasets.Final, "DJIZoomedV2_Test_") if __name__ == "__main__": - create_one_record_per_roco_dset() + # create_one_record_per_roco_dset() create_twitch_records() create_dji_records() create_dji_zoomed_records()