From a5bca86e57b862be0f6d8fe8e526b19f450f06e9 Mon Sep 17 00:00:00 2001
From: Mathieu Beligon <mathieu@feedly.com>
Date: Sat, 6 Mar 2021 20:02:25 -0500
Subject: [PATCH] [TargetPipeline] Flow from detector

---
 .../target_pipeline/debug_pipeline.py         |  9 +++--
 .../objects_detectors/objects_detector_abc.py |  6 +++-
 .../target_pipeline/target_pipeline.py        | 11 ++++++
 src/research/scripts/demo_pipeline_camera.py  | 35 +++++++++----------
 4 files changed, 36 insertions(+), 25 deletions(-)

diff --git a/src/polystar/target_pipeline/debug_pipeline.py b/src/polystar/target_pipeline/debug_pipeline.py
index 1bb2cea..b3302d6 100644
--- a/src/polystar/target_pipeline/debug_pipeline.py
+++ b/src/polystar/target_pipeline/debug_pipeline.py
@@ -12,8 +12,8 @@ from polystar.target_pipeline.target_pipeline import TargetPipeline, _assert_arm
 
 @dataclass
 class DebugInfo:
-    image: Image = None
-    detected_robots: List[DetectedRobot] = field(init=False, default_factory=list)
+    image: Image
+    detected_robots: List[DetectedRobot]
     validated_robots: List[DetectedRobot] = field(init=False, default_factory=list)
     selected_armor: DetectedArmor = field(init=False, default=None)
     target: TargetABC = field(init=False, default=None)
@@ -26,9 +26,8 @@ class DebugTargetPipeline(TargetPipeline):
 
     debug_info_: DebugInfo = field(init=False)
 
-    def predict_target(self, image: Image) -> TargetABC:
-        self.debug_info_ = DebugInfo(image)
-        self.debug_info_.detected_robots = self.robots_detector.detect_robots(image)
+    def _make_target_from_robots(self, image: Image, robots: List[DetectedRobot]) -> TargetABC:
+        self.debug_info_ = DebugInfo(image, robots)
         self.debug_info_.validated_robots = self.robots_filters.filter(self.debug_info_.detected_robots)
         _assert_armors_detected(self.debug_info_.validated_robots)
         self.debug_info_.selected_armor = self.object_selector.select(self.debug_info_.validated_robots, image)
diff --git a/src/polystar/target_pipeline/objects_detectors/objects_detector_abc.py b/src/polystar/target_pipeline/objects_detectors/objects_detector_abc.py
index 53a9c26..298398f 100644
--- a/src/polystar/target_pipeline/objects_detectors/objects_detector_abc.py
+++ b/src/polystar/target_pipeline/objects_detectors/objects_detector_abc.py
@@ -1,6 +1,6 @@
 from abc import ABC, abstractmethod
 from dataclasses import dataclass
-from typing import List, Tuple
+from typing import Iterable, List, Tuple
 
 import numpy as np
 from injector import inject
@@ -28,6 +28,10 @@ class RobotsDetector:
     object_detector: ObjectsDetectorABC
     objects_linker: ObjectsLinkerABC
 
+    def flow_robots(self, image_iterator: Iterable[Image]) -> Iterable[List[DetectedRobot]]:
+        for image in image_iterator:
+            yield image, self.detect_robots(image)
+
     def detect_robots(self, image: Image) -> List[DetectedRobot]:
         objects_params = self.object_detector.detect(image)
         robots, armors = self.make_robots_and_armors(objects_params, image)
diff --git a/src/polystar/target_pipeline/target_pipeline.py b/src/polystar/target_pipeline/target_pipeline.py
index 43c82c7..b166286 100644
--- a/src/polystar/target_pipeline/target_pipeline.py
+++ b/src/polystar/target_pipeline/target_pipeline.py
@@ -1,4 +1,5 @@
 from dataclasses import dataclass
+from typing import Iterable, List, Optional
 
 from injector import inject
 
@@ -23,8 +24,18 @@ class TargetPipeline:
     object_selector: ObjectSelectorABC
     target_factory: TargetFactoryABC
 
+    def flow_targets(self, image_iterator: Iterable[Image]) -> Iterable[Optional[TargetABC]]:
+        for image, robots in self.robots_detector.flow_robots(image_iterator):
+            try:
+                yield self._make_target_from_robots(image, robots)
+            except NoTargetFoundException:
+                yield None
+
     def predict_target(self, image: Image) -> TargetABC:
         robots = self.robots_detector.detect_robots(image)
+        return self._make_target_from_robots(image, robots)
+
+    def _make_target_from_robots(self, image: Image, robots: List[DetectedRobot]) -> TargetABC:
         robots = self.robots_filters.filter(robots)
         _assert_armors_detected(robots)
         selected_armor = self.object_selector.select(robots, image)
diff --git a/src/research/scripts/demo_pipeline_camera.py b/src/research/scripts/demo_pipeline_camera.py
index 00c6d6c..5a035b5 100644
--- a/src/research/scripts/demo_pipeline_camera.py
+++ b/src/research/scripts/demo_pipeline_camera.py
@@ -1,12 +1,13 @@
+from typing import Optional
+
 from injector import inject
 
 from polystar.communication.cs_link_abc import CSLinkABC
 from polystar.communication.togglabe_cs_link import TogglableCSLink
 from polystar.dependency_injection import make_injector
 from polystar.frame_generators.frames_generator_abc import FrameGeneratorABC
-from polystar.models.image import Image
 from polystar.target_pipeline.debug_pipeline import DebugTargetPipeline
-from polystar.target_pipeline.target_pipeline import NoTargetFoundException
+from polystar.target_pipeline.target_abc import SimpleTarget
 from polystar.utils.fps import FPS
 from polystar.utils.thread import MyThread
 from polystar.view.cv2_results_viewer import CV2ResultViewer
@@ -18,32 +19,28 @@ class CameraPipelineDemo:
         self.cs_link = TogglableCSLink(cs_link, is_on=False)
         self.webcam = webcam
         self.pipeline = pipeline
-        self.fps, self.pipeline_fps = FPS(), FPS()
+        self.fps = FPS()
         self.persistence_last_detection = 0
 
     def run(self):
-        with CV2ResultViewer("TensorRT demo", key_callbacks={" ": self.cs_link.toggle}) as viewer:
-            for image in self.webcam:
-                self.pipeline_fps.skip()
-                self._detect(image)
-                self.pipeline_fps.tick(), self.fps.tick()
+        with CV2ResultViewer("Pipeline demo", key_callbacks={" ": self.cs_link.toggle}) as viewer:
+            for target in self.pipeline.flow_targets(self.webcam):
+                self._send_target(target)
                 self._display(viewer)
-                self.fps.skip()
 
-    def _detect(self, image: Image):
-        try:
-            target = self.pipeline.predict_target(image)
-            self.cs_link.send_target(target)
+    def _send_target(self, target: Optional[SimpleTarget]):
+        if target is not None:
             self.persistence_last_detection = 5
-        except NoTargetFoundException:
-            if self.persistence_last_detection:
-                self.persistence_last_detection -= 1
-            else:
-                self.cs_link.send_no_target()
+            return self.cs_link.send_target(target)
+
+        if not self.persistence_last_detection:
+            return self.cs_link.send_no_target()
+
+        self.persistence_last_detection -= 1
 
     def _display(self, viewer: CV2ResultViewer):
         viewer.add_debug_info(self.pipeline.debug_info_)
-        viewer.add_text(f"FPS: {self.fps:.1f} / {self.pipeline_fps:.1f}", 10, 10, (0, 0, 0))
+        viewer.add_text(f"FPS: {self.fps.tick():.1f}", 10, 10, (0, 0, 0))
         viewer.add_text("Communication: " + ("[ON]" if self.cs_link.is_on else "[OFF]"), 10, 30, (0, 0, 0))
         viewer.display()
 
-- 
GitLab