From d80297fd08b2e6489fb679a04637f3e54b3c60e4 Mon Sep 17 00:00:00 2001
From: Mathieu Beligon <mathieu@feedly.com>
Date: Mon, 1 Mar 2021 17:53:34 -0500
Subject: [PATCH] [demo] add toggle for communication

---
 .../communication/togglabe_cs_link.py         | 16 +++++
 src/research/scripts/demo_pipeline_camera.py  | 63 ++++++++++++-------
 2 files changed, 55 insertions(+), 24 deletions(-)
 create mode 100644 src/polystar/communication/togglabe_cs_link.py

diff --git a/src/polystar/communication/togglabe_cs_link.py b/src/polystar/communication/togglabe_cs_link.py
new file mode 100644
index 0000000..2cc7a6d
--- /dev/null
+++ b/src/polystar/communication/togglabe_cs_link.py
@@ -0,0 +1,16 @@
+from polystar.communication.command import Command
+from polystar.communication.cs_link_abc import CSLinkABC
+
+
+class TogglableCSLink(CSLinkABC):
+    def __init__(self, cs_link: CSLinkABC, is_on: bool):
+        self.is_on = is_on
+        self.cs_link = cs_link
+
+    def send_command(self, command: Command):
+        if not self.is_on:
+            return
+        self.cs_link.send_command(command)
+
+    def toggle(self):
+        self.is_on = not self.is_on
diff --git a/src/research/scripts/demo_pipeline_camera.py b/src/research/scripts/demo_pipeline_camera.py
index 06f72f9..c2c79fd 100644
--- a/src/research/scripts/demo_pipeline_camera.py
+++ b/src/research/scripts/demo_pipeline_camera.py
@@ -1,38 +1,53 @@
 from injector import inject
 
 from polystar.communication.cs_link_abc import CSLinkABC
+from polystar.communication.togglabe_cs_link import TogglableCSLink
 from polystar.dependency_injection import make_injector
 from polystar.frame_generators.frames_generator_abc import FrameGeneratorABC
+from polystar.models.image import Image
 from polystar.target_pipeline.debug_pipeline import DebugTargetPipeline
 from polystar.target_pipeline.target_pipeline import NoTargetFoundException
 from polystar.utils.fps import FPS
 from polystar.view.cv2_results_viewer import CV2ResultViewer
 
 
-@inject
-def demo_pipeline_on_camera(pipeline: DebugTargetPipeline, webcam: FrameGeneratorABC, cs_link: CSLinkABC):
-    fps, pipeline_fps = FPS(), FPS()
-    with CV2ResultViewer("TensorRT demo") as viewer:
-        persistence_last_detection: int = 0
-        for image in webcam.generate():
-            pipeline_fps.skip()
-            try:
-                target = pipeline.predict_target(image)
-                cs_link.send_target(target)
-                persistence_last_detection = 5
-            except NoTargetFoundException:
-                if persistence_last_detection:
-                    persistence_last_detection -= 1
-                else:
-                    cs_link.send_no_target()
-            pipeline_fps.tick(), fps.tick()
-            viewer.add_debug_info(pipeline.debug_info_)
-            viewer.add_text(f"FPS: {fps:.1f} / {pipeline_fps:.1f}", 10, 10, (0, 0, 0))
-            viewer.display()
-            fps.skip()
-            if viewer.finished:
-                return
+class CameraPipelineDemo:
+    @inject
+    def __init__(self, pipeline: DebugTargetPipeline, webcam: FrameGeneratorABC, cs_link: CSLinkABC):
+        self.cs_link = TogglableCSLink(cs_link, is_on=False)
+        self.webcam = webcam
+        self.pipeline = pipeline
+        self.fps, self.pipeline_fps = FPS(), FPS()
+        self.persistence_last_detection = 0
+
+    def run(self):
+        with CV2ResultViewer("TensorRT demo", key_callbacks={" ": self.cs_link.toggle}) as viewer:
+            for image in self.webcam.generate():
+                self.pipeline_fps.skip()
+                self._detect(image)
+                self.pipeline_fps.tick(), self.fps.tick()
+                self._display(viewer)
+                self.fps.skip()
+                if viewer.finished:
+                    return
+
+    def _detect(self, image: Image):
+        try:
+            target = self.pipeline.predict_target(image)
+            self.cs_link.send_target(target)
+            self.persistence_last_detection = 5
+        except NoTargetFoundException:
+            if self.persistence_last_detection:
+                self.persistence_last_detection -= 1
+            else:
+                self.cs_link.send_no_target()
+
+    def _display(self, viewer: CV2ResultViewer):
+        viewer.add_debug_info(self.pipeline.debug_info_)
+        viewer.add_text(f"FPS: {self.fps:.1f} / {self.pipeline_fps:.1f}", 10, 10, (0, 0, 0))
+        viewer.add_text("Communication: " + ("[ON]" if self.cs_link.is_on else "[OFF]"), 10, 30, (0, 0, 0))
+        viewer.display()
 
 
 if __name__ == "__main__":
-    make_injector().call_with_injection(demo_pipeline_on_camera)
+    make_injector().get(CameraPipelineDemo).run()
-- 
GitLab