diff --git a/polystar_cv/polystar/common/dependency_injection.py b/polystar_cv/polystar/common/dependency_injection.py index 5482b1ce49695a0612cea7fa0a6b41d71af0aca7..097eab071296cfef7f377eac12aaa0ae82fb3d16 100644 --- a/polystar_cv/polystar/common/dependency_injection.py +++ b/polystar_cv/polystar/common/dependency_injection.py @@ -16,15 +16,14 @@ from polystar.common.target_pipeline.armors_descriptors.armors_color_descriptor from polystar.common.target_pipeline.armors_descriptors.armors_descriptor_abc import ArmorsDescriptorABC from polystar.common.target_pipeline.armors_descriptors.armors_digit_descriptor import ArmorsDigitDescriptor from polystar.common.target_pipeline.detected_objects.detected_objects_factory import DetectedObjectFactory -from polystar.common.target_pipeline.detected_objects.detected_robot import DetectedRobot from polystar.common.target_pipeline.object_selectors.closest_object_selector import ClosestObjectSelector from polystar.common.target_pipeline.object_selectors.object_selector_abc import ObjectSelectorABC from polystar.common.target_pipeline.objects_detectors.objects_detector_abc import ObjectsDetectorABC from polystar.common.target_pipeline.objects_detectors.tf_model_objects_detector import TFModelObjectsDetector +from polystar.common.target_pipeline.objects_filters.confidence_object_filter import ConfidenceObjectsFilter +from polystar.common.target_pipeline.objects_filters.objects_filter_abc import ObjectsFilterABC from polystar.common.target_pipeline.objects_linker.objects_linker_abs import ObjectsLinkerABC from polystar.common.target_pipeline.objects_linker.simple_objects_linker import SimpleObjectsLinker -from polystar.common.target_pipeline.objects_validators.confidence_object_validator import ConfidenceObjectValidator -from polystar.common.target_pipeline.objects_validators.objects_validator_abc import ObjectsValidatorABC from polystar.common.target_pipeline.target_factories.ratio_simple_target_factory import RatioSimpleTargetFactory from polystar.common.target_pipeline.target_factories.target_factory_abc import TargetFactoryABC from polystar.common.utils.serialization import pkl_load @@ -79,8 +78,8 @@ class CommonModule(Module): @multiprovider @singleton - def provide_objects_validators(self) -> List[ObjectsValidatorABC[DetectedRobot]]: - return [ConfidenceObjectValidator(0.6)] + def provide_objects_validators(self) -> List[ObjectsFilterABC]: + return [ConfidenceObjectsFilter(0.6)] @provider @singleton diff --git a/polystar_cv/polystar/common/pipeline/keras/classifier.py b/polystar_cv/polystar/common/pipeline/keras/classifier.py index 05d55ba7d464d4babbba1151fe55e64225bfc28d..f49a77b27315890daa0121eb53a4488ca932782e 100644 --- a/polystar_cv/polystar/common/pipeline/keras/classifier.py +++ b/polystar_cv/polystar/common/pipeline/keras/classifier.py @@ -10,7 +10,6 @@ from tensorflow.python.keras.utils.np_utils import to_categorical from polystar.common.models.image import Image from polystar.common.pipeline.classification.classifier_abc import ClassifierABC from polystar.common.pipeline.keras.trainer import KerasTrainer -from polystar.common.settings import settings from polystar.common.utils.registry import registry @@ -32,10 +31,11 @@ class KerasClassifier(ClassifierABC): return self def predict_proba(self, examples: List[Image]) -> Sequence[float]: - if settings.is_prod: # FIXME + try: # FIXME with self.graph.as_default(), self.session.as_default(): return self.model.predict(asarray(examples)) - return self.model.predict(asarray(examples)) + except AttributeError: + return self.model.predict(asarray(examples)) def __getstate__(self) -> Dict: with NamedTemporaryFile(suffix=".hdf5", delete=True) as fd: diff --git a/polystar_cv/polystar/common/target_pipeline/objects_validators/__init__.py b/polystar_cv/polystar/common/target_pipeline/objects_filters/__init__.py similarity index 100% rename from polystar_cv/polystar/common/target_pipeline/objects_validators/__init__.py rename to polystar_cv/polystar/common/target_pipeline/objects_filters/__init__.py diff --git a/polystar_cv/polystar/common/target_pipeline/objects_filters/armor_digit_filter.py b/polystar_cv/polystar/common/target_pipeline/objects_filters/armor_digit_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..4feea7592f9100a2e70ccdc14696f591415ac4f9 --- /dev/null +++ b/polystar_cv/polystar/common/target_pipeline/objects_filters/armor_digit_filter.py @@ -0,0 +1,12 @@ +from typing import Iterable + +from polystar.common.filters.filter_abc import FilterABC +from polystar.common.models.object import Armor, Object + + +class KeepArmorsDigitFilter(FilterABC[Object]): + def __init__(self, digits: Iterable[int]): + self.digits = digits + + def validate_single(self, obj: Object) -> bool: + return isinstance(obj, Armor) and obj.number in self.digits diff --git a/polystar_cv/polystar/common/target_pipeline/objects_validators/confidence_object_validator.py b/polystar_cv/polystar/common/target_pipeline/objects_filters/confidence_object_filter.py similarity index 54% rename from polystar_cv/polystar/common/target_pipeline/objects_validators/confidence_object_validator.py rename to polystar_cv/polystar/common/target_pipeline/objects_filters/confidence_object_filter.py index b17b31671680303f2370fa9f94381fe550590b9e..285e19a55c984ef4e0e026ff46d068070adaa7ee 100644 --- a/polystar_cv/polystar/common/target_pipeline/objects_validators/confidence_object_validator.py +++ b/polystar_cv/polystar/common/target_pipeline/objects_filters/confidence_object_filter.py @@ -1,14 +1,12 @@ -import numpy as np - +from polystar.common.filters.filter_abc import FilterABC from polystar.common.target_pipeline.detected_objects.detected_object import DetectedObject -from polystar.common.target_pipeline.objects_validators.objects_validator_abc import ObjectsValidatorABC -class ConfidenceObjectValidator(ObjectsValidatorABC[DetectedObject]): +class ConfidenceObjectsFilter(FilterABC[DetectedObject]): """Keep only objects for which we are confident enough.""" def __init__(self, confidence_threshold: float): self.confidence_threshold = confidence_threshold - def validate_single(self, obj: DetectedObject, image: np.ndarray) -> bool: + def validate_single(self, obj: DetectedObject) -> bool: return obj.confidence >= self.confidence_threshold diff --git a/polystar_cv/polystar/common/target_pipeline/objects_filters/contains_box_filter.py b/polystar_cv/polystar/common/target_pipeline/objects_filters/contains_box_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..064bf7eb5b41b208866af4828c0d80b162abc1f6 --- /dev/null +++ b/polystar_cv/polystar/common/target_pipeline/objects_filters/contains_box_filter.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + +from polystar.common.models.box import Box +from polystar.common.models.object import Object +from polystar.common.target_pipeline.objects_filters.objects_filter_abc import ObjectsFilterABC + + +@dataclass +class ContainsBoxObjectsFilter(ObjectsFilterABC): + box: Box + min_percentage_intersection: float + + def validate_single(self, obj: Object) -> bool: + return obj.box.contains(self.box, self.min_percentage_intersection) diff --git a/polystar_cv/polystar/common/target_pipeline/objects_filters/in_box_filter.py b/polystar_cv/polystar/common/target_pipeline/objects_filters/in_box_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..abf3ab38aadacf6a828a7f84114586e73513a926 --- /dev/null +++ b/polystar_cv/polystar/common/target_pipeline/objects_filters/in_box_filter.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + +from polystar.common.models.box import Box +from polystar.common.models.object import Object +from polystar.common.target_pipeline.objects_filters.objects_filter_abc import ObjectsFilterABC + + +@dataclass +class InBoxObjectFilter(ObjectsFilterABC): + box: Box + min_percentage_intersection: float + + def validate_single(self, obj: Object) -> bool: + return self.box.contains(obj.box, self.min_percentage_intersection) diff --git a/polystar_cv/polystar/common/target_pipeline/objects_filters/objects_filter_abc.py b/polystar_cv/polystar/common/target_pipeline/objects_filters/objects_filter_abc.py new file mode 100644 index 0000000000000000000000000000000000000000..b55acac424898a11d8c4821f8d81afce845d86f6 --- /dev/null +++ b/polystar_cv/polystar/common/target_pipeline/objects_filters/objects_filter_abc.py @@ -0,0 +1,4 @@ +from polystar.common.filters.filter_abc import FilterABC +from polystar.common.models.object import Object + +ObjectsFilterABC = FilterABC[Object] diff --git a/polystar_cv/polystar/common/target_pipeline/objects_filters/type_object_filter.py b/polystar_cv/polystar/common/target_pipeline/objects_filters/type_object_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..709c290f22dd2a3f5251c250e5ac28cdc820ebf8 --- /dev/null +++ b/polystar_cv/polystar/common/target_pipeline/objects_filters/type_object_filter.py @@ -0,0 +1,14 @@ +from typing import Iterable + +from polystar.common.models.object import Object, ObjectType +from polystar.common.target_pipeline.objects_filters.objects_filter_abc import ObjectsFilterABC + + +class TypeObjectsFilter(ObjectsFilterABC): + """Keep only the objects of a desired type""" + + def __init__(self, desired_types: Iterable[ObjectType]): + self.desired_types = set(desired_types) + + def validate_single(self, obj: Object) -> bool: + return obj.type in self.desired_types diff --git a/polystar_cv/polystar/common/target_pipeline/objects_linker/simple_objects_linker.py b/polystar_cv/polystar/common/target_pipeline/objects_linker/simple_objects_linker.py index 86dd41d700442947e2e12d3c070bd5ac664f1d54..93bf08f519ac97378197e34af2e541db89fd7e8b 100644 --- a/polystar_cv/polystar/common/target_pipeline/objects_linker/simple_objects_linker.py +++ b/polystar_cv/polystar/common/target_pipeline/objects_linker/simple_objects_linker.py @@ -1,29 +1,23 @@ from typing import Iterable, List from polystar.common.models.image import Image -from polystar.common.models.object import ObjectType from polystar.common.target_pipeline.detected_objects.detected_armor import DetectedArmor from polystar.common.target_pipeline.detected_objects.detected_robot import DetectedRobot, FakeDetectedRobot +from polystar.common.target_pipeline.objects_filters.contains_box_filter import ContainsBoxObjectsFilter from polystar.common.target_pipeline.objects_linker.objects_linker_abs import ObjectsLinkerABC -from polystar.common.target_pipeline.objects_validators.contains_box_validator import ContainsBoxValidator -from polystar.common.target_pipeline.objects_validators.negation_validator import NegationValidator -from polystar.common.target_pipeline.objects_validators.type_object_validator import TypeObjectValidator class SimpleObjectsLinker(ObjectsLinkerABC): def __init__(self, min_percentage_intersection: float): super().__init__() self.min_percentage_intersection = min_percentage_intersection - self.robots_filter = NegationValidator(TypeObjectValidator(ObjectType.ARMOR)) - self.armors_filter = TypeObjectValidator(ObjectType.ARMOR) def link_armors_to_robots( self, robots: List[DetectedRobot], armors: List[DetectedArmor], image: Image ) -> Iterable[DetectedRobot]: for armor in armors: - parents_filter = ContainsBoxValidator[DetectedRobot](armor.box, self.min_percentage_intersection) - parents = parents_filter.filter(robots, image) + parents = ContainsBoxObjectsFilter(armor.box, self.min_percentage_intersection).filter(robots) if len(parents) != 1: yield FakeDetectedRobot(armor) else: diff --git a/polystar_cv/polystar/common/target_pipeline/objects_validators/armor_digit_validator.py b/polystar_cv/polystar/common/target_pipeline/objects_validators/armor_digit_validator.py deleted file mode 100644 index 060d78c7c3fa6b3949fd384c8ad38b19e6b338d0..0000000000000000000000000000000000000000 --- a/polystar_cv/polystar/common/target_pipeline/objects_validators/armor_digit_validator.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Iterable - -from numpy.core.multiarray import ndarray - -from polystar.common.models.object import Armor -from polystar.common.target_pipeline.detected_objects.detected_robot import DetectedRobot -from polystar.common.target_pipeline.objects_validators.objects_validator_abc import ObjectsValidatorABC - - -class ArmorDigitValidator(ObjectsValidatorABC[DetectedRobot]): - def __init__(self, digits: Iterable[int]): - self.digits = digits - - def validate_single(self, armor: Armor, image: ndarray) -> bool: - return isinstance(armor, Armor) and armor.number in self.digits diff --git a/polystar_cv/polystar/common/target_pipeline/objects_validators/contains_box_validator.py b/polystar_cv/polystar/common/target_pipeline/objects_validators/contains_box_validator.py deleted file mode 100644 index 77c2a532ac251fa45339bd5b3017a7b7bb018ff6..0000000000000000000000000000000000000000 --- a/polystar_cv/polystar/common/target_pipeline/objects_validators/contains_box_validator.py +++ /dev/null @@ -1,14 +0,0 @@ -from dataclasses import dataclass - -from polystar.common.models.box import Box -from polystar.common.models.image import Image -from polystar.common.target_pipeline.objects_validators.objects_validator_abc import ObjectsValidatorABC, ObjectT - - -@dataclass -class ContainsBoxValidator(ObjectsValidatorABC[ObjectT]): - box: Box - min_percentage_intersection: float - - def validate_single(self, obj: ObjectT, image: Image) -> bool: - return obj.box.contains(self.box, self.min_percentage_intersection) diff --git a/polystar_cv/polystar/common/target_pipeline/objects_validators/in_box_validator.py b/polystar_cv/polystar/common/target_pipeline/objects_validators/in_box_validator.py deleted file mode 100644 index 800cbd82b1a615adcf17b73ef7e5dded43bf6099..0000000000000000000000000000000000000000 --- a/polystar_cv/polystar/common/target_pipeline/objects_validators/in_box_validator.py +++ /dev/null @@ -1,14 +0,0 @@ -from dataclasses import dataclass - -from polystar.common.models.box import Box -from polystar.common.models.image import Image -from polystar.common.target_pipeline.objects_validators.objects_validator_abc import ObjectsValidatorABC, ObjectT - - -@dataclass -class InBoxValidator(ObjectsValidatorABC[ObjectT]): - box: Box - min_percentage_intersection: float - - def validate_single(self, obj: ObjectT, image: Image) -> bool: - return self.box.contains(obj.box, self.min_percentage_intersection) diff --git a/polystar_cv/polystar/common/target_pipeline/objects_validators/negation_validator.py b/polystar_cv/polystar/common/target_pipeline/objects_validators/negation_validator.py deleted file mode 100644 index b660bfa9f92c9f9d935f87db7ee3ddaf74ae47c0..0000000000000000000000000000000000000000 --- a/polystar_cv/polystar/common/target_pipeline/objects_validators/negation_validator.py +++ /dev/null @@ -1,13 +0,0 @@ -from dataclasses import dataclass - -import numpy as np - -from polystar.common.target_pipeline.objects_validators.objects_validator_abc import ObjectsValidatorABC, ObjectT - - -@dataclass -class NegationValidator(ObjectsValidatorABC): - validator: ObjectsValidatorABC - - def validate_single(self, obj: ObjectT, image: np.ndarray) -> bool: - return not self.validator.validate_single(obj, image) diff --git a/polystar_cv/polystar/common/target_pipeline/objects_validators/objects_validator_abc.py b/polystar_cv/polystar/common/target_pipeline/objects_validators/objects_validator_abc.py deleted file mode 100644 index 075ce1521f5764bc98ba8fc10ae473514b83d0a8..0000000000000000000000000000000000000000 --- a/polystar_cv/polystar/common/target_pipeline/objects_validators/objects_validator_abc.py +++ /dev/null @@ -1,20 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Generic, List, TypeVar - -import numpy as np - -from polystar.common.models.object import Object - -ObjectT = TypeVar("ObjectT", bound=Object) - - -class ObjectsValidatorABC(Generic[ObjectT], ABC): # FIXME Filter would do here - def filter(self, objects: List[ObjectT], image: np.ndarray) -> List[ObjectT]: - return [obj for obj, is_valid in zip(objects, self.validate(objects, image)) if is_valid] - - def validate(self, objects: List[ObjectT], image: np.ndarray) -> List[bool]: - return [self.validate_single(obj, image) for obj in objects] - - @abstractmethod - def validate_single(self, obj: ObjectT, image: np.ndarray) -> bool: - pass diff --git a/polystar_cv/polystar/common/target_pipeline/objects_validators/robot_color_validator.py b/polystar_cv/polystar/common/target_pipeline/objects_validators/robot_color_validator.py deleted file mode 100644 index b50c2ed387caf931f963d691bf6fc3b0bb6c7aaa..0000000000000000000000000000000000000000 --- a/polystar_cv/polystar/common/target_pipeline/objects_validators/robot_color_validator.py +++ /dev/null @@ -1,17 +0,0 @@ -from dataclasses import dataclass - -import numpy as np - -from polystar.common.models.object import ArmorColor -from polystar.common.target_pipeline.detected_objects.detected_robot import DetectedRobot -from polystar.common.target_pipeline.objects_validators.objects_validator_abc import ObjectsValidatorABC - - -@dataclass -class RobotPercentageColorValidator(ObjectsValidatorABC[DetectedRobot]): - color: ArmorColor - min_percentage: 0.5 - - def validate_single(self, robot: DetectedRobot, image: np.ndarray) -> bool: - good_colors = [armor.color is self.color for armor in robot.armors] - return sum(good_colors) >= len(good_colors) * self.min_percentage diff --git a/polystar_cv/polystar/common/target_pipeline/objects_validators/type_object_validator.py b/polystar_cv/polystar/common/target_pipeline/objects_validators/type_object_validator.py deleted file mode 100644 index eb001d0b00071c3b753ccfa331f583dbace85660..0000000000000000000000000000000000000000 --- a/polystar_cv/polystar/common/target_pipeline/objects_validators/type_object_validator.py +++ /dev/null @@ -1,14 +0,0 @@ -import numpy as np - -from polystar.common.models.object import ObjectType, Object -from polystar.common.target_pipeline.objects_validators.objects_validator_abc import ObjectsValidatorABC - - -class TypeObjectValidator(ObjectsValidatorABC[Object]): - """Keep only the objects of a desired type""" - - def __init__(self, *desired_types: ObjectType): - self.desired_types = set(desired_types) - - def validate_single(self, obj: Object, image: np.ndarray) -> bool: - return obj.type in self.desired_types diff --git a/polystar_cv/polystar/common/target_pipeline/target_pipeline.py b/polystar_cv/polystar/common/target_pipeline/target_pipeline.py index 9b51450094f407ff103e1d0a5cd76cc41ffaabd7..1dd3bdcec2071222a91b9b87f53bd27f14d1e0fb 100644 --- a/polystar_cv/polystar/common/target_pipeline/target_pipeline.py +++ b/polystar_cv/polystar/common/target_pipeline/target_pipeline.py @@ -9,8 +9,8 @@ from polystar.common.target_pipeline.detected_objects.detected_object import Det from polystar.common.target_pipeline.detected_objects.detected_robot import DetectedRobot from polystar.common.target_pipeline.object_selectors.object_selector_abc import ObjectSelectorABC from polystar.common.target_pipeline.objects_detectors.objects_detector_abc import ObjectsDetectorABC +from polystar.common.target_pipeline.objects_filters.objects_filter_abc import ObjectsFilterABC from polystar.common.target_pipeline.objects_linker.objects_linker_abs import ObjectsLinkerABC -from polystar.common.target_pipeline.objects_validators.objects_validator_abc import ObjectsValidatorABC from polystar.common.target_pipeline.target_abc import TargetABC from polystar.common.target_pipeline.target_factories.target_factory_abc import TargetFactoryABC @@ -25,7 +25,7 @@ class TargetPipeline: objects_detector: ObjectsDetectorABC objects_linker: ObjectsLinkerABC - objects_validators: List[ObjectsValidatorABC[DetectedRobot]] + objects_filters: List[ObjectsFilterABC] object_selector: ObjectSelectorABC target_factory: TargetFactoryABC target_sender: TargetSenderABC @@ -43,16 +43,16 @@ class TargetPipeline: def _get_robots_of_interest(self, image: Image) -> List[DetectedRobot]: robots = self._detect_robots(image) - robots = self._filter_robots(image, robots) + robots = self._filter_robots(robots) if not any(robot.armors for robot in robots): raise NoTargetFoundException() return robots - def _filter_robots(self, image: Image, robots: List[DetectedRobot]) -> List[DetectedRobot]: - for robots_validator in self.objects_validators: - robots = robots_validator.filter(robots, image) + def _filter_robots(self, robots: List[DetectedRobot]) -> List[DetectedRobot]: + for robots_validator in self.objects_filters: + robots = robots_validator.filter(robots) return robots def _detect_robots(self, image: Image) -> List[DetectedRobot]: diff --git a/polystar_cv/research/common/dataset/improvement/zoom.py b/polystar_cv/research/common/dataset/improvement/zoom.py index f2338909b5566344ed75a6360ccb9c1aa78f7660..756dc015ebbc92f191b18398c6a3347d0517d92b 100644 --- a/polystar_cv/research/common/dataset/improvement/zoom.py +++ b/polystar_cv/research/common/dataset/improvement/zoom.py @@ -5,7 +5,7 @@ from typing import Iterable, List, Tuple from polystar.common.models.box import Box from polystar.common.models.image import Image -from polystar.common.target_pipeline.objects_validators.in_box_validator import InBoxValidator +from polystar.common.target_pipeline.objects_filters.in_box_filter import InBoxObjectFilter from polystar.common.view.plt_results_viewer import PltResultViewer from research.common.datasets.roco.roco_annotation import ROCOAnnotation from research.common.datasets.roco.zoo.roco_dataset_zoo import ROCODatasetsZoo @@ -14,7 +14,7 @@ from research.common.datasets.roco.zoo.roco_dataset_zoo import ROCODatasetsZoo def crop_image_annotation( image: Image, annotation: ROCOAnnotation, box: Box, min_coverage: float, name: str ) -> Tuple[Image, ROCOAnnotation, str]: - objects = InBoxValidator(box, min_coverage).filter(annotation.objects, image) + objects = InBoxObjectFilter(box, min_coverage).filter(annotation.objects) objects = [copy(o) for o in objects] for obj in objects: obj.box = Box.from_positions( diff --git a/polystar_cv/research/common/datasets/roco/roco_annotation.py b/polystar_cv/research/common/datasets/roco/roco_annotation.py index 0317bddf201eb3b344f02aab58e412f89142287c..e57b065a33877651e880856cc47d060560689320 100644 --- a/polystar_cv/research/common/datasets/roco/roco_annotation.py +++ b/polystar_cv/research/common/datasets/roco/roco_annotation.py @@ -6,7 +6,8 @@ from xml.dom.minidom import parseString import xmltodict from dicttoxml import dicttoxml -from polystar.common.models.object import Object, ObjectFactory + +from polystar.common.models.object import Armor, Object, ObjectFactory @dataclass @@ -18,6 +19,10 @@ class ROCOAnnotation: w: int h: int + @property + def armors(self) -> List[Armor]: + return [obj for obj in self.objects if isinstance(obj, Armor)] + @staticmethod def from_xml_file(xml_file: Path) -> "ROCOAnnotation": try: diff --git a/polystar_cv/research/common/datasets/roco/roco_annotation_filters/roco_annotation_object_filter.py b/polystar_cv/research/common/datasets/roco/roco_annotation_filters/roco_annotation_object_filter.py index 8b9d61db7f1f819f26a7f778ec427189b12a00be..7495b4a7da773fae07303b3e7d299d15ddb75ba4 100644 --- a/polystar_cv/research/common/datasets/roco/roco_annotation_filters/roco_annotation_object_filter.py +++ b/polystar_cv/research/common/datasets/roco/roco_annotation_filters/roco_annotation_object_filter.py @@ -1,11 +1,11 @@ from polystar.common.filters.filter_abc import FilterABC -from polystar.common.target_pipeline.objects_validators.objects_validator_abc import ObjectsValidatorABC +from polystar.common.target_pipeline.objects_filters.objects_filter_abc import ObjectsFilterABC from research.common.datasets.roco.roco_annotation import ROCOAnnotation class ROCOAnnotationObjectFilter(FilterABC): - def __init__(self, object_validator: ObjectsValidatorABC): - self.object_validator = object_validator + def __init__(self, object_filter: ObjectsFilterABC): + self.object_filter = object_filter def validate_single(self, annotation: ROCOAnnotation) -> bool: - return any(self.object_validator.validate(annotation.objects, None)) + return any(self.object_filter.validate(annotation.objects)) diff --git a/polystar_cv/research/robots/dataset/armor_dataset_factory.py b/polystar_cv/research/robots/dataset/armor_dataset_factory.py index c14eca1d40b5837ffbb897fd24139ee3794088e7..a429e5af90a236aea71ab62fc188331b50537944 100644 --- a/polystar_cv/research/robots/dataset/armor_dataset_factory.py +++ b/polystar_cv/research/robots/dataset/armor_dataset_factory.py @@ -1,11 +1,10 @@ from itertools import islice -from typing import Iterator, List, Tuple +from typing import Iterator, Tuple import matplotlib.pyplot as plt from polystar.common.models.image import Image -from polystar.common.models.object import Armor, ObjectType -from polystar.common.target_pipeline.objects_validators.type_object_validator import TypeObjectValidator +from polystar.common.models.object import Armor from research.common.datasets.lazy_dataset import LazyDataset from research.common.datasets.roco.roco_annotation import ROCOAnnotation from research.common.datasets.roco.roco_dataset import LazyROCODataset @@ -23,9 +22,7 @@ class ArmorDataset(LazyDataset[Image, Armor]): @staticmethod def _generate_from_single(image: Image, annotation: ROCOAnnotation, name) -> Iterator[Tuple[Image, Armor, str]]: - armors: List[Armor] = TypeObjectValidator(ObjectType.ARMOR).filter(annotation.objects, image) - - for i, obj in enumerate(armors): + for i, obj in enumerate(annotation.armors): croped_img = image[obj.box.y1 : obj.box.y2, obj.box.x1 : obj.box.x2] yield croped_img, obj, f"{name}-{i}" diff --git a/polystar_cv/research/robots/demos/demo_pipeline.py b/polystar_cv/research/robots/demos/demo_pipeline.py index 44fbbbebfd3460ac59408d7b92be95f994711692..d7678f8a2bf871dae0d27ed5cc665dbf43ab0c10 100644 --- a/polystar_cv/research/robots/demos/demo_pipeline.py +++ b/polystar_cv/research/robots/demos/demo_pipeline.py @@ -2,7 +2,7 @@ from injector import inject from polystar.common.dependency_injection import make_injector from polystar.common.target_pipeline.debug_pipeline import DebugTargetPipeline -from polystar.common.target_pipeline.objects_validators.armor_digit_validator import ArmorDigitValidator +from polystar.common.target_pipeline.objects_filters.armor_digit_filter import KeepArmorsDigitFilter from polystar.common.target_pipeline.target_pipeline import NoTargetFoundException from polystar.common.view.plt_results_viewer import PltResultViewer from research.common.datasets.roco.roco_annotation_filters.roco_annotation_object_filter import ( @@ -17,7 +17,7 @@ def demo_pipeline_on_images(pipeline: DebugTargetPipeline): for builder in ROCODatasetsZoo.DEFAULT_TEST_DATASETS: for image in ( builder.to_images() - .filter_targets(ROCOAnnotationObjectFilter(ArmorDigitValidator((1, 3, 4)))) + .filter_targets(ROCOAnnotationObjectFilter(KeepArmorsDigitFilter((1, 3, 4)))) .shuffle() .cap(15) .build_examples() diff --git a/polystar_cv/tests/common/unittests/object_validators/test_in_box_validator.py b/polystar_cv/tests/common/unittests/object_validators/test_in_box_validator.py index f7aaa5276a535769727742c9c2114ad1a59dce6a..5476955122edeab64e40cd77e8fe02aad9616dff 100644 --- a/polystar_cv/tests/common/unittests/object_validators/test_in_box_validator.py +++ b/polystar_cv/tests/common/unittests/object_validators/test_in_box_validator.py @@ -1,13 +1,13 @@ import unittest from polystar.common.models.box import Box -from polystar.common.models.object import Object -from polystar.common.target_pipeline.objects_validators.in_box_validator import InBoxValidator +from polystar.common.models.object import Object, ObjectType +from polystar.common.target_pipeline.objects_filters.in_box_filter import InBoxObjectFilter -class TestInBoxValidator(unittest.TestCase): +class TestInBoxObjectFilter(unittest.TestCase): def setUp(self) -> None: - self.in_box_validator = InBoxValidator(Box.from_size(2, 2, 6, 4), 0.5) + self.in_box_validator = InBoxObjectFilter(Box.from_size(2, 2, 6, 4), 0.5) def test_fully_inside(self): self._test_obj(3, 3, 2, 2, True) @@ -44,5 +44,5 @@ class TestInBoxValidator(unittest.TestCase): def _test_obj(self, x: int, y: int, w: int, h: int, is_inside: bool): self.assertEqual( - is_inside, self.in_box_validator.validate_single(Object(None, Box.from_size(x, y, w, h)), None) + is_inside, self.in_box_validator.validate_single(Object(ObjectType.CAR, Box.from_size(x, y, w, h))) )