diff --git a/common/research/common/dataset/roco_dataset_descriptor.py b/common/research/common/dataset/roco_dataset_descriptor.py index 1fb8c289bf02f81ab6dd34260d556eaf41dd40e4..f957eeaf5a21afcf55b962e0615996ca0845783a 100644 --- a/common/research/common/dataset/roco_dataset_descriptor.py +++ b/common/research/common/dataset/roco_dataset_descriptor.py @@ -4,13 +4,10 @@ from pathlib import Path from typing import Dict from pandas import DataFrame - from polystar.common.models.object import Armor, ObjectType from polystar.common.utils.markdown import MarkdownFile -from research.common.dataset.dji.dji_roco_datasets import DJIROCODataset -from research.common.dataset.dji.dji_roco_zoomed_datasets import DJIROCOZoomedDataset -from research.common.dataset.roco_dataset import ROCODataset -from research.common.dataset.twitch.twitch_roco_datasets import TwitchROCODataset +from research.common.datasets.roco.roco_dataset import ROCODataset +from research.common.datasets.roco.zoo.roco_datasets_zoo import ROCODatasetsZoo @dataclass @@ -31,7 +28,7 @@ class ROCODatasetStats: rv.armors_color2num2count = {c: {n: 0 for n in range(10)} for c in colors} for c in colors: rv.armors_color2num2count[c]["total"] = 0 - for annotation in dataset.image_annotations: + for annotation in dataset.targets: rv.n_images += 1 rv.n_runes += annotation.has_rune for obj in annotation.objects: @@ -49,12 +46,12 @@ class ROCODatasetStats: def make_markdown_dataset_report(dataset: ROCODataset, report_dir: Path): - report_path = report_dir / f"dset_{dataset.dataset_name}_report.md" + report_path = report_dir / f"dset_{dataset.name}_report.md" stats = ROCODatasetStats.from_dataset(dataset) with MarkdownFile(report_path) as mf: - mf.title(f"Dataset {dataset.dataset_name}") + mf.title(f"Dataset {dataset.name}") mf.paragraph(f"{stats.n_images} images, with:") mf.list( @@ -69,5 +66,5 @@ def make_markdown_dataset_report(dataset: ROCODataset, report_dir: Path): if __name__ == "__main__": - for dset in chain(TwitchROCODataset, DJIROCOZoomedDataset, DJIROCODataset): + for dset in chain(*ROCODatasetsZoo()): make_markdown_dataset_report(dset, dset.dataset_path) diff --git a/common/research/common/datasets/roco/directory_roco_dataset.py b/common/research/common/datasets/roco/directory_roco_dataset.py index e59163fdd91864804debeebb710d229b4b723df4..ba8f09ec04145f60088a73486921b9c2752f09fb 100644 --- a/common/research/common/datasets/roco/directory_roco_dataset.py +++ b/common/research/common/datasets/roco/directory_roco_dataset.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Iterable, Tuple +from typing import Iterable, List, Tuple from more_itertools import ilen from polystar.common.models.image import Image @@ -12,6 +12,12 @@ class DirectoryROCODataset(ROCODataset): super().__init__(name) self.dataset_path = dataset_path + @property + def targets(self) -> List[ROCOAnnotation]: + if self._is_loaded: + return super().targets + return list(map(ROCOAnnotation.from_xml_file, self.annotation_paths)) + @property def images_dir_path(self) -> Path: return self.dataset_path / "image" diff --git a/common/research/common/datasets/roco/zoo/roco_datasets_zoo.py b/common/research/common/datasets/roco/zoo/roco_datasets_zoo.py index a11d78b5dab37d8ba561b44025bb2228529aa937..cfeab521b01bdc016dcfa740b4bbe8313cb9d7a4 100644 --- a/common/research/common/datasets/roco/zoo/roco_datasets_zoo.py +++ b/common/research/common/datasets/roco/zoo/roco_datasets_zoo.py @@ -7,3 +7,6 @@ class ROCODatasetsZoo: DJI_ZOOMED = DJIROCOZoomedDatasets() DJI = DJIROCODatasets() TWITCH = TwitchROCODatasets() + + def __iter__(self): + return (self.DJI, self.DJI_ZOOMED, self.TWITCH).__iter__() diff --git a/common/tests/common/unittests/datasets/roco/test_directory_dataset_zoo.py b/common/tests/common/unittests/datasets/roco/test_directory_dataset_zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..a25ff21a19a37fa73ed4d438085293e500402e21 --- /dev/null +++ b/common/tests/common/unittests/datasets/roco/test_directory_dataset_zoo.py @@ -0,0 +1,19 @@ +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest import TestCase + +from research.common.datasets.roco.directory_roco_dataset import \ + DirectoryROCODataset +from research.common.datasets.roco.roco_annotation import ROCOAnnotation + + +class TestDirectoryROCODataset(TestCase): + def test_lazy_targets(self): + with TemporaryDirectory() as dataset_dir: + dataset = DirectoryROCODataset(Path(dataset_dir), "fake") + dataset.annotations_dir_path.mkdir() + + annotation = ROCOAnnotation("frame_1", objects=[], has_rune=False, w=160, h=90) + (dataset.annotations_dir_path / "frame_1.xml").write_text(annotation.to_xml()) + + self.assertEqual([annotation], dataset)