diff --git a/common/polystar/common/models/image.py b/common/polystar/common/models/image.py index 90400a2ca25a39c953d51681a15e68dc97dd697d..61af256e9c096f11347a3d419421ea7f20626b5e 100644 --- a/common/polystar/common/models/image.py +++ b/common/polystar/common/models/image.py @@ -2,7 +2,6 @@ from pathlib import Path from typing import Iterable import cv2 - from nptyping import Array diff --git a/common/polystar/common/utils/misc.py b/common/polystar/common/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..e55d2b80874791769007c2033e2aea602a81a23b --- /dev/null +++ b/common/polystar/common/utils/misc.py @@ -0,0 +1,7 @@ +from typing import TypeVar + +T = TypeVar("T") + + +def identity(x: T) -> T: + return x diff --git a/common/research/common/dataset/improvement/zoom.py b/common/research/common/dataset/improvement/zoom.py index 567ec7738b65c94918dcb511256f2a1c19c94f87..785ea347b07b7f8c4ce55db07c5c462b7664c0c1 100644 --- a/common/research/common/dataset/improvement/zoom.py +++ b/common/research/common/dataset/improvement/zoom.py @@ -121,10 +121,10 @@ class Zoomer: if __name__ == "__main__": zoomer = Zoomer(854, 480, 0.15, 0.5) - for k, (img, annotation) in enumerate(ROCODatasetsZoo.DJI.NorthChina): + for k, (img, annot) in enumerate(ROCODatasetsZoo.DJI.NorthChina): viewer = PltResultViewer(f"img {i}") - for (cropped_image, cropped_annotation) in zoomer.zoom(img, annotation): + for (cropped_image, cropped_annotation) in zoomer.zoom(img, annot): viewer.display_image_with_objects(cropped_image, cropped_annotation.objects) if k == 2: diff --git a/common/research/common/dataset/roco_dataset_descriptor.py b/common/research/common/dataset/roco_dataset_descriptor.py index f7c3016cc7990eb610268a063252275655d367e1..2c790fdec8db12147191ee0f364ea827b0580493 100644 --- a/common/research/common/dataset/roco_dataset_descriptor.py +++ b/common/research/common/dataset/roco_dataset_descriptor.py @@ -8,6 +8,7 @@ from polystar.common.utils.markdown import MarkdownFile from research.common.datasets.roco.roco_dataset import ROCODataset from research.common.datasets.roco.zoo.roco_datasets_zoo import ROCODatasetsZoo from research.common.datasets.union_dataset import UnionDataset +from tqdm import tqdm @dataclass @@ -28,7 +29,7 @@ class ROCODatasetStats: rv.armors_color2num2count = {c: {n: 0 for n in range(10)} for c in colors} for c in colors: rv.armors_color2num2count[c]["total"] = 0 - for annotation in dataset.targets: + for annotation in tqdm(dataset.targets, desc=dataset.name, unit="frame", total=len(dataset)): rv.n_images += 1 rv.n_runes += annotation.has_rune for obj in annotation.objects: @@ -69,4 +70,4 @@ if __name__ == "__main__": for datasets in ROCODatasetsZoo(): make_markdown_dataset_report(UnionDataset(datasets, datasets.name), datasets.directory) for dset in datasets: - make_markdown_dataset_report(dset, dset.dataset_path) + make_markdown_dataset_report(dset, dset.main_dir) diff --git a/common/research/common/dataset/tensorflow_record.py b/common/research/common/dataset/tensorflow_record.py index 762cf2f68371275889cc15031b46e3c06780c38b..7c1e40cbe42badacee69be040f598965710f339d 100644 --- a/common/research/common/dataset/tensorflow_record.py +++ b/common/research/common/dataset/tensorflow_record.py @@ -19,10 +19,8 @@ class TensorflowRecordFactory: name = prefix + "_".join(d.name for d in datasets) writer = python_io.TFRecordWriter(str(TENSORFLOW_RECORDS_DIR / f"{name}.record")) c = 0 - for dataset in tqdm(datasets, desc=name, total=len(datasets)): - for image_path, annotation in tqdm( - dataset.unloaded_items(), desc=dataset.name, total=len(dataset), unit="img", leave=False - ): + for dataset in tqdm(datasets, desc=name, total=len(datasets), unit="dataset"): + for image_path, annotation in tqdm(dataset, desc=dataset.name, total=len(dataset), unit="img", leave=False): writer.write(_example_from_image_annotation(image_path, annotation).SerializeToString()) c += 1 writer.close() diff --git a/common/research/common/datasets/dataset.py b/common/research/common/datasets/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..eeb9a6e466774e581a9a74b0b50ca1e0e111e314 --- /dev/null +++ b/common/research/common/datasets/dataset.py @@ -0,0 +1,114 @@ +from abc import ABC, abstractmethod +from collections import deque +from typing import Callable, Generic, Iterable, Iterator, Tuple, TypeVar + +from more_itertools import ilen +from polystar.common.utils.misc import identity + +ExampleT = TypeVar("ExampleT") +TargetT = TypeVar("TargetT") +ExampleU = TypeVar("ExampleU") +TargetU = TypeVar("TargetU") + + +class Dataset(Generic[ExampleT, TargetT], Iterable[Tuple[ExampleT, TargetT]], ABC): + def __init__(self, name: str): + self.name = name + + @property + @abstractmethod + def examples(self) -> Iterable[ExampleT]: + pass + + @property + @abstractmethod + def targets(self) -> Iterable[TargetT]: + pass + + @abstractmethod + def __iter__(self) -> Iterator[Tuple[ExampleT, TargetT]]: + pass + + @abstractmethod + def __len__(self): + pass + + def transform_examples(self, example_transformer: Callable[[ExampleT], ExampleU]) -> "Dataset[ExampleU, TargetT]": + return self.transform(example_transformer, identity) + + def transform_targets( + self, target_transformer: Callable[[TargetT], TargetU] = identity + ) -> "Dataset[ExampleT, TargetU]": + return self.transform(identity, target_transformer) + + def transform( + self, example_transformer: Callable[[ExampleT], ExampleU], target_transformer: Callable[[TargetT], TargetU] + ) -> "Dataset[ExampleU, TargetU]": + return GeneratorDataset( + self.name, lambda: ((example_transformer(example), target_transformer(target)) for example, target in self) + ) + + def __str__(self): + return f"<{self.__class__.__name__} {self.name}>" + + __repr__ = __str__ + + def check_consistency(self): + targets, examples = self.targets, self.examples + if isinstance(targets, list) and isinstance(examples, list): + assert len(targets) == len(examples) + assert ilen(targets) == ilen(examples) + + +class LazyUnzipper: + def __init__(self, iterator: Iterator[Tuple]): + self._iterator = iterator + self._memory = [deque(), deque()] + + def empty(self, i: int): + return self._iterator is None and not self._memory[i] + + def elements(self, i: int): + while True: + if self._memory[i]: + yield self._memory[i].popleft() + elif self._iterator is None: + return + else: + try: + elements = next(self._iterator) + self._memory[1 - i].append(elements[1 - i]) + yield elements[i] + except StopIteration: + self._iterator = None + return + + +class LazyDataset(Dataset[ExampleT, TargetT], ABC): + def __init__(self, name: str): + super().__init__(name) + self._unzipper = LazyUnzipper(iter(self)) + + @property + def examples(self) -> Iterable[ExampleT]: + if self._unzipper.empty(0): + self._unzipper = LazyUnzipper(iter(self)) + return self._unzipper.elements(0) + + @property + def targets(self) -> Iterable[ExampleT]: + if self._unzipper.empty(1): + self._unzipper = LazyUnzipper(iter(self)) + return self._unzipper.elements(1) + + def __len__(self): + return ilen(self) + + +class GeneratorDataset(LazyDataset[ExampleT, TargetT]): + def __init__(self, name: str, generator: Callable[[], Iterator[Tuple[ExampleT, TargetT]]]): + self.generator = generator + super().__init__(name) + + def __iter__(self) -> Iterator[Tuple[ExampleT, TargetT]]: + return self.generator() diff --git a/common/research/common/datasets/image_dataset.py b/common/research/common/datasets/image_dataset.py index a730661b80c03c2fe5ff63dd2101d388838a71fe..9ede6326accbf19365c098f7f1f59b6fb6ed2a8f 100644 --- a/common/research/common/datasets/image_dataset.py +++ b/common/research/common/datasets/image_dataset.py @@ -1,58 +1,45 @@ -from typing import Generic, Iterable, Iterator, List, Tuple, TypeVar +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Iterator, List, Tuple +from memoized_property import memoized_property +from more_itertools import ilen from polystar.common.models.image import Image +from research.common.datasets.dataset import Dataset, LazyDataset, TargetT -TargetT = TypeVar("TargetT") +ImageDataset = Dataset[Image, TargetT] -class ImageDataset(Generic[TargetT], Iterable[Tuple[Image, TargetT]]): - def __init__(self, name: str, images: List[Image] = None, targets: List[TargetT] = None): - self.name = name - self._targets = targets - self._images = images - self._check_consistency() +class ImageFileDataset(LazyDataset[Path, TargetT], ABC): + def __iter__(self) -> Iterator[Tuple[Path, TargetT]]: + for image_file in self.image_files: + yield image_file, self.target_from_image_file(image_file) - def __iter__(self) -> Iterator[Tuple[Image, TargetT]]: - return zip(self.images, self.targets) + @abstractmethod + def target_from_image_file(self, image_file: Path) -> TargetT: + pass - def __len__(self): - if not self._is_loaded: - self._load_data() - return len(self.images) + @property + @abstractmethod + def image_files(self) -> Iterator[Path]: + pass - def __str__(self): - return f"<{self.__class__.__name__} {self.name}>" + def open(self) -> ImageDataset: + return self.transform_examples(Image.from_path) - __repr__ = __str__ + def __len__(self): + return ilen(self.image_files) - @property - def images(self) -> List[Image]: - self._load_data() - return self._images - @property - def targets(self) -> List[TargetT]: - self._load_data() - return self._targets - - def _load_data(self): - if self._is_loaded: - return - self._images, self._targets = [], [] - for image, target in self: - self._images.append(image) - self._targets.append(target) - self._check_consistency() - - def _check_consistency(self): - assert self._is_loaded or self._has_custom_load - if self._is_loaded: - assert len(self.targets) == len(self.images) +class ImageDirectoryDataset(ImageFileDataset[TargetT], ABC): + def __init__(self, images_dir: Path, name: str, extension: str = "jpg"): + super().__init__(name) + self.extension = extension + self.images_dir = images_dir - @property - def _is_loaded(self) -> bool: - return self._images is not None and self._targets is not None + @memoized_property + def image_files(self) -> List[Path]: + return list(sorted(self.images_dir.glob(f"*.{self.extension}"))) - @property - def _has_custom_load(self) -> bool: - return not self.__iter__.__qualname__.startswith("ImageDataset") + def __len__(self): + return len(self.image_files) diff --git a/common/research/common/datasets/roco/directory_roco_dataset.py b/common/research/common/datasets/roco/directory_roco_dataset.py index 260b5d47f6aa547cbfd63c04e0ef8be2b9ed3b72..9857bea0b90228184f7517226ee6c6e199fc066a 100644 --- a/common/research/common/datasets/roco/directory_roco_dataset.py +++ b/common/research/common/datasets/roco/directory_roco_dataset.py @@ -1,46 +1,25 @@ from pathlib import Path -from typing import Iterable, List, Tuple -from more_itertools import ilen from polystar.common.models.image import Image +from research.common.datasets.image_dataset import ImageDirectoryDataset from research.common.datasets.roco.roco_annotation import ROCOAnnotation -from research.common.datasets.roco.roco_dataset import ROCODataset -class DirectoryROCODataset(ROCODataset): +class DirectoryROCODataset(ImageDirectoryDataset[ROCOAnnotation]): def __init__(self, dataset_path: Path, name: str): - super().__init__(name) - self.dataset_path = dataset_path - - @property - def targets(self) -> List[ROCOAnnotation]: - if self._is_loaded: - return super().targets - return list(map(ROCOAnnotation.from_xml_file, self.annotation_paths)) - - @property - def images_dir_path(self) -> Path: - return self.dataset_path / "image" - - @property - def annotation_paths(self) -> Iterable[Path]: - return sorted(self.annotations_dir_path.glob("*.xml")) - - @property - def annotations_dir_path(self) -> Path: - return self.dataset_path / "image_annotation" - - def __len__(self): - return ilen(self.annotation_paths) - - def unloaded_items(self) -> Iterable[Tuple[Path, ROCOAnnotation]]: - for annotation_file in self.annotation_paths: - yield self.images_dir_path / f"{annotation_file.stem}.jpg", ROCOAnnotation.from_xml_file(annotation_file) - - def __iter__(self) -> Iterable[Tuple[Image, ROCOAnnotation]]: - for image_path, roco_annotation in self.unloaded_items(): - yield Image.from_path(image_path), roco_annotation - - def save_one(self, image: Image, annotation: ROCOAnnotation): - Image.save(image, self.images_dir_path / f"{annotation.name}.jpg") - (self.annotations_dir_path / f"{annotation.name}.xml").write_text(annotation.to_xml()) + super().__init__(dataset_path / "image", name) + self.main_dir = dataset_path + self.annotations_dir: Path = self.main_dir / "image_annotation" + self.annotations_dir: Path = self.main_dir / "image_annotation" + + def target_from_image_file(self, image_file: Path) -> ROCOAnnotation: + return ROCOAnnotation.from_xml_file(self.annotations_dir / f"{image_file.stem}.xml") + + def create(self): + self.main_dir.mkdir(parents=True) + self.images_dir.mkdir() + self.annotations_dir.mkdir() + + def add(self, image: Image, annotation: ROCOAnnotation): + Image.save(image, self.images_dir / f"{annotation.name}.jpg") + (self.annotations_dir / f"{annotation.name}.xml").write_text(annotation.to_xml()) diff --git a/common/research/common/datasets/roco/roco_dataset.py b/common/research/common/datasets/roco/roco_dataset.py index 57abf823f73ecfde30768410363b030b44580fc0..f2fb206c43460631eac2761e300a51f259b86ba5 100644 --- a/common/research/common/datasets/roco/roco_dataset.py +++ b/common/research/common/datasets/roco/roco_dataset.py @@ -1,4 +1,6 @@ -from research.common.datasets.image_dataset import ImageDataset +from research.common.datasets.image_dataset import (ImageDataset, + ImageFileDataset) from research.common.datasets.roco.roco_annotation import ROCOAnnotation ROCODataset = ImageDataset[ROCOAnnotation] +ROCOFileDataset = ImageFileDataset[ROCOAnnotation] diff --git a/common/research/common/datasets/simple_dataset.py b/common/research/common/datasets/simple_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1209ae9311ca7da8c8df91c6ff745fc889b969f1 --- /dev/null +++ b/common/research/common/datasets/simple_dataset.py @@ -0,0 +1,25 @@ +from typing import Iterable, Iterator, List, Tuple + +from research.common.datasets.dataset import Dataset, ExampleT, TargetT + + +class SimpleDataset(Dataset[ExampleT, TargetT]): + def __init__(self, examples: Iterable[ExampleT], targets: Iterable[TargetT], name: str): + super().__init__(name) + self._examples = list(examples) + self._targets = list(targets) + self.check_consistency() + + @property + def examples(self) -> List[ExampleT]: + return self._examples + + @property + def targets(self) -> List[TargetT]: + return self._targets + + def __iter__(self) -> Iterator[Tuple[ExampleT, TargetT]]: + return zip(self.examples, self.targets) + + def __len__(self): + return len(self.examples) diff --git a/common/research/common/datasets/union_dataset.py b/common/research/common/datasets/union_dataset.py index 623887862ae88553bacb24aa47ef428100b65c4f..24f314d2ef35dab56bb420bf1ddac671f572f59b 100644 --- a/common/research/common/datasets/union_dataset.py +++ b/common/research/common/datasets/union_dataset.py @@ -1,11 +1,11 @@ -from itertools import chain -from typing import Iterable, Iterator, List, Tuple +from typing import Iterable, Iterator, Tuple from polystar.common.models.image import Image -from research.common.datasets.image_dataset import ImageDataset, TargetT +from research.common.datasets.dataset import ExampleT, LazyDataset, TargetT +from research.common.datasets.image_dataset import ImageDataset -class UnionDataset(ImageDataset[TargetT]): +class UnionDataset(LazyDataset[ExampleT, TargetT]): def __init__(self, datasets: Iterable[ImageDataset[TargetT]], name: str): super().__init__(name) self.datasets = list(datasets) @@ -14,6 +14,5 @@ class UnionDataset(ImageDataset[TargetT]): for dataset in self.datasets: yield from dataset - @property - def targets(self) -> List[TargetT]: - return list(chain(*(dataset.targets for dataset in self.datasets))) + def __len__(self): + return sum(map(len, self.datasets)) diff --git a/common/research/common/scripts/improve_roco_by_zooming.py b/common/research/common/scripts/improve_roco_by_zooming.py index 0b3d6b1c6163ff38ee59cd5201b049fd16f93b3d..d38ddf1b411107cc98405b478730f653b944027c 100644 --- a/common/research/common/scripts/improve_roco_by_zooming.py +++ b/common/research/common/scripts/improve_roco_by_zooming.py @@ -6,25 +6,22 @@ from research.common.dataset.perturbations.image_modifiers.contrast import \ from research.common.dataset.perturbations.image_modifiers.saturation import \ SaturationModifier from research.common.dataset.perturbations.perturbator import ImagePerturbator -from research.common.datasets.roco.directory_roco_dataset import \ - DirectoryROCODataset +from research.common.datasets.roco.roco_dataset import ROCOFileDataset from research.common.datasets.roco.zoo.dji_zoomed import DJIROCOZoomedDatasets from research.common.datasets.roco.zoo.roco_datasets_zoo import ROCODatasetsZoo from tqdm import tqdm def improve_dji_roco_dataset_by_zooming_and_perturbating( - dset: DirectoryROCODataset, zoomer: Zoomer, perturbator: ImagePerturbator + dset: ROCOFileDataset, zoomer: Zoomer, perturbator: ImagePerturbator ): zoomed_dset = DJIROCOZoomedDatasets.make_dataset(dset.name) - zoomed_dset.dataset_path.mkdir(parents=True) - zoomed_dset.images_dir_path.mkdir() - zoomed_dset.annotations_dir_path.mkdir() + zoomed_dset.create() - for img, annotation in tqdm(dset, desc=f"Processing {dset}", unit="image", total=len(dset)): + for img, annotation in tqdm(dset.open(), desc=f"Processing {dset}", unit="image", total=len(dset)): for zoomed_image, zoomed_annotation in zoomer.zoom(img, annotation): zoomed_image = perturbator.perturbate(zoomed_image) - zoomed_dset.save_one(zoomed_image, zoomed_annotation) + zoomed_dset.add(zoomed_image, zoomed_annotation) def improve_all_dji_datasets_by_zooming_and_perturbating(zoomer: Zoomer, perturbator: ImagePerturbator): diff --git a/common/research/common/scripts/visualize_dataset.py b/common/research/common/scripts/visualize_dataset.py index 9cc123b5f85f9cc9289f46f9e8a5f14d1fea5ea0..1bdb43a71d8b39107999552d76f692257f752304 100644 --- a/common/research/common/scripts/visualize_dataset.py +++ b/common/research/common/scripts/visualize_dataset.py @@ -14,4 +14,4 @@ def visualize_dataset(dataset: ROCODataset, n_images: int): if __name__ == "__main__": - visualize_dataset(ROCODatasetsZoo.DJI_ZOOMED.CentralChina, 20) + visualize_dataset(ROCODatasetsZoo.DJI_ZOOMED.CentralChina.open(), 20) diff --git a/common/tests/common/unittests/datasets/roco/test_directory_dataset_zoo.py b/common/tests/common/unittests/datasets/roco/test_directory_dataset_zoo.py index a25ff21a19a37fa73ed4d438085293e500402e21..69ad22403d4d6acb5445a98506b42f829269da83 100644 --- a/common/tests/common/unittests/datasets/roco/test_directory_dataset_zoo.py +++ b/common/tests/common/unittests/datasets/roco/test_directory_dataset_zoo.py @@ -2,18 +2,44 @@ from pathlib import Path from tempfile import TemporaryDirectory from unittest import TestCase +from numpy import asarray, float32 +from numpy.testing import assert_array_almost_equal +from polystar.common.models.image import Image from research.common.datasets.roco.directory_roco_dataset import \ DirectoryROCODataset from research.common.datasets.roco.roco_annotation import ROCOAnnotation class TestDirectoryROCODataset(TestCase): - def test_lazy_targets(self): + def test_targets(self): with TemporaryDirectory() as dataset_dir: dataset = DirectoryROCODataset(Path(dataset_dir), "fake") - dataset.annotations_dir_path.mkdir() annotation = ROCOAnnotation("frame_1", objects=[], has_rune=False, w=160, h=90) - (dataset.annotations_dir_path / "frame_1.xml").write_text(annotation.to_xml()) - self.assertEqual([annotation], dataset) + dataset.annotations_dir.mkdir() + dataset.images_dir.mkdir() + (dataset.annotations_dir / "frame_1.xml").write_text(annotation.to_xml()) + (dataset.images_dir / "frame_1.jpg").write_text("") + + self.assertEqual([annotation], list(dataset.targets)) + self.assertEqual([dataset.images_dir / "frame_1.jpg"], list(dataset.examples)) + + def test_open(self): + with TemporaryDirectory() as dataset_dir: + dataset = DirectoryROCODataset(Path(dataset_dir), "fake") + + annotation = ROCOAnnotation("frame_1", objects=[], has_rune=False, w=160, h=90) + image = asarray([[[250, 0, 0], [250, 0, 0]], [[250, 0, 0], [250, 0, 0]]]).astype(float32) + + dataset.annotations_dir.mkdir() + dataset.images_dir.mkdir() + (dataset.annotations_dir / "frame_1.xml").write_text(annotation.to_xml()) + Image.save(image, dataset.images_dir / "frame_1.jpg") + + image_dataset = dataset.open() + + self.assertEqual([annotation], list(image_dataset.targets)) + images = list(image_dataset.examples) + self.assertEqual(1, len(images)) + assert_array_almost_equal(image / 256, images[0] / 256, decimal=2) # jpeg precision diff --git a/common/tests/common/unittests/datasets/test_dataset.py b/common/tests/common/unittests/datasets/test_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fff885d3cf565a489fa783dfadf865173811d0b7 --- /dev/null +++ b/common/tests/common/unittests/datasets/test_dataset.py @@ -0,0 +1,85 @@ +from unittest import TestCase +from unittest.mock import MagicMock + +from research.common.datasets.dataset import Dataset, LazyDataset +from research.common.datasets.simple_dataset import SimpleDataset + + +class TestDataset(TestCase): + def test_transform(self): + dataset = _make_fake_dataset() + + str_str_dataset: Dataset[str, str] = dataset.transform(str, str) + + self.assertEqual([("0", "8"), ("1", "9"), ("2", "10"), ("3", "11")], list(str_str_dataset)) + + def test_transform_examples(self): + dataset = _make_fake_dataset() + + str_int_dataset: Dataset[str, int] = dataset.transform_examples(str) + + self.assertEqual([("0", 8), ("1", 9), ("2", 10), ("3", 11)], list(str_int_dataset)) + + def test_transform_not_exhaustible(self): + dataset = _make_fake_dataset() + + str_int_dataset: Dataset[str, float] = dataset.transform_examples(str) + + self.assertEqual([("0", 8), ("1", 9), ("2", 10), ("3", 11)], list(str_int_dataset)) + self.assertEqual([("0", 8), ("1", 9), ("2", 10), ("3", 11)], list(str_int_dataset)) + self.assertEqual([("0", 8), ("1", 9), ("2", 10), ("3", 11)], list(str_int_dataset)) + + +class TestSimpleDataset(TestCase): + def test_properties(self): + dataset = _make_fake_dataset() + + self.assertEqual([0, 1, 2, 3], dataset.examples) + self.assertEqual([8, 9, 10, 11], dataset.targets) + + def test_iter(self): + dataset = _make_fake_dataset() + + self.assertEqual([(0, 8), (1, 9), (2, 10), (3, 11)], list(dataset)) + + def test_len(self): + dataset = _make_fake_dataset() + + self.assertEqual(4, len(dataset)) + + def test_consistency(self): + with self.assertRaises(AssertionError): + SimpleDataset([0, 1], [8, 9, 10, 11], "fake") + + +class FakeLazyDataset(LazyDataset): + def __init__(self): + super().__init__("fake") + + __iter__ = MagicMock(side_effect=lambda *args: iter([(1, 1), (2, 4), (3, 9)])) + + +class TestLazyDataset(TestCase): + def test_properties(self): + dataset = FakeLazyDataset() + + self.assertEqual([1, 2, 3], list(dataset.examples)) + self.assertEqual([1, 4, 9], list(dataset.targets)) + self.assertEqual([(1, 1), (2, 4), (3, 9)], list(zip(dataset.examples, dataset.targets))) + + def test_properties_laziness(self): + FakeLazyDataset.__iter__.reset_mock() + dataset = FakeLazyDataset() + + list(dataset.examples) + list(dataset.targets) + FakeLazyDataset.__iter__.assert_called_once() + + FakeLazyDataset.__iter__.reset_mock() + + list(zip(dataset.examples, dataset.targets)) + FakeLazyDataset.__iter__.assert_called_once() + + +def _make_fake_dataset() -> Dataset[int, int]: + return SimpleDataset([0, 1, 2, 3], [8, 9, 10, 11], "fake") diff --git a/common/tests/common/unittests/datasets/test_image_dataset.py b/common/tests/common/unittests/datasets/test_image_dataset.py deleted file mode 100644 index 86ee99594ba14ccf503754bc41072a807bbc1db7..0000000000000000000000000000000000000000 --- a/common/tests/common/unittests/datasets/test_image_dataset.py +++ /dev/null @@ -1,31 +0,0 @@ -from unittest import TestCase - -from research.common.datasets.image_dataset import ImageDataset - - -class TestImageDataset(TestCase): - def test_iter(self): - dataset = ImageDataset("test", list(range(5)), list(range(3, 8))) - - self.assertEqual([(0, 3), (1, 4), (2, 5), (3, 6), (4, 7)], list(dataset)) - - def test_auto_load(self): - class FakeDataset(ImageDataset): - def __iter__(self): - return [(0, 2), (1, 4)].__iter__() - - dataset = FakeDataset("test") - - self.assertEqual([0, 1], dataset.images) - self.assertEqual([2, 4], dataset.targets) - - def test_assert(self): - with self.assertRaises(AssertionError): - ImageDataset("test") - - def test_assert_child(self): - class FakeDataset(ImageDataset): - pass - - with self.assertRaises(AssertionError): - FakeDataset("test")