diff --git a/common/research_common/dataset/split.py b/common/research_common/dataset/split.py new file mode 100644 index 0000000000000000000000000000000000000000..b7e69bb22f255631d695a094e8616fcaadcc5e9e --- /dev/null +++ b/common/research_common/dataset/split.py @@ -0,0 +1,14 @@ +from enum import Enum +from pathlib import Path + +from research_common.dataset.dataset import Dataset + + +class Split(Enum): + Val = "val" + Train = "train" + Test = "test" + TrainVal = "trainval" + + def get_split_file(self, dataset: Dataset) -> Path: + return (dataset.dataset_path / self.value).with_suffix(".txt") diff --git a/common/research_common/dataset/split_dataset.py b/common/research_common/dataset/split_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d4cb2771bbde51d8903225e89dce8f214c29a7d2 --- /dev/null +++ b/common/research_common/dataset/split_dataset.py @@ -0,0 +1,33 @@ +from pathlib import Path +from typing import Iterable, List + +from research_common.dataset.dataset import Dataset +from research_common.dataset.roco.roco_datasets import ROCODataset +from research_common.dataset.split import Split + + +class SplitDataset(Dataset): + def __init__(self, root_dataset: Dataset, split: Split): + super().__init__(root_dataset.dataset_path, f"{root_dataset.dataset_name}_{split.name}") + self._load_file_names(split) + + def _load_file_names(self, split: Split): + self._file_names = split.get_split_file(self).read_text().strip().split("\n") + + @property + def image_paths(self) -> Iterable[Path]: + return self._generate_file_paths(self.images_dir_path, ".jpg") + + @property + def annotation_paths(self) -> Iterable[Path]: + return self._generate_file_paths(self.annotations_dir_path, ".xml") + + def _generate_file_paths(self, dir_path: Path, suffix: str) -> List[Path]: + return [(dir_path / file_name).with_suffix(suffix) for file_name in self._file_names] + + +if __name__ == "__main__": + print( + set(SplitDataset(ROCODataset.CentralChina, Split.Train).annotation_paths) + & set(SplitDataset(ROCODataset.CentralChina, Split.Val).annotation_paths) + )