From a5bca86e57b862be0f6d8fe8e526b19f450f06e9 Mon Sep 17 00:00:00 2001 From: Mathieu Beligon <mathieu@feedly.com> Date: Sat, 6 Mar 2021 20:02:25 -0500 Subject: [PATCH] [TargetPipeline] Flow from detector --- .../target_pipeline/debug_pipeline.py | 9 +++-- .../objects_detectors/objects_detector_abc.py | 6 +++- .../target_pipeline/target_pipeline.py | 11 ++++++ src/research/scripts/demo_pipeline_camera.py | 35 +++++++++---------- 4 files changed, 36 insertions(+), 25 deletions(-) diff --git a/src/polystar/target_pipeline/debug_pipeline.py b/src/polystar/target_pipeline/debug_pipeline.py index 1bb2cea..b3302d6 100644 --- a/src/polystar/target_pipeline/debug_pipeline.py +++ b/src/polystar/target_pipeline/debug_pipeline.py @@ -12,8 +12,8 @@ from polystar.target_pipeline.target_pipeline import TargetPipeline, _assert_arm @dataclass class DebugInfo: - image: Image = None - detected_robots: List[DetectedRobot] = field(init=False, default_factory=list) + image: Image + detected_robots: List[DetectedRobot] validated_robots: List[DetectedRobot] = field(init=False, default_factory=list) selected_armor: DetectedArmor = field(init=False, default=None) target: TargetABC = field(init=False, default=None) @@ -26,9 +26,8 @@ class DebugTargetPipeline(TargetPipeline): debug_info_: DebugInfo = field(init=False) - def predict_target(self, image: Image) -> TargetABC: - self.debug_info_ = DebugInfo(image) - self.debug_info_.detected_robots = self.robots_detector.detect_robots(image) + def _make_target_from_robots(self, image: Image, robots: List[DetectedRobot]) -> TargetABC: + self.debug_info_ = DebugInfo(image, robots) self.debug_info_.validated_robots = self.robots_filters.filter(self.debug_info_.detected_robots) _assert_armors_detected(self.debug_info_.validated_robots) self.debug_info_.selected_armor = self.object_selector.select(self.debug_info_.validated_robots, image) diff --git a/src/polystar/target_pipeline/objects_detectors/objects_detector_abc.py b/src/polystar/target_pipeline/objects_detectors/objects_detector_abc.py index 53a9c26..298398f 100644 --- a/src/polystar/target_pipeline/objects_detectors/objects_detector_abc.py +++ b/src/polystar/target_pipeline/objects_detectors/objects_detector_abc.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List, Tuple +from typing import Iterable, List, Tuple import numpy as np from injector import inject @@ -28,6 +28,10 @@ class RobotsDetector: object_detector: ObjectsDetectorABC objects_linker: ObjectsLinkerABC + def flow_robots(self, image_iterator: Iterable[Image]) -> Iterable[List[DetectedRobot]]: + for image in image_iterator: + yield image, self.detect_robots(image) + def detect_robots(self, image: Image) -> List[DetectedRobot]: objects_params = self.object_detector.detect(image) robots, armors = self.make_robots_and_armors(objects_params, image) diff --git a/src/polystar/target_pipeline/target_pipeline.py b/src/polystar/target_pipeline/target_pipeline.py index 43c82c7..b166286 100644 --- a/src/polystar/target_pipeline/target_pipeline.py +++ b/src/polystar/target_pipeline/target_pipeline.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Iterable, List, Optional from injector import inject @@ -23,8 +24,18 @@ class TargetPipeline: object_selector: ObjectSelectorABC target_factory: TargetFactoryABC + def flow_targets(self, image_iterator: Iterable[Image]) -> Iterable[Optional[TargetABC]]: + for image, robots in self.robots_detector.flow_robots(image_iterator): + try: + yield self._make_target_from_robots(image, robots) + except NoTargetFoundException: + yield None + def predict_target(self, image: Image) -> TargetABC: robots = self.robots_detector.detect_robots(image) + return self._make_target_from_robots(image, robots) + + def _make_target_from_robots(self, image: Image, robots: List[DetectedRobot]) -> TargetABC: robots = self.robots_filters.filter(robots) _assert_armors_detected(robots) selected_armor = self.object_selector.select(robots, image) diff --git a/src/research/scripts/demo_pipeline_camera.py b/src/research/scripts/demo_pipeline_camera.py index 00c6d6c..5a035b5 100644 --- a/src/research/scripts/demo_pipeline_camera.py +++ b/src/research/scripts/demo_pipeline_camera.py @@ -1,12 +1,13 @@ +from typing import Optional + from injector import inject from polystar.communication.cs_link_abc import CSLinkABC from polystar.communication.togglabe_cs_link import TogglableCSLink from polystar.dependency_injection import make_injector from polystar.frame_generators.frames_generator_abc import FrameGeneratorABC -from polystar.models.image import Image from polystar.target_pipeline.debug_pipeline import DebugTargetPipeline -from polystar.target_pipeline.target_pipeline import NoTargetFoundException +from polystar.target_pipeline.target_abc import SimpleTarget from polystar.utils.fps import FPS from polystar.utils.thread import MyThread from polystar.view.cv2_results_viewer import CV2ResultViewer @@ -18,32 +19,28 @@ class CameraPipelineDemo: self.cs_link = TogglableCSLink(cs_link, is_on=False) self.webcam = webcam self.pipeline = pipeline - self.fps, self.pipeline_fps = FPS(), FPS() + self.fps = FPS() self.persistence_last_detection = 0 def run(self): - with CV2ResultViewer("TensorRT demo", key_callbacks={" ": self.cs_link.toggle}) as viewer: - for image in self.webcam: - self.pipeline_fps.skip() - self._detect(image) - self.pipeline_fps.tick(), self.fps.tick() + with CV2ResultViewer("Pipeline demo", key_callbacks={" ": self.cs_link.toggle}) as viewer: + for target in self.pipeline.flow_targets(self.webcam): + self._send_target(target) self._display(viewer) - self.fps.skip() - def _detect(self, image: Image): - try: - target = self.pipeline.predict_target(image) - self.cs_link.send_target(target) + def _send_target(self, target: Optional[SimpleTarget]): + if target is not None: self.persistence_last_detection = 5 - except NoTargetFoundException: - if self.persistence_last_detection: - self.persistence_last_detection -= 1 - else: - self.cs_link.send_no_target() + return self.cs_link.send_target(target) + + if not self.persistence_last_detection: + return self.cs_link.send_no_target() + + self.persistence_last_detection -= 1 def _display(self, viewer: CV2ResultViewer): viewer.add_debug_info(self.pipeline.debug_info_) - viewer.add_text(f"FPS: {self.fps:.1f} / {self.pipeline_fps:.1f}", 10, 10, (0, 0, 0)) + viewer.add_text(f"FPS: {self.fps.tick():.1f}", 10, 10, (0, 0, 0)) viewer.add_text("Communication: " + ("[ON]" if self.cs_link.is_on else "[OFF]"), 10, 30, (0, 0, 0)) viewer.display() -- GitLab