From a98b4b0772d4ab6f598fb88d3abd53dd74674f9e Mon Sep 17 00:00:00 2001 From: Mathieu Beligon <mathieu@feedly.com> Date: Sat, 14 Mar 2020 13:22:08 -0400 Subject: [PATCH] [robots] (demo) Use debug pipeline, and add exception handling --- .../research/demos/demo_pipeline.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/robots-at-robots/research/demos/demo_pipeline.py b/robots-at-robots/research/demos/demo_pipeline.py index ca58ed5..ca3832c 100644 --- a/robots-at-robots/research/demos/demo_pipeline.py +++ b/robots-at-robots/research/demos/demo_pipeline.py @@ -3,25 +3,26 @@ import cv2 from polystar.common.models.camera import Camera from polystar.common.models.label_map import LabelMap from polystar.common.models.object import ObjectType +from polystar.common.pipeline.debug_pipeline import DebugPipeline from polystar.common.pipeline.object_selectors.closest_object_selector import ClosestObjectSelector from polystar.common.pipeline.objects_detectors.tf_model_objects_detector import TFModelObjectsDetector from polystar.common.pipeline.objects_validators.confidence_object_validator import ConfidenceObjectValidator from polystar.common.pipeline.objects_validators.type_object_validator import TypeObjectValidator -from polystar.common.pipeline.pipeline import Pipeline +from polystar.common.pipeline.pipeline import NoTargetFound from polystar.common.pipeline.target_factories.ratio_simple_target_factory import RatioSimpleTargetFactory from polystar.common.utils.tensorflow import patch_tf_v2 from polystar.common.view.plt_results_viewer import PltResultViewer from polystar.robots_at_robots.dependency_injection import make_injector from research.demos.utils import load_tf_model -from research_common.dataset.dji.dji_roco_datasets import DJIROCODataset from research_common.dataset.split import Split from research_common.dataset.split_dataset import SplitDataset +from research_common.dataset.twitch.twitch_roco_datasets import TwitchROCODataset if __name__ == "__main__": patch_tf_v2() injector = make_injector() - pipeline = Pipeline( + pipeline = DebugPipeline( objects_detector=TFModelObjectsDetector(load_tf_model(), injector.get(LabelMap)), objects_validators=[ConfidenceObjectValidator(0.6), TypeObjectValidator(ObjectType.Armor)], object_selector=ClosestObjectSelector(), @@ -29,11 +30,17 @@ if __name__ == "__main__": ) with PltResultViewer("Demo of tf model") as viewer: - for i, image_path in enumerate(SplitDataset(DJIROCODataset.CentralChina, Split.Test).image_paths): - image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB) - obj = pipeline.predict_best_object(image) + for i, image_path in enumerate(SplitDataset(TwitchROCODataset.TWITCH_470150052, Split.Test).image_paths): + try: + image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB) + target = pipeline.predict_target(image) - viewer.display_image_with_objects(image, [obj]) + viewer.new(image) + viewer.add_objects(pipeline.debug_info_.validated_objects, forced_color=(0.3, 0.3, 0.3)) + viewer.add_object(pipeline.debug_info_.selected_object) + viewer.display() + except NoTargetFound: + pass if i == 5: break -- GitLab