From 958baa8f27b05619415be996fae680dcf8840046 Mon Sep 17 00:00:00 2001 From: Mathieu Beligon <mathieu@feedly.com> Date: Mon, 16 Mar 2020 23:49:24 -0400 Subject: [PATCH] [robots] (demo) improve trt demo --- .../research/demos/demo_pipeline_camera.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/robots-at-robots/research/demos/demo_pipeline_camera.py b/robots-at-robots/research/demos/demo_pipeline_camera.py index 33f1570..10485ff 100644 --- a/robots-at-robots/research/demos/demo_pipeline_camera.py +++ b/robots-at-robots/research/demos/demo_pipeline_camera.py @@ -4,10 +4,16 @@ import pycuda.autoinit # This is needed for initializing CUDA driver from polystar.common.constants import MODELS_DIR from polystar.common.frame_generators.camera_frame_generator import CameraFrameGenerator +from polystar.common.models.camera import Camera from polystar.common.models.label_map import LabelMap +from polystar.common.models.object import ObjectType from polystar.common.models.trt_model import TRTModel +from polystar.common.pipeline.debug_pipeline import DebugPipeline +from polystar.common.pipeline.object_selectors.closest_object_selector import ClosestObjectSelector from polystar.common.pipeline.objects_detectors.trt_model_object_detector import TRTModelObjectsDetector from polystar.common.pipeline.objects_validators.confidence_object_validator import ConfidenceObjectValidator +from polystar.common.pipeline.objects_validators.type_object_validator import TypeObjectValidator +from polystar.common.pipeline.target_factories.ratio_simple_target_factory import RatioSimpleTargetFactory from polystar.common.utils.tensorflow import patch_tf_v2 from polystar.common.view.cv2_results_viewer import CV2ResultViewer from polystar.robots_at_robots.dependency_injection import make_injector @@ -23,7 +29,12 @@ if __name__ == "__main__": objects_detector = TRTModelObjectsDetector( TRTModel(MODELS_DIR / settings.MODEL_NAME, (300, 300)), injector.get(LabelMap) ) - filters = [ConfidenceObjectValidator(confidence_threshold=0.5)] + pipeline = DebugPipeline( + objects_detector=objects_detector, + objects_validators=[ConfidenceObjectValidator(0.6), TypeObjectValidator(ObjectType.Armor)], + object_selector=ClosestObjectSelector(), + target_factory=RatioSimpleTargetFactory(injector.get(Camera), 300, 100), + ) fps = 0 with CV2ResultViewer("TensorRT demo") as viewer: @@ -32,15 +43,15 @@ if __name__ == "__main__": previous_time = time() # inference - objects = objects_detector.detect(image) - for f in filters: - objects = f.filter(objects, image) + pipeline.predict_target(image) # display fps = 0.9 * fps + 0.1 / (time() - previous_time) viewer.new(image) - viewer.add_objects(objects) + viewer.add_objects(pipeline.debug_info_.validated_objects, forced_color=(0.6, 0.6, 0.6)) + viewer.add_object(pipeline.debug_info_.selected_object) viewer.add_text(f"FPS: {fps:.1f}", 10, 10, (0, 0, 0)) viewer.display() + if viewer.finished: break -- GitLab