From 969b746736d219f69a5607138848300bc037abfc Mon Sep 17 00:00:00 2001 From: Mathieu Beligon <mathieu@feedly.com> Date: Mon, 25 Jan 2021 12:23:23 -0500 Subject: [PATCH] [objects] Remove the old validator class, and use the filter class --- .../polystar/common/dependency_injection.py | 9 ++++----- .../common/pipeline/keras/classifier.py | 6 +++--- .../__init__.py | 0 .../objects_filters/armor_digit_filter.py | 12 +++++++++++ .../confidence_object_filter.py} | 8 +++----- .../objects_filters/contains_box_filter.py | 14 +++++++++++++ .../objects_filters/in_box_filter.py | 14 +++++++++++++ .../objects_filters/objects_filter_abc.py | 4 ++++ .../objects_filters/type_object_filter.py | 14 +++++++++++++ .../objects_linker/simple_objects_linker.py | 10 ++-------- .../armor_digit_validator.py | 15 -------------- .../contains_box_validator.py | 14 ------------- .../objects_validators/in_box_validator.py | 14 ------------- .../objects_validators/negation_validator.py | 13 ------------ .../objects_validator_abc.py | 20 ------------------- .../robot_color_validator.py | 17 ---------------- .../type_object_validator.py | 14 ------------- .../common/target_pipeline/target_pipeline.py | 12 +++++------ .../common/dataset/improvement/zoom.py | 4 ++-- .../common/datasets/roco/roco_annotation.py | 7 ++++++- .../roco_annotation_object_filter.py | 8 ++++---- .../robots/dataset/armor_dataset_factory.py | 9 +++------ .../research/robots/demos/demo_pipeline.py | 4 ++-- .../test_in_box_validator.py | 10 +++++----- 24 files changed, 98 insertions(+), 154 deletions(-) rename polystar_cv/polystar/common/target_pipeline/{objects_validators => objects_filters}/__init__.py (100%) create mode 100644 polystar_cv/polystar/common/target_pipeline/objects_filters/armor_digit_filter.py rename polystar_cv/polystar/common/target_pipeline/{objects_validators/confidence_object_validator.py => objects_filters/confidence_object_filter.py} (54%) create mode 100644 polystar_cv/polystar/common/target_pipeline/objects_filters/contains_box_filter.py create mode 100644 polystar_cv/polystar/common/target_pipeline/objects_filters/in_box_filter.py create mode 100644 polystar_cv/polystar/common/target_pipeline/objects_filters/objects_filter_abc.py create mode 100644 polystar_cv/polystar/common/target_pipeline/objects_filters/type_object_filter.py delete mode 100644 polystar_cv/polystar/common/target_pipeline/objects_validators/armor_digit_validator.py delete mode 100644 polystar_cv/polystar/common/target_pipeline/objects_validators/contains_box_validator.py delete mode 100644 polystar_cv/polystar/common/target_pipeline/objects_validators/in_box_validator.py delete mode 100644 polystar_cv/polystar/common/target_pipeline/objects_validators/negation_validator.py delete mode 100644 polystar_cv/polystar/common/target_pipeline/objects_validators/objects_validator_abc.py delete mode 100644 polystar_cv/polystar/common/target_pipeline/objects_validators/robot_color_validator.py delete mode 100644 polystar_cv/polystar/common/target_pipeline/objects_validators/type_object_validator.py diff --git a/polystar_cv/polystar/common/dependency_injection.py b/polystar_cv/polystar/common/dependency_injection.py index 5482b1c..097eab0 100644 --- a/polystar_cv/polystar/common/dependency_injection.py +++ b/polystar_cv/polystar/common/dependency_injection.py @@ -16,15 +16,14 @@ from polystar.common.target_pipeline.armors_descriptors.armors_color_descriptor from polystar.common.target_pipeline.armors_descriptors.armors_descriptor_abc import ArmorsDescriptorABC from polystar.common.target_pipeline.armors_descriptors.armors_digit_descriptor import ArmorsDigitDescriptor from polystar.common.target_pipeline.detected_objects.detected_objects_factory import DetectedObjectFactory -from polystar.common.target_pipeline.detected_objects.detected_robot import DetectedRobot from polystar.common.target_pipeline.object_selectors.closest_object_selector import ClosestObjectSelector from polystar.common.target_pipeline.object_selectors.object_selector_abc import ObjectSelectorABC from polystar.common.target_pipeline.objects_detectors.objects_detector_abc import ObjectsDetectorABC from polystar.common.target_pipeline.objects_detectors.tf_model_objects_detector import TFModelObjectsDetector +from polystar.common.target_pipeline.objects_filters.confidence_object_filter import ConfidenceObjectsFilter +from polystar.common.target_pipeline.objects_filters.objects_filter_abc import ObjectsFilterABC from polystar.common.target_pipeline.objects_linker.objects_linker_abs import ObjectsLinkerABC from polystar.common.target_pipeline.objects_linker.simple_objects_linker import SimpleObjectsLinker -from polystar.common.target_pipeline.objects_validators.confidence_object_validator import ConfidenceObjectValidator -from polystar.common.target_pipeline.objects_validators.objects_validator_abc import ObjectsValidatorABC from polystar.common.target_pipeline.target_factories.ratio_simple_target_factory import RatioSimpleTargetFactory from polystar.common.target_pipeline.target_factories.target_factory_abc import TargetFactoryABC from polystar.common.utils.serialization import pkl_load @@ -79,8 +78,8 @@ class CommonModule(Module): @multiprovider @singleton - def provide_objects_validators(self) -> List[ObjectsValidatorABC[DetectedRobot]]: - return [ConfidenceObjectValidator(0.6)] + def provide_objects_validators(self) -> List[ObjectsFilterABC]: + return [ConfidenceObjectsFilter(0.6)] @provider @singleton diff --git a/polystar_cv/polystar/common/pipeline/keras/classifier.py b/polystar_cv/polystar/common/pipeline/keras/classifier.py index 05d55ba..f49a77b 100644 --- a/polystar_cv/polystar/common/pipeline/keras/classifier.py +++ b/polystar_cv/polystar/common/pipeline/keras/classifier.py @@ -10,7 +10,6 @@ from tensorflow.python.keras.utils.np_utils import to_categorical from polystar.common.models.image import Image from polystar.common.pipeline.classification.classifier_abc import ClassifierABC from polystar.common.pipeline.keras.trainer import KerasTrainer -from polystar.common.settings import settings from polystar.common.utils.registry import registry @@ -32,10 +31,11 @@ class KerasClassifier(ClassifierABC): return self def predict_proba(self, examples: List[Image]) -> Sequence[float]: - if settings.is_prod: # FIXME + try: # FIXME with self.graph.as_default(), self.session.as_default(): return self.model.predict(asarray(examples)) - return self.model.predict(asarray(examples)) + except AttributeError: + return self.model.predict(asarray(examples)) def __getstate__(self) -> Dict: with NamedTemporaryFile(suffix=".hdf5", delete=True) as fd: diff --git a/polystar_cv/polystar/common/target_pipeline/objects_validators/__init__.py b/polystar_cv/polystar/common/target_pipeline/objects_filters/__init__.py similarity index 100% rename from polystar_cv/polystar/common/target_pipeline/objects_validators/__init__.py rename to polystar_cv/polystar/common/target_pipeline/objects_filters/__init__.py diff --git a/polystar_cv/polystar/common/target_pipeline/objects_filters/armor_digit_filter.py b/polystar_cv/polystar/common/target_pipeline/objects_filters/armor_digit_filter.py new file mode 100644 index 0000000..4feea75 --- /dev/null +++ b/polystar_cv/polystar/common/target_pipeline/objects_filters/armor_digit_filter.py @@ -0,0 +1,12 @@ +from typing import Iterable + +from polystar.common.filters.filter_abc import FilterABC +from polystar.common.models.object import Armor, Object + + +class KeepArmorsDigitFilter(FilterABC[Object]): + def __init__(self, digits: Iterable[int]): + self.digits = digits + + def validate_single(self, obj: Object) -> bool: + return isinstance(obj, Armor) and obj.number in self.digits diff --git a/polystar_cv/polystar/common/target_pipeline/objects_validators/confidence_object_validator.py b/polystar_cv/polystar/common/target_pipeline/objects_filters/confidence_object_filter.py similarity index 54% rename from polystar_cv/polystar/common/target_pipeline/objects_validators/confidence_object_validator.py rename to polystar_cv/polystar/common/target_pipeline/objects_filters/confidence_object_filter.py index b17b316..285e19a 100644 --- a/polystar_cv/polystar/common/target_pipeline/objects_validators/confidence_object_validator.py +++ b/polystar_cv/polystar/common/target_pipeline/objects_filters/confidence_object_filter.py @@ -1,14 +1,12 @@ -import numpy as np - +from polystar.common.filters.filter_abc import FilterABC from polystar.common.target_pipeline.detected_objects.detected_object import DetectedObject -from polystar.common.target_pipeline.objects_validators.objects_validator_abc import ObjectsValidatorABC -class ConfidenceObjectValidator(ObjectsValidatorABC[DetectedObject]): +class ConfidenceObjectsFilter(FilterABC[DetectedObject]): """Keep only objects for which we are confident enough.""" def __init__(self, confidence_threshold: float): self.confidence_threshold = confidence_threshold - def validate_single(self, obj: DetectedObject, image: np.ndarray) -> bool: + def validate_single(self, obj: DetectedObject) -> bool: return obj.confidence >= self.confidence_threshold diff --git a/polystar_cv/polystar/common/target_pipeline/objects_filters/contains_box_filter.py b/polystar_cv/polystar/common/target_pipeline/objects_filters/contains_box_filter.py new file mode 100644 index 0000000..064bf7e --- /dev/null +++ b/polystar_cv/polystar/common/target_pipeline/objects_filters/contains_box_filter.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + +from polystar.common.models.box import Box +from polystar.common.models.object import Object +from polystar.common.target_pipeline.objects_filters.objects_filter_abc import ObjectsFilterABC + + +@dataclass +class ContainsBoxObjectsFilter(ObjectsFilterABC): + box: Box + min_percentage_intersection: float + + def validate_single(self, obj: Object) -> bool: + return obj.box.contains(self.box, self.min_percentage_intersection) diff --git a/polystar_cv/polystar/common/target_pipeline/objects_filters/in_box_filter.py b/polystar_cv/polystar/common/target_pipeline/objects_filters/in_box_filter.py new file mode 100644 index 0000000..abf3ab3 --- /dev/null +++ b/polystar_cv/polystar/common/target_pipeline/objects_filters/in_box_filter.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + +from polystar.common.models.box import Box +from polystar.common.models.object import Object +from polystar.common.target_pipeline.objects_filters.objects_filter_abc import ObjectsFilterABC + + +@dataclass +class InBoxObjectFilter(ObjectsFilterABC): + box: Box + min_percentage_intersection: float + + def validate_single(self, obj: Object) -> bool: + return self.box.contains(obj.box, self.min_percentage_intersection) diff --git a/polystar_cv/polystar/common/target_pipeline/objects_filters/objects_filter_abc.py b/polystar_cv/polystar/common/target_pipeline/objects_filters/objects_filter_abc.py new file mode 100644 index 0000000..b55acac --- /dev/null +++ b/polystar_cv/polystar/common/target_pipeline/objects_filters/objects_filter_abc.py @@ -0,0 +1,4 @@ +from polystar.common.filters.filter_abc import FilterABC +from polystar.common.models.object import Object + +ObjectsFilterABC = FilterABC[Object] diff --git a/polystar_cv/polystar/common/target_pipeline/objects_filters/type_object_filter.py b/polystar_cv/polystar/common/target_pipeline/objects_filters/type_object_filter.py new file mode 100644 index 0000000..709c290 --- /dev/null +++ b/polystar_cv/polystar/common/target_pipeline/objects_filters/type_object_filter.py @@ -0,0 +1,14 @@ +from typing import Iterable + +from polystar.common.models.object import Object, ObjectType +from polystar.common.target_pipeline.objects_filters.objects_filter_abc import ObjectsFilterABC + + +class TypeObjectsFilter(ObjectsFilterABC): + """Keep only the objects of a desired type""" + + def __init__(self, desired_types: Iterable[ObjectType]): + self.desired_types = set(desired_types) + + def validate_single(self, obj: Object) -> bool: + return obj.type in self.desired_types diff --git a/polystar_cv/polystar/common/target_pipeline/objects_linker/simple_objects_linker.py b/polystar_cv/polystar/common/target_pipeline/objects_linker/simple_objects_linker.py index 86dd41d..93bf08f 100644 --- a/polystar_cv/polystar/common/target_pipeline/objects_linker/simple_objects_linker.py +++ b/polystar_cv/polystar/common/target_pipeline/objects_linker/simple_objects_linker.py @@ -1,29 +1,23 @@ from typing import Iterable, List from polystar.common.models.image import Image -from polystar.common.models.object import ObjectType from polystar.common.target_pipeline.detected_objects.detected_armor import DetectedArmor from polystar.common.target_pipeline.detected_objects.detected_robot import DetectedRobot, FakeDetectedRobot +from polystar.common.target_pipeline.objects_filters.contains_box_filter import ContainsBoxObjectsFilter from polystar.common.target_pipeline.objects_linker.objects_linker_abs import ObjectsLinkerABC -from polystar.common.target_pipeline.objects_validators.contains_box_validator import ContainsBoxValidator -from polystar.common.target_pipeline.objects_validators.negation_validator import NegationValidator -from polystar.common.target_pipeline.objects_validators.type_object_validator import TypeObjectValidator class SimpleObjectsLinker(ObjectsLinkerABC): def __init__(self, min_percentage_intersection: float): super().__init__() self.min_percentage_intersection = min_percentage_intersection - self.robots_filter = NegationValidator(TypeObjectValidator(ObjectType.ARMOR)) - self.armors_filter = TypeObjectValidator(ObjectType.ARMOR) def link_armors_to_robots( self, robots: List[DetectedRobot], armors: List[DetectedArmor], image: Image ) -> Iterable[DetectedRobot]: for armor in armors: - parents_filter = ContainsBoxValidator[DetectedRobot](armor.box, self.min_percentage_intersection) - parents = parents_filter.filter(robots, image) + parents = ContainsBoxObjectsFilter(armor.box, self.min_percentage_intersection).filter(robots) if len(parents) != 1: yield FakeDetectedRobot(armor) else: diff --git a/polystar_cv/polystar/common/target_pipeline/objects_validators/armor_digit_validator.py b/polystar_cv/polystar/common/target_pipeline/objects_validators/armor_digit_validator.py deleted file mode 100644 index 060d78c..0000000 --- a/polystar_cv/polystar/common/target_pipeline/objects_validators/armor_digit_validator.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Iterable - -from numpy.core.multiarray import ndarray - -from polystar.common.models.object import Armor -from polystar.common.target_pipeline.detected_objects.detected_robot import DetectedRobot -from polystar.common.target_pipeline.objects_validators.objects_validator_abc import ObjectsValidatorABC - - -class ArmorDigitValidator(ObjectsValidatorABC[DetectedRobot]): - def __init__(self, digits: Iterable[int]): - self.digits = digits - - def validate_single(self, armor: Armor, image: ndarray) -> bool: - return isinstance(armor, Armor) and armor.number in self.digits diff --git a/polystar_cv/polystar/common/target_pipeline/objects_validators/contains_box_validator.py b/polystar_cv/polystar/common/target_pipeline/objects_validators/contains_box_validator.py deleted file mode 100644 index 77c2a53..0000000 --- a/polystar_cv/polystar/common/target_pipeline/objects_validators/contains_box_validator.py +++ /dev/null @@ -1,14 +0,0 @@ -from dataclasses import dataclass - -from polystar.common.models.box import Box -from polystar.common.models.image import Image -from polystar.common.target_pipeline.objects_validators.objects_validator_abc import ObjectsValidatorABC, ObjectT - - -@dataclass -class ContainsBoxValidator(ObjectsValidatorABC[ObjectT]): - box: Box - min_percentage_intersection: float - - def validate_single(self, obj: ObjectT, image: Image) -> bool: - return obj.box.contains(self.box, self.min_percentage_intersection) diff --git a/polystar_cv/polystar/common/target_pipeline/objects_validators/in_box_validator.py b/polystar_cv/polystar/common/target_pipeline/objects_validators/in_box_validator.py deleted file mode 100644 index 800cbd8..0000000 --- a/polystar_cv/polystar/common/target_pipeline/objects_validators/in_box_validator.py +++ /dev/null @@ -1,14 +0,0 @@ -from dataclasses import dataclass - -from polystar.common.models.box import Box -from polystar.common.models.image import Image -from polystar.common.target_pipeline.objects_validators.objects_validator_abc import ObjectsValidatorABC, ObjectT - - -@dataclass -class InBoxValidator(ObjectsValidatorABC[ObjectT]): - box: Box - min_percentage_intersection: float - - def validate_single(self, obj: ObjectT, image: Image) -> bool: - return self.box.contains(obj.box, self.min_percentage_intersection) diff --git a/polystar_cv/polystar/common/target_pipeline/objects_validators/negation_validator.py b/polystar_cv/polystar/common/target_pipeline/objects_validators/negation_validator.py deleted file mode 100644 index b660bfa..0000000 --- a/polystar_cv/polystar/common/target_pipeline/objects_validators/negation_validator.py +++ /dev/null @@ -1,13 +0,0 @@ -from dataclasses import dataclass - -import numpy as np - -from polystar.common.target_pipeline.objects_validators.objects_validator_abc import ObjectsValidatorABC, ObjectT - - -@dataclass -class NegationValidator(ObjectsValidatorABC): - validator: ObjectsValidatorABC - - def validate_single(self, obj: ObjectT, image: np.ndarray) -> bool: - return not self.validator.validate_single(obj, image) diff --git a/polystar_cv/polystar/common/target_pipeline/objects_validators/objects_validator_abc.py b/polystar_cv/polystar/common/target_pipeline/objects_validators/objects_validator_abc.py deleted file mode 100644 index 075ce15..0000000 --- a/polystar_cv/polystar/common/target_pipeline/objects_validators/objects_validator_abc.py +++ /dev/null @@ -1,20 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Generic, List, TypeVar - -import numpy as np - -from polystar.common.models.object import Object - -ObjectT = TypeVar("ObjectT", bound=Object) - - -class ObjectsValidatorABC(Generic[ObjectT], ABC): # FIXME Filter would do here - def filter(self, objects: List[ObjectT], image: np.ndarray) -> List[ObjectT]: - return [obj for obj, is_valid in zip(objects, self.validate(objects, image)) if is_valid] - - def validate(self, objects: List[ObjectT], image: np.ndarray) -> List[bool]: - return [self.validate_single(obj, image) for obj in objects] - - @abstractmethod - def validate_single(self, obj: ObjectT, image: np.ndarray) -> bool: - pass diff --git a/polystar_cv/polystar/common/target_pipeline/objects_validators/robot_color_validator.py b/polystar_cv/polystar/common/target_pipeline/objects_validators/robot_color_validator.py deleted file mode 100644 index b50c2ed..0000000 --- a/polystar_cv/polystar/common/target_pipeline/objects_validators/robot_color_validator.py +++ /dev/null @@ -1,17 +0,0 @@ -from dataclasses import dataclass - -import numpy as np - -from polystar.common.models.object import ArmorColor -from polystar.common.target_pipeline.detected_objects.detected_robot import DetectedRobot -from polystar.common.target_pipeline.objects_validators.objects_validator_abc import ObjectsValidatorABC - - -@dataclass -class RobotPercentageColorValidator(ObjectsValidatorABC[DetectedRobot]): - color: ArmorColor - min_percentage: 0.5 - - def validate_single(self, robot: DetectedRobot, image: np.ndarray) -> bool: - good_colors = [armor.color is self.color for armor in robot.armors] - return sum(good_colors) >= len(good_colors) * self.min_percentage diff --git a/polystar_cv/polystar/common/target_pipeline/objects_validators/type_object_validator.py b/polystar_cv/polystar/common/target_pipeline/objects_validators/type_object_validator.py deleted file mode 100644 index eb001d0..0000000 --- a/polystar_cv/polystar/common/target_pipeline/objects_validators/type_object_validator.py +++ /dev/null @@ -1,14 +0,0 @@ -import numpy as np - -from polystar.common.models.object import ObjectType, Object -from polystar.common.target_pipeline.objects_validators.objects_validator_abc import ObjectsValidatorABC - - -class TypeObjectValidator(ObjectsValidatorABC[Object]): - """Keep only the objects of a desired type""" - - def __init__(self, *desired_types: ObjectType): - self.desired_types = set(desired_types) - - def validate_single(self, obj: Object, image: np.ndarray) -> bool: - return obj.type in self.desired_types diff --git a/polystar_cv/polystar/common/target_pipeline/target_pipeline.py b/polystar_cv/polystar/common/target_pipeline/target_pipeline.py index 9b51450..1dd3bdc 100644 --- a/polystar_cv/polystar/common/target_pipeline/target_pipeline.py +++ b/polystar_cv/polystar/common/target_pipeline/target_pipeline.py @@ -9,8 +9,8 @@ from polystar.common.target_pipeline.detected_objects.detected_object import Det from polystar.common.target_pipeline.detected_objects.detected_robot import DetectedRobot from polystar.common.target_pipeline.object_selectors.object_selector_abc import ObjectSelectorABC from polystar.common.target_pipeline.objects_detectors.objects_detector_abc import ObjectsDetectorABC +from polystar.common.target_pipeline.objects_filters.objects_filter_abc import ObjectsFilterABC from polystar.common.target_pipeline.objects_linker.objects_linker_abs import ObjectsLinkerABC -from polystar.common.target_pipeline.objects_validators.objects_validator_abc import ObjectsValidatorABC from polystar.common.target_pipeline.target_abc import TargetABC from polystar.common.target_pipeline.target_factories.target_factory_abc import TargetFactoryABC @@ -25,7 +25,7 @@ class TargetPipeline: objects_detector: ObjectsDetectorABC objects_linker: ObjectsLinkerABC - objects_validators: List[ObjectsValidatorABC[DetectedRobot]] + objects_filters: List[ObjectsFilterABC] object_selector: ObjectSelectorABC target_factory: TargetFactoryABC target_sender: TargetSenderABC @@ -43,16 +43,16 @@ class TargetPipeline: def _get_robots_of_interest(self, image: Image) -> List[DetectedRobot]: robots = self._detect_robots(image) - robots = self._filter_robots(image, robots) + robots = self._filter_robots(robots) if not any(robot.armors for robot in robots): raise NoTargetFoundException() return robots - def _filter_robots(self, image: Image, robots: List[DetectedRobot]) -> List[DetectedRobot]: - for robots_validator in self.objects_validators: - robots = robots_validator.filter(robots, image) + def _filter_robots(self, robots: List[DetectedRobot]) -> List[DetectedRobot]: + for robots_validator in self.objects_filters: + robots = robots_validator.filter(robots) return robots def _detect_robots(self, image: Image) -> List[DetectedRobot]: diff --git a/polystar_cv/research/common/dataset/improvement/zoom.py b/polystar_cv/research/common/dataset/improvement/zoom.py index f233890..756dc01 100644 --- a/polystar_cv/research/common/dataset/improvement/zoom.py +++ b/polystar_cv/research/common/dataset/improvement/zoom.py @@ -5,7 +5,7 @@ from typing import Iterable, List, Tuple from polystar.common.models.box import Box from polystar.common.models.image import Image -from polystar.common.target_pipeline.objects_validators.in_box_validator import InBoxValidator +from polystar.common.target_pipeline.objects_filters.in_box_filter import InBoxObjectFilter from polystar.common.view.plt_results_viewer import PltResultViewer from research.common.datasets.roco.roco_annotation import ROCOAnnotation from research.common.datasets.roco.zoo.roco_dataset_zoo import ROCODatasetsZoo @@ -14,7 +14,7 @@ from research.common.datasets.roco.zoo.roco_dataset_zoo import ROCODatasetsZoo def crop_image_annotation( image: Image, annotation: ROCOAnnotation, box: Box, min_coverage: float, name: str ) -> Tuple[Image, ROCOAnnotation, str]: - objects = InBoxValidator(box, min_coverage).filter(annotation.objects, image) + objects = InBoxObjectFilter(box, min_coverage).filter(annotation.objects) objects = [copy(o) for o in objects] for obj in objects: obj.box = Box.from_positions( diff --git a/polystar_cv/research/common/datasets/roco/roco_annotation.py b/polystar_cv/research/common/datasets/roco/roco_annotation.py index 0317bdd..e57b065 100644 --- a/polystar_cv/research/common/datasets/roco/roco_annotation.py +++ b/polystar_cv/research/common/datasets/roco/roco_annotation.py @@ -6,7 +6,8 @@ from xml.dom.minidom import parseString import xmltodict from dicttoxml import dicttoxml -from polystar.common.models.object import Object, ObjectFactory + +from polystar.common.models.object import Armor, Object, ObjectFactory @dataclass @@ -18,6 +19,10 @@ class ROCOAnnotation: w: int h: int + @property + def armors(self) -> List[Armor]: + return [obj for obj in self.objects if isinstance(obj, Armor)] + @staticmethod def from_xml_file(xml_file: Path) -> "ROCOAnnotation": try: diff --git a/polystar_cv/research/common/datasets/roco/roco_annotation_filters/roco_annotation_object_filter.py b/polystar_cv/research/common/datasets/roco/roco_annotation_filters/roco_annotation_object_filter.py index 8b9d61d..7495b4a 100644 --- a/polystar_cv/research/common/datasets/roco/roco_annotation_filters/roco_annotation_object_filter.py +++ b/polystar_cv/research/common/datasets/roco/roco_annotation_filters/roco_annotation_object_filter.py @@ -1,11 +1,11 @@ from polystar.common.filters.filter_abc import FilterABC -from polystar.common.target_pipeline.objects_validators.objects_validator_abc import ObjectsValidatorABC +from polystar.common.target_pipeline.objects_filters.objects_filter_abc import ObjectsFilterABC from research.common.datasets.roco.roco_annotation import ROCOAnnotation class ROCOAnnotationObjectFilter(FilterABC): - def __init__(self, object_validator: ObjectsValidatorABC): - self.object_validator = object_validator + def __init__(self, object_filter: ObjectsFilterABC): + self.object_filter = object_filter def validate_single(self, annotation: ROCOAnnotation) -> bool: - return any(self.object_validator.validate(annotation.objects, None)) + return any(self.object_filter.validate(annotation.objects)) diff --git a/polystar_cv/research/robots/dataset/armor_dataset_factory.py b/polystar_cv/research/robots/dataset/armor_dataset_factory.py index c14eca1..a429e5a 100644 --- a/polystar_cv/research/robots/dataset/armor_dataset_factory.py +++ b/polystar_cv/research/robots/dataset/armor_dataset_factory.py @@ -1,11 +1,10 @@ from itertools import islice -from typing import Iterator, List, Tuple +from typing import Iterator, Tuple import matplotlib.pyplot as plt from polystar.common.models.image import Image -from polystar.common.models.object import Armor, ObjectType -from polystar.common.target_pipeline.objects_validators.type_object_validator import TypeObjectValidator +from polystar.common.models.object import Armor from research.common.datasets.lazy_dataset import LazyDataset from research.common.datasets.roco.roco_annotation import ROCOAnnotation from research.common.datasets.roco.roco_dataset import LazyROCODataset @@ -23,9 +22,7 @@ class ArmorDataset(LazyDataset[Image, Armor]): @staticmethod def _generate_from_single(image: Image, annotation: ROCOAnnotation, name) -> Iterator[Tuple[Image, Armor, str]]: - armors: List[Armor] = TypeObjectValidator(ObjectType.ARMOR).filter(annotation.objects, image) - - for i, obj in enumerate(armors): + for i, obj in enumerate(annotation.armors): croped_img = image[obj.box.y1 : obj.box.y2, obj.box.x1 : obj.box.x2] yield croped_img, obj, f"{name}-{i}" diff --git a/polystar_cv/research/robots/demos/demo_pipeline.py b/polystar_cv/research/robots/demos/demo_pipeline.py index 44fbbbe..d7678f8 100644 --- a/polystar_cv/research/robots/demos/demo_pipeline.py +++ b/polystar_cv/research/robots/demos/demo_pipeline.py @@ -2,7 +2,7 @@ from injector import inject from polystar.common.dependency_injection import make_injector from polystar.common.target_pipeline.debug_pipeline import DebugTargetPipeline -from polystar.common.target_pipeline.objects_validators.armor_digit_validator import ArmorDigitValidator +from polystar.common.target_pipeline.objects_filters.armor_digit_filter import KeepArmorsDigitFilter from polystar.common.target_pipeline.target_pipeline import NoTargetFoundException from polystar.common.view.plt_results_viewer import PltResultViewer from research.common.datasets.roco.roco_annotation_filters.roco_annotation_object_filter import ( @@ -17,7 +17,7 @@ def demo_pipeline_on_images(pipeline: DebugTargetPipeline): for builder in ROCODatasetsZoo.DEFAULT_TEST_DATASETS: for image in ( builder.to_images() - .filter_targets(ROCOAnnotationObjectFilter(ArmorDigitValidator((1, 3, 4)))) + .filter_targets(ROCOAnnotationObjectFilter(KeepArmorsDigitFilter((1, 3, 4)))) .shuffle() .cap(15) .build_examples() diff --git a/polystar_cv/tests/common/unittests/object_validators/test_in_box_validator.py b/polystar_cv/tests/common/unittests/object_validators/test_in_box_validator.py index f7aaa52..5476955 100644 --- a/polystar_cv/tests/common/unittests/object_validators/test_in_box_validator.py +++ b/polystar_cv/tests/common/unittests/object_validators/test_in_box_validator.py @@ -1,13 +1,13 @@ import unittest from polystar.common.models.box import Box -from polystar.common.models.object import Object -from polystar.common.target_pipeline.objects_validators.in_box_validator import InBoxValidator +from polystar.common.models.object import Object, ObjectType +from polystar.common.target_pipeline.objects_filters.in_box_filter import InBoxObjectFilter -class TestInBoxValidator(unittest.TestCase): +class TestInBoxObjectFilter(unittest.TestCase): def setUp(self) -> None: - self.in_box_validator = InBoxValidator(Box.from_size(2, 2, 6, 4), 0.5) + self.in_box_validator = InBoxObjectFilter(Box.from_size(2, 2, 6, 4), 0.5) def test_fully_inside(self): self._test_obj(3, 3, 2, 2, True) @@ -44,5 +44,5 @@ class TestInBoxValidator(unittest.TestCase): def _test_obj(self, x: int, y: int, w: int, h: int, is_inside: bool): self.assertEqual( - is_inside, self.in_box_validator.validate_single(Object(None, Box.from_size(x, y, w, h)), None) + is_inside, self.in_box_validator.validate_single(Object(ObjectType.CAR, Box.from_size(x, y, w, h))) ) -- GitLab