From 398d5827e1b4e2d08cd9f8da01f8022d80c2cebc Mon Sep 17 00:00:00 2001 From: Mathieu Beligon <mathieu@feedly.com> Date: Sun, 19 Jan 2020 11:24:35 -0500 Subject: [PATCH] [common] (scripts) Add a script to create tensorflow records from VOC dset --- .../dataset/tensorflow_record.py | 91 +++++++++++++++++++ .../scripts/create_tensorflow_records.py | 22 +++++ 2 files changed, 113 insertions(+) create mode 100644 common/research_common/dataset/tensorflow_record.py create mode 100644 common/research_common/scripts/create_tensorflow_records.py diff --git a/common/research_common/dataset/tensorflow_record.py b/common/research_common/dataset/tensorflow_record.py new file mode 100644 index 0000000..2b4ffc1 --- /dev/null +++ b/common/research_common/dataset/tensorflow_record.py @@ -0,0 +1,91 @@ +import hashlib +from pathlib import Path +from typing import Dict, Any, Iterable + +import tensorflow as tf +from lxml import etree +from tensorflow_core.python.lib.io import python_io +from tqdm import tqdm + +from object_detection.utils.dataset_util import ( + float_list_feature, + bytes_feature, + int64_feature, + bytes_list_feature, + int64_list_feature, + recursive_parse_xml_to_dict, +) +from object_detection.utils.label_map_util import get_label_map_dict +from research_common.constants import TENSORFLOW_RECORDS_DIR +from research_common.dataset.dataset import Dataset +from research_common.tensorflow_utils import patch_tf_v2 + +patch_tf_v2() # FIXME: Needed for version compatibility + + +class TensorflowExampleFactory: + def __init__(self, dataset: Dataset): + self.dataset = dataset + self.label_map = get_label_map_dict(str(TENSORFLOW_RECORDS_DIR / "label_map.pbtxt")) + + def from_annotation_path(self, annotation_path: Path) -> tf.train.Example: + annotation = self._load_annotation(annotation_path) + return self.from_annotation(annotation, annotation_path.stem) + + def from_annotation(self, annotation: Dict[str, Any], img_name: str) -> tf.train.Example: + full_path = (self.dataset.images_dir_path / img_name).with_suffix(".jpg") + encoded_jpg = full_path.read_bytes() + key = hashlib.sha256(encoded_jpg).hexdigest() + + width = int(annotation["size"]["width"]) + height = int(annotation["size"]["height"]) + + xmin = [] + ymin = [] + xmax = [] + ymax = [] + classes = [] + classes_text = [] + for obj in annotation.get("object", []): + xmin.append(float(obj["bndbox"]["xmin"]) / width) + ymin.append(float(obj["bndbox"]["ymin"]) / height) + xmax.append(float(obj["bndbox"]["xmax"]) / width) + ymax.append(float(obj["bndbox"]["ymax"]) / height) + classes_text.append(obj["name"].encode("utf8")) + classes.append(self.label_map[obj["name"]]) + + return tf.train.Example( + features=tf.train.Features( + feature={ + "image/height": int64_feature(height), + "image/width": int64_feature(width), + "image/key/sha256": bytes_feature(key.encode("utf8")), + "image/encoded": bytes_feature(encoded_jpg), + "image/format": bytes_feature("jpeg".encode("utf8")), + "image/object/bbox/xmin": float_list_feature(xmin), + "image/object/bbox/xmax": float_list_feature(xmax), + "image/object/bbox/ymin": float_list_feature(ymin), + "image/object/bbox/ymax": float_list_feature(ymax), + "image/object/class/text": bytes_list_feature(classes_text), + "image/object/class/label": int64_list_feature(classes), + } + ) + ) + + @staticmethod + def _load_annotation(annotation_path: Path) -> Dict[str, Any]: + xml = etree.fromstring(annotation_path.read_text()) + return recursive_parse_xml_to_dict(xml)["annotation"] + + +def create_tf_record_from_datasets(datasets: Iterable[Dataset], name: str): + writer = python_io.TFRecordWriter(str(TENSORFLOW_RECORDS_DIR / f"{name}.record")) + for dataset in datasets: + example_factory = TensorflowExampleFactory(dataset) + for annotation_path in tqdm(dataset.annotation_paths, desc=dataset.dataset_name): + writer.write(example_factory.from_annotation_path(annotation_path).SerializeToString()) + writer.close() + + +def create_tf_record_from_dataset(dataset: Dataset): + create_tf_record_from_datasets([dataset], name=dataset.dataset_name) diff --git a/common/research_common/scripts/create_tensorflow_records.py b/common/research_common/scripts/create_tensorflow_records.py new file mode 100644 index 0000000..9c35c53 --- /dev/null +++ b/common/research_common/scripts/create_tensorflow_records.py @@ -0,0 +1,22 @@ +from research_common.dataset.roco.roco_datasets import ROCODataset +from research_common.dataset.split import Split +from research_common.dataset.split_dataset import SplitDataset +from research_common.dataset.tensorflow_record import create_tf_record_from_dataset, create_tf_record_from_datasets + + +def create_one_record_per_roco_dset(): + for roco_set in ROCODataset: + for split in Split: + create_tf_record_from_dataset(SplitDataset(roco_set, split)) + + +def create_one_roco_record(): + for split in Split: + create_tf_record_from_datasets( + [SplitDataset(roco_dset, split) for roco_dset in ROCODataset], f"DJI_ROCO_{split.name}" + ) + + +if __name__ == "__main__": + create_one_record_per_roco_dset() + create_one_roco_record() -- GitLab