From 958baa8f27b05619415be996fae680dcf8840046 Mon Sep 17 00:00:00 2001
From: Mathieu Beligon <mathieu@feedly.com>
Date: Mon, 16 Mar 2020 23:49:24 -0400
Subject: [PATCH] [robots] (demo) improve trt demo

---
 .../research/demos/demo_pipeline_camera.py    | 21 ++++++++++++++-----
 1 file changed, 16 insertions(+), 5 deletions(-)

diff --git a/robots-at-robots/research/demos/demo_pipeline_camera.py b/robots-at-robots/research/demos/demo_pipeline_camera.py
index 33f1570..10485ff 100644
--- a/robots-at-robots/research/demos/demo_pipeline_camera.py
+++ b/robots-at-robots/research/demos/demo_pipeline_camera.py
@@ -4,10 +4,16 @@ import pycuda.autoinit  # This is needed for initializing CUDA driver
 
 from polystar.common.constants import MODELS_DIR
 from polystar.common.frame_generators.camera_frame_generator import CameraFrameGenerator
+from polystar.common.models.camera import Camera
 from polystar.common.models.label_map import LabelMap
+from polystar.common.models.object import ObjectType
 from polystar.common.models.trt_model import TRTModel
+from polystar.common.pipeline.debug_pipeline import DebugPipeline
+from polystar.common.pipeline.object_selectors.closest_object_selector import ClosestObjectSelector
 from polystar.common.pipeline.objects_detectors.trt_model_object_detector import TRTModelObjectsDetector
 from polystar.common.pipeline.objects_validators.confidence_object_validator import ConfidenceObjectValidator
+from polystar.common.pipeline.objects_validators.type_object_validator import TypeObjectValidator
+from polystar.common.pipeline.target_factories.ratio_simple_target_factory import RatioSimpleTargetFactory
 from polystar.common.utils.tensorflow import patch_tf_v2
 from polystar.common.view.cv2_results_viewer import CV2ResultViewer
 from polystar.robots_at_robots.dependency_injection import make_injector
@@ -23,7 +29,12 @@ if __name__ == "__main__":
     objects_detector = TRTModelObjectsDetector(
         TRTModel(MODELS_DIR / settings.MODEL_NAME, (300, 300)), injector.get(LabelMap)
     )
-    filters = [ConfidenceObjectValidator(confidence_threshold=0.5)]
+    pipeline = DebugPipeline(
+        objects_detector=objects_detector,
+        objects_validators=[ConfidenceObjectValidator(0.6), TypeObjectValidator(ObjectType.Armor)],
+        object_selector=ClosestObjectSelector(),
+        target_factory=RatioSimpleTargetFactory(injector.get(Camera), 300, 100),
+    )
 
     fps = 0
     with CV2ResultViewer("TensorRT demo") as viewer:
@@ -32,15 +43,15 @@ if __name__ == "__main__":
             previous_time = time()
 
             # inference
-            objects = objects_detector.detect(image)
-            for f in filters:
-                objects = f.filter(objects, image)
+            pipeline.predict_target(image)
 
             # display
             fps = 0.9 * fps + 0.1 / (time() - previous_time)
             viewer.new(image)
-            viewer.add_objects(objects)
+            viewer.add_objects(pipeline.debug_info_.validated_objects, forced_color=(0.6, 0.6, 0.6))
+            viewer.add_object(pipeline.debug_info_.selected_object)
             viewer.add_text(f"FPS: {fps:.1f}", 10, 10, (0, 0, 0))
             viewer.display()
+
             if viewer.finished:
                 break
-- 
GitLab