diff --git a/common/research/common/dataset/improvement/zoom.py b/common/research/common/dataset/improvement/zoom.py index b40b766dd4e61c847b90c2ba83f80cf574821f6c..567ec7738b65c94918dcb511256f2a1c19c94f87 100644 --- a/common/research/common/dataset/improvement/zoom.py +++ b/common/research/common/dataset/improvement/zoom.py @@ -1,17 +1,20 @@ from copy import copy from dataclasses import dataclass -from time import time from typing import Iterable, List, Tuple from polystar.common.models.box import Box -from polystar.common.models.image_annotation import ImageAnnotation -from polystar.common.target_pipeline.objects_validators.in_box_validator import InBoxValidator +from polystar.common.models.image import Image +from polystar.common.target_pipeline.objects_validators.in_box_validator import \ + InBoxValidator from polystar.common.view.plt_results_viewer import PltResultViewer -from research.common.dataset.dji.dji_roco_datasets import DJIROCODataset +from research.common.datasets.roco.roco_annotation import ROCOAnnotation +from research.common.datasets.roco.zoo.roco_datasets_zoo import ROCODatasetsZoo -def crop_image_annotation(image_annotation: ImageAnnotation, box: Box, min_coverage: float) -> ImageAnnotation: - objects = InBoxValidator(box, min_coverage).filter(image_annotation.objects, image_annotation.image) +def crop_image_annotation( + image: Image, annotation: ROCOAnnotation, box: Box, min_coverage: float, name: str +) -> Tuple[Image, ROCOAnnotation]: + objects = InBoxValidator(box, min_coverage).filter(annotation.objects, image) objects = [copy(o) for o in objects] for obj in objects: obj.box = Box.from_positions( @@ -20,14 +23,9 @@ def crop_image_annotation(image_annotation: ImageAnnotation, box: Box, min_cover x2=min(box.x2, obj.box.x2 - box.x1), y2=min(box.y2, obj.box.y2 - box.y1), ) - return ImageAnnotation( - image_path=None, - xml_path=None, - width=box.w, - height=box.h, - objects=objects, - has_rune=False, - _image=image_annotation.image[box.y1 : box.y2, box.x1 : box.x2], + return ( + image[box.y1 : box.y2, box.x1 : box.x2], + ROCOAnnotation(w=box.w, h=box.h, objects=objects, has_rune=False, name=name), ) @@ -38,18 +36,21 @@ class Zoomer: max_overlap: float min_coverage: float - def zoom(self, image_annotation: ImageAnnotation) -> Iterable[ImageAnnotation]: - boxes = [obj.box for obj in image_annotation.objects] - boxes = self._create_views_covering(boxes, image_annotation) + def zoom(self, image: Image, annotation: ROCOAnnotation) -> Iterable[Tuple[Image, ROCOAnnotation]]: + boxes = [obj.box for obj in annotation.objects] + boxes = self._create_views_covering(boxes, annotation) boxes = self._remove_overlapping_boxes(boxes) - return (crop_image_annotation(image_annotation, box, self.min_coverage) for box in boxes) + return ( + crop_image_annotation(image, annotation, box, self.min_coverage, name=f"{annotation.name}_zoom_{i}") + for (i, box) in enumerate(boxes, 1) + ) - def _create_views_covering(self, boxes: List[Box], image_annotation: ImageAnnotation) -> List[Box]: + def _create_views_covering(self, boxes: List[Box], annotation: ROCOAnnotation) -> List[Box]: views: List[Box] = [] while boxes: view, boxes = self._find_new_cluster(boxes) - view = self._re_frame_box_with_respect_of(view, views, image_annotation) + view = self._re_frame_box_with_respect_of(view, views, annotation) views.append(view) boxes = self._remove_covered_boxes(boxes, views) @@ -66,7 +67,7 @@ class Zoomer: remaining_boxes.append(box) return cluster, remaining_boxes - def _re_frame_box_with_respect_of(self, box: Box, boxes: List[Box], image_annotation: ImageAnnotation) -> Box: + def _re_frame_box_with_respect_of(self, box: Box, boxes: List[Box], annotation: ROCOAnnotation) -> Box: missing_width = self.w - box.w missing_height = self.h - box.h @@ -80,8 +81,8 @@ class Zoomer: dx = -(missing_width // 2) * (not close_box_on_left) * (1 + close_box_on_right) dy = -(missing_height // 2) * (not close_box_on_top) * (1 + close_box_on_bottom) - x = max(0, min(image_annotation.width - self.w, box.x1 + dx)) - y = max(0, min(image_annotation.height - self.h, box.y1 + dy)) + x = max(0, min(annotation.w - self.w, box.x1 + dx)) + y = max(0, min(annotation.h - self.h, box.y1 + dy)) return Box.from_size(x, y, self.w, self.h) def _remove_covered_boxes(self, boxes: List[Box], views: List[Box]) -> List[Box]: @@ -120,17 +121,11 @@ class Zoomer: if __name__ == "__main__": zoomer = Zoomer(854, 480, 0.15, 0.5) - t = time() - c = 0 - - for i, img in enumerate(DJIROCODataset.CentralChina.image_annotations): + for k, (img, annotation) in enumerate(ROCODatasetsZoo.DJI.NorthChina): viewer = PltResultViewer(f"img {i}") - for res in zoomer.zoom(img): - viewer.display_image_annotation(res) - c += 1 + for (cropped_image, cropped_annotation) in zoomer.zoom(img, annotation): + viewer.display_image_with_objects(cropped_image, cropped_annotation.objects) - if i == 10: + if k == 2: break - - print(time() - t, c) diff --git a/common/research/common/datasets/image_dataset.py b/common/research/common/datasets/image_dataset.py index 6dca7c063f9b1fe93a2931b4edfeff1818b4d224..a9531339b27908fbd248bf9cac9406f5d64932a8 100644 --- a/common/research/common/datasets/image_dataset.py +++ b/common/research/common/datasets/image_dataset.py @@ -6,7 +6,8 @@ TargetT = TypeVar("TargetT") class ImageDataset(Generic[TargetT]): - def __init__(self, images: List[Image] = None, targets: List[TargetT] = None): + def __init__(self, name: str, images: List[Image] = None, targets: List[TargetT] = None): + self.name = name self._targets = targets self._images = images self._check_consistency() @@ -14,6 +15,11 @@ class ImageDataset(Generic[TargetT]): def __iter__(self) -> Iterator[Tuple[Image, TargetT]]: return zip(self.images, self.targets) + def __str__(self): + return f"<{self.__class__.__name__} {self.name}>" + + __repr__ = __str__ + @property def images(self) -> List[Image]: self._load_data() @@ -27,8 +33,7 @@ class ImageDataset(Generic[TargetT]): def _load_data(self): if self._is_loaded: return - images, targets = zip(*self) - self._images, self._targets = list(images), list(targets) + self._images, self._targets = map(list, zip(*self)) self._check_consistency() def _check_consistency(self): diff --git a/common/research/common/datasets/roco/directory_roco_dataset.py b/common/research/common/datasets/roco/directory_roco_dataset.py index b11b707e01699fcd63f2a1801e4b8ded2c294f36..e59163fdd91864804debeebb710d229b4b723df4 100644 --- a/common/research/common/datasets/roco/directory_roco_dataset.py +++ b/common/research/common/datasets/roco/directory_roco_dataset.py @@ -1,14 +1,15 @@ from pathlib import Path from typing import Iterable, Tuple +from more_itertools import ilen from polystar.common.models.image import Image from research.common.datasets.roco.roco_annotation import ROCOAnnotation from research.common.datasets.roco.roco_dataset import ROCODataset class DirectoryROCODataset(ROCODataset): - def __init__(self, dataset_path: Path, dataset_name: str): - self.dataset_name = dataset_name + def __init__(self, dataset_path: Path, name: str): + super().__init__(name) self.dataset_path = dataset_path @property @@ -23,6 +24,9 @@ class DirectoryROCODataset(ROCODataset): def annotations_dir_path(self) -> Path: return self.dataset_path / "image_annotation" + def __len__(self): + return ilen(self.annotation_paths) + def __iter__(self) -> Iterable[Tuple[Image, ROCOAnnotation]]: for annotation_file in self.annotation_paths: yield self._load_from_annotation_file(annotation_file) @@ -32,3 +36,7 @@ class DirectoryROCODataset(ROCODataset): Image.from_path(self.images_dir_path / f"{annotation_file.stem}.jpg"), ROCOAnnotation.from_xml_file(annotation_file), ) + + def save_one(self, image: Image, annotation: ROCOAnnotation): + Image.save(image, self.images_dir_path / f"{annotation.name}.jpg") + (self.annotations_dir_path / f"{annotation.name}.xml").write_text(annotation.to_xml()) diff --git a/common/research/common/datasets/roco/roco_annotation.py b/common/research/common/datasets/roco/roco_annotation.py index 60e069bcdb5093d88bbee7dacc191bc9222da46a..4378679ac7ecc7998bf9e916865944f14ffb7cd5 100644 --- a/common/research/common/datasets/roco/roco_annotation.py +++ b/common/research/common/datasets/roco/roco_annotation.py @@ -2,30 +2,60 @@ import logging from dataclasses import dataclass from pathlib import Path from typing import Dict, List +from xml.dom.minidom import parseString import xmltodict +from dicttoxml import dicttoxml from polystar.common.models.object import Object, ObjectFactory @dataclass class ROCOAnnotation: + name: str + objects: List[Object] has_rune: bool + w: int + h: int + @staticmethod def from_xml_file(xml_file: Path) -> "ROCOAnnotation": try: - return ROCOAnnotation.from_xml_dict(xmltodict.parse(xml_file.read_text())["annotation"]) + return ROCOAnnotation.from_xml_dict(xmltodict.parse(xml_file.read_text())["annotation"], xml_file.stem) except Exception as e: logging.exception(f"Error parsing annotation file {xml_file}") raise e @staticmethod - def from_xml_dict(xml_dict: Dict) -> "ROCOAnnotation": + def from_xml_dict(xml_dict: Dict, name: str) -> "ROCOAnnotation": json_objects = xml_dict.get("object", []) or [] json_objects = json_objects if isinstance(json_objects, list) else [json_objects] roco_json_objects = [obj_json for obj_json in json_objects if not obj_json["name"].startswith("rune")] objects = [ObjectFactory.from_json(obj_json) for obj_json in roco_json_objects] - return ROCOAnnotation(objects=objects, has_rune=len(roco_json_objects) != len(json_objects)) + return ROCOAnnotation( + objects=objects, + has_rune=len(roco_json_objects) != len(json_objects), + w=int(xml_dict["size"]["width"]), + h=int(xml_dict["size"]["height"]), + name=name, + ) + + def to_xml(self) -> str: + return parseString( + dicttoxml( + { + "annotation": { + "size": {"width": self.w, "height": self.h}, + "object": [ObjectFactory.to_json(obj) for obj in self.objects], + } + }, + attr_type=False, + root="annotation", + item_func=lambda x: x, + ) + .replace(b"<object><object>", b"<object>") + .replace(b"</object></object>", b"</object>") + ).toprettyxml() diff --git a/common/research/common/datasets/roco/roco_datasets.py b/common/research/common/datasets/roco/roco_datasets.py index c716d29c0f08d822655415f0135373415b1fc28f..fc0a7d522c1dab294f4a57f76c31940420afa2ed 100644 --- a/common/research/common/datasets/roco/roco_datasets.py +++ b/common/research/common/datasets/roco/roco_datasets.py @@ -1,15 +1,25 @@ -from typing import Any, Tuple +from typing import Any, Iterator, List, Tuple from research.common.dataset.directory_roco_dataset import DirectoryROCODataset class ROCODatasets: - def _make_dataset(dataset_name: str, *args: Any) -> DirectoryROCODataset: + def make_dataset(dataset_name: str, *args: Any) -> DirectoryROCODataset: pass def __init_subclass__(cls, **kwargs): + cls.datasets: List[DirectoryROCODataset] = [] for dataset_name, args in cls.__dict__.items(): - if not callable(args) and not dataset_name.startswith("_"): + if ( + not callable(args) + and not dataset_name.startswith("_") + and dataset_name not in ("make_dataset", "datasets") + ): if not isinstance(args, Tuple): args = (args,) - setattr(cls, dataset_name, cls._make_dataset(dataset_name, *args)) + dataset = cls.make_dataset(dataset_name, *args) + setattr(cls, dataset_name, dataset) + cls.datasets.append(dataset) + + def __iter__(self) -> Iterator[DirectoryROCODataset]: + return self.datasets.__iter__() diff --git a/common/research/common/datasets/roco/zoo/dji.py b/common/research/common/datasets/roco/zoo/dji.py index 4a630e74e4b1209fe187ef1444a1bbb91958cecf..0b7ca9fe2c3b7791706b4c339b66fc9dccc84c19 100644 --- a/common/research/common/datasets/roco/zoo/dji.py +++ b/common/research/common/datasets/roco/zoo/dji.py @@ -11,5 +11,5 @@ class DJIROCODatasets(ROCODatasets): Final = "robomaster_Final Tournament" @staticmethod - def _make_dataset(dataset_name: str, competition_name: str) -> DirectoryROCODataset: + def make_dataset(dataset_name: str, competition_name: str) -> DirectoryROCODataset: return DirectoryROCODataset(DJI_ROCO_DSET_DIR / competition_name, dataset_name) diff --git a/common/research/common/datasets/roco/zoo/dji_zoomed.py b/common/research/common/datasets/roco/zoo/dji_zoomed.py index 009ffac7b105ae978fd96f4259d90205b2157988..39444e3a3726dddf3aadf05ba7a927cfe5cb1df3 100644 --- a/common/research/common/datasets/roco/zoo/dji_zoomed.py +++ b/common/research/common/datasets/roco/zoo/dji_zoomed.py @@ -12,5 +12,5 @@ class DJIROCOZoomedDatasets(ROCODatasets): Final = () @staticmethod - def _make_dataset(dataset_name: str) -> DirectoryROCODataset: + def make_dataset(dataset_name: str) -> DirectoryROCODataset: return DirectoryROCODataset(DJI_ROCO_ZOOMED_DSET_DIR / camel2snake(dataset_name), f"{dataset_name}ZoomedV2") diff --git a/common/research/common/datasets/roco/zoo/twitch.py b/common/research/common/datasets/roco/zoo/twitch.py index dfb4963dfbbf4b673558528e5af30bb90997a381..9d7a0af4f8fa118b2c8a1fa68d170e656eeb2b85 100644 --- a/common/research/common/datasets/roco/zoo/twitch.py +++ b/common/research/common/datasets/roco/zoo/twitch.py @@ -15,6 +15,6 @@ class TwitchROCODatasets(ROCODatasets): TWITCH_470158483 = () @staticmethod - def _make_dataset(dataset_name: str) -> DirectoryROCODataset: + def make_dataset(dataset_name: str) -> DirectoryROCODataset: twitch_id = dataset_name[len("TWITCH_") :] return DirectoryROCODataset(TWITCH_DSET_DIR / "v1" / twitch_id, f"T{twitch_id}") diff --git a/common/research/common/scripts/improve_roco_by_zooming.py b/common/research/common/scripts/improve_roco_by_zooming.py index 9550fa94f4eb301b011d78b1f3f1985819c5200d..0b3d6b1c6163ff38ee59cd5201b049fd16f93b3d 100644 --- a/common/research/common/scripts/improve_roco_by_zooming.py +++ b/common/research/common/scripts/improve_roco_by_zooming.py @@ -1,28 +1,34 @@ -from tqdm import tqdm - -from research.common.dataset.dji.dji_roco_datasets import DJIROCODataset -from research.common.dataset.dji.dji_roco_zoomed_datasets import DJIROCOZoomedDataset from research.common.dataset.improvement.zoom import Zoomer -from research.common.dataset.perturbations.image_modifiers.brightness import BrightnessModifier -from research.common.dataset.perturbations.image_modifiers.contrast import ContrastModifier -from research.common.dataset.perturbations.image_modifiers.saturation import SaturationModifier +from research.common.dataset.perturbations.image_modifiers.brightness import \ + BrightnessModifier +from research.common.dataset.perturbations.image_modifiers.contrast import \ + ContrastModifier +from research.common.dataset.perturbations.image_modifiers.saturation import \ + SaturationModifier from research.common.dataset.perturbations.perturbator import ImagePerturbator +from research.common.datasets.roco.directory_roco_dataset import \ + DirectoryROCODataset +from research.common.datasets.roco.zoo.dji_zoomed import DJIROCOZoomedDatasets +from research.common.datasets.roco.zoo.roco_datasets_zoo import ROCODatasetsZoo +from tqdm import tqdm def improve_dji_roco_dataset_by_zooming_and_perturbating( - dset: DJIROCODataset, zoomer: Zoomer, perturbator: ImagePerturbator + dset: DirectoryROCODataset, zoomer: Zoomer, perturbator: ImagePerturbator ): - zoomed_dset: DJIROCOZoomedDataset = DJIROCOZoomedDataset[dset.name] + zoomed_dset = DJIROCOZoomedDatasets.make_dataset(dset.name) zoomed_dset.dataset_path.mkdir(parents=True) + zoomed_dset.images_dir_path.mkdir() + zoomed_dset.annotations_dir_path.mkdir() - for img in tqdm(dset.image_annotations, desc=f"Processing {dset}", unit="image", total=len(dset)): - for i, zoomed_image in enumerate(zoomer.zoom(img), 1): - zoomed_image._image = perturbator.perturbate(zoomed_image.image) - zoomed_image.save_to_dir(zoomed_dset.dataset_path, f"{img.image_path.stem}_zoom_{i}") + for img, annotation in tqdm(dset, desc=f"Processing {dset}", unit="image", total=len(dset)): + for zoomed_image, zoomed_annotation in zoomer.zoom(img, annotation): + zoomed_image = perturbator.perturbate(zoomed_image) + zoomed_dset.save_one(zoomed_image, zoomed_annotation) def improve_all_dji_datasets_by_zooming_and_perturbating(zoomer: Zoomer, perturbator: ImagePerturbator): - for _dset in DJIROCODataset: + for _dset in ROCODatasetsZoo.DJI: improve_dji_roco_dataset_by_zooming_and_perturbating(zoomer=zoomer, dset=_dset, perturbator=perturbator) diff --git a/common/research/common/scripts/visualize_dataset.py b/common/research/common/scripts/visualize_dataset.py index 13254ff04717c3a57dbe41829158f082912b303c..9cc123b5f85f9cc9289f46f9e8a5f14d1fea5ea0 100644 --- a/common/research/common/scripts/visualize_dataset.py +++ b/common/research/common/scripts/visualize_dataset.py @@ -1,17 +1,17 @@ from polystar.common.view.plt_results_viewer import PltResultViewer -from research.common.dataset.dji.dji_roco_zoomed_datasets import DJIROCOZoomedDataset -from research.common.dataset.roco_dataset import ROCODataset +from research.common.datasets.roco.roco_dataset import ROCODataset +from research.common.datasets.roco.zoo.roco_datasets_zoo import ROCODatasetsZoo def visualize_dataset(dataset: ROCODataset, n_images: int): - viewer = PltResultViewer(dataset.dataset_name) + viewer = PltResultViewer(dataset.name) - for i, image in enumerate(dataset.image_annotations, 1): - viewer.display_image_annotation(image) + for i, (image, annotation) in enumerate(dataset, 1): + viewer.display_image_with_objects(image, annotation.objects) if i == n_images: return if __name__ == "__main__": - visualize_dataset(DJIROCOZoomedDataset.CentralChina, 20) + visualize_dataset(ROCODatasetsZoo.DJI_ZOOMED.CentralChina, 20) diff --git a/common/tests/common/unittests/datasets/test_image_dataset.py b/common/tests/common/unittests/datasets/test_image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..86ee99594ba14ccf503754bc41072a807bbc1db7 --- /dev/null +++ b/common/tests/common/unittests/datasets/test_image_dataset.py @@ -0,0 +1,31 @@ +from unittest import TestCase + +from research.common.datasets.image_dataset import ImageDataset + + +class TestImageDataset(TestCase): + def test_iter(self): + dataset = ImageDataset("test", list(range(5)), list(range(3, 8))) + + self.assertEqual([(0, 3), (1, 4), (2, 5), (3, 6), (4, 7)], list(dataset)) + + def test_auto_load(self): + class FakeDataset(ImageDataset): + def __iter__(self): + return [(0, 2), (1, 4)].__iter__() + + dataset = FakeDataset("test") + + self.assertEqual([0, 1], dataset.images) + self.assertEqual([2, 4], dataset.targets) + + def test_assert(self): + with self.assertRaises(AssertionError): + ImageDataset("test") + + def test_assert_child(self): + class FakeDataset(ImageDataset): + pass + + with self.assertRaises(AssertionError): + FakeDataset("test")