Skip to content
Snippets Groups Projects
Commit c9a41b71 authored by Mathieu Beligon's avatar Mathieu Beligon
Browse files

[robots@robots] (rename) ArmorDigitDatasetGenerator -> ArmorDigitDatasetCache

parent 3efe77c4
No related branches found
No related tags found
No related merge requests found
......@@ -11,7 +11,7 @@ from polystar.common.models.image import Image, load_images
from research.common.datasets.roco.directory_roco_dataset import \
DirectoryROCODataset
from research.robots_at_robots.dataset.armor_value_dataset import \
ArmorValueDatasetGenerator
ArmorValueDatasetCache
from sklearn.metrics import classification_report, confusion_matrix
......@@ -50,16 +50,16 @@ class ImagePipelineEvaluator:
self,
train_roco_datasets: List[DirectoryROCODataset],
test_roco_datasets: List[DirectoryROCODataset],
image_dataset_generator: ArmorValueDatasetGenerator,
image_dataset_cache: ArmorValueDatasetCache,
):
logging.info("Loading data")
self.train_roco_datasets = train_roco_datasets
self.test_roco_datasets = test_roco_datasets
(self.train_images_paths, self.train_images, self.train_labels, self.train_dataset_sizes) = load_datasets(
train_roco_datasets, image_dataset_generator
train_roco_datasets, image_dataset_cache
)
(self.test_images_paths, self.test_images, self.test_labels, self.test_dataset_sizes) = load_datasets(
test_roco_datasets, image_dataset_generator
test_roco_datasets, image_dataset_cache
)
def evaluate_pipelines(self, pipelines: Iterable[ImagePipeline]) -> Dict[str, ClassificationResults]:
......@@ -86,9 +86,9 @@ class ImagePipelineEvaluator:
def load_datasets(
roco_datasets: List[DirectoryROCODataset], image_dataset_generator: ArmorValueDatasetGenerator,
roco_datasets: List[DirectoryROCODataset], image_dataset_cache: ArmorValueDatasetCache,
) -> Tuple[List[Path], List[Image], List[Any], List[int]]:
dataset = image_dataset_generator.from_roco_datasets(roco_datasets)
dataset = image_dataset_cache.from_roco_datasets(roco_datasets)
dataset_sizes = [len(d) for d in dataset.datasets]
paths, targets = list(dataset.examples), list(dataset.targets)
......
......@@ -7,7 +7,7 @@ from research.common.datasets.dataset import Dataset
from research.common.datasets.image_dataset import open_file_dataset
from research.common.datasets.roco.zoo.roco_datasets_zoo import ROCODatasetsZoo
from research.robots_at_robots.dataset.armor_value_dataset import (
ArmorValueDatasetGenerator, ArmorValueDirectoryDataset)
ArmorValueDatasetCache, ArmorValueDirectoryDataset)
class ArmorColorDirectoryDataset(ArmorValueDirectoryDataset[str]):
......@@ -16,7 +16,7 @@ class ArmorColorDirectoryDataset(ArmorValueDirectoryDataset[str]):
return label
class ArmorColorDatasetGenerator(ArmorValueDatasetGenerator[str]):
class ArmorColorDatasetCache(ArmorValueDatasetCache[str]):
def __init__(self):
super().__init__("colors")
......@@ -28,7 +28,7 @@ class ArmorColorDatasetGenerator(ArmorValueDatasetGenerator[str]):
if __name__ == "__main__":
_dataset = open_file_dataset(ArmorColorDatasetGenerator().from_roco_dataset(ROCODatasetsZoo.TWITCH.T470150052))
_dataset = open_file_dataset(ArmorColorDatasetCache().from_roco_dataset(ROCODatasetsZoo.TWITCH.T470150052))
for _image, _value, _name in islice(_dataset, 40, 50):
print(_value)
......
......@@ -6,7 +6,7 @@ from research.common.image_pipeline_evaluation.image_pipeline_evaluation_reporte
from research.common.image_pipeline_evaluation.image_pipeline_evaluator import \
ImagePipelineEvaluator
from research.robots_at_robots.armor_color.armor_color_dataset import \
ArmorColorDatasetGenerator
ArmorColorDatasetCache
class ArmorColorPipelineReporterFactory:
......@@ -18,7 +18,7 @@ class ArmorColorPipelineReporterFactory:
evaluator=ImagePipelineEvaluator(
train_roco_datasets=train_roco_datasets,
test_roco_datasets=test_roco_datasets,
image_dataset_generator=ArmorColorDatasetGenerator(),
image_dataset_cache=ArmorColorDatasetCache(),
),
evaluation_project="armor-color",
)
......@@ -10,7 +10,7 @@ from research.common.datasets.filtered_dataset import FilteredTargetsDataset
from research.common.datasets.image_dataset import open_file_dataset
from research.common.datasets.roco.zoo.roco_datasets_zoo import ROCODatasetsZoo
from research.robots_at_robots.dataset.armor_value_dataset import (
ArmorValueDatasetGenerator, ArmorValueDirectoryDataset)
ArmorValueDatasetCache, ArmorValueDirectoryDataset)
class ArmorDigitDirectoryDataset(ArmorValueDirectoryDataset[int]):
......@@ -19,7 +19,7 @@ class ArmorDigitDirectoryDataset(ArmorValueDirectoryDataset[int]):
return int(label)
class ArmorDigitDatasetGenerator(ArmorValueDatasetGenerator[str]):
class ArmorDigitDatasetCache(ArmorValueDatasetCache[str]):
def __init__(self, acceptable_digits: Iterable[int]):
super().__init__("digits")
self.acceptable_digits = acceptable_digits
......@@ -34,7 +34,7 @@ class ArmorDigitDatasetGenerator(ArmorValueDatasetGenerator[str]):
if __name__ == "__main__":
_dataset = open_file_dataset(
ArmorDigitDatasetGenerator((1, 2, 3, 4, 5, 7)).from_roco_dataset(ROCODatasetsZoo.TWITCH.T470150052)
ArmorDigitDatasetCache((1, 2, 3, 4, 5, 7)).from_roco_dataset(ROCODatasetsZoo.TWITCH.T470150052)
)
for _image, _value, _name in islice(_dataset, 40, 50):
......
......@@ -7,7 +7,7 @@ from research.common.image_pipeline_evaluation.image_pipeline_evaluation_reporte
from research.common.image_pipeline_evaluation.image_pipeline_evaluator import \
ImagePipelineEvaluator
from research.robots_at_robots.armor_digit.armor_digit_dataset import \
ArmorDigitDatasetGenerator
ArmorDigitDatasetCache
class ArmorDigitPipelineReporterFactory:
......@@ -21,7 +21,7 @@ class ArmorDigitPipelineReporterFactory:
evaluator=ImagePipelineEvaluator(
train_roco_datasets=train_roco_datasets,
test_roco_datasets=test_roco_datasets,
image_dataset_generator=ArmorDigitDatasetGenerator(acceptable_digits),
image_dataset_cache=ArmorDigitDatasetCache(acceptable_digits),
),
evaluation_project="armor-digit",
)
......@@ -11,11 +11,9 @@ from polystar.common.utils.time import create_time_id
from polystar.common.utils.tqdm import smart_tqdm
from research.common.datasets.dataset import Dataset
from research.common.datasets.image_dataset import ImageDirectoryDataset
from research.common.datasets.roco.directory_roco_dataset import \
DirectoryROCODataset
from research.common.datasets.roco.directory_roco_dataset import DirectoryROCODataset
from research.common.datasets.union_dataset import UnionDataset
from research.robots_at_robots.dataset.armor_dataset_factory import \
ArmorDatasetFactory
from research.robots_at_robots.dataset.armor_dataset_factory import ArmorDatasetFactory
ValueT = TypeVar("ValueT")
......@@ -35,7 +33,7 @@ class WrongVersionException(Exception):
expected: str
class ArmorValueDatasetGenerator(Generic[ValueT], ABC):
class ArmorValueDatasetCache(Generic[ValueT], ABC):
VERSION: ClassVar[str] = "2.0"
def __init__(self, task_name: str):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment