diff --git a/src/polystar/view/cv2_results_viewer.py b/src/polystar/view/cv2_results_viewer.py index ef93a3d092a97346eba85c9ecb31eaea53f50b35..98594c241c1c625fdb25e9ecc3a0aa360f4e1fc2 100644 --- a/src/polystar/view/cv2_results_viewer.py +++ b/src/polystar/view/cv2_results_viewer.py @@ -1,3 +1,6 @@ +from contextlib import suppress +from typing import Any, Callable, Dict + import cv2 import numpy as np @@ -23,10 +26,12 @@ COLORS = [ (23, 190, 207), ] # seaborn.color_palette() * 255 +Callback = Callable[[], Any] + class CV2ResultViewer(ResultViewerABC): - def __init__(self, name: str, delay: int = 1, end_keys: str = "q"): - self.end_keys = [ord(c) for c in end_keys] + def __init__(self, name: str, delay: int = 1, end_key: str = "q", key_callbacks: Dict[str, Callback] = None): + self.keycode_callbacks = self._make_keycode_callbacks(end_key, key_callbacks or {}) self.delay = delay self.name = name self._current_image: Image = None @@ -70,4 +75,13 @@ class CV2ResultViewer(ResultViewerABC): def display(self): cv2.imshow(self.name, self._current_image) - self.finished = cv2.waitKey(self.delay) & 0xFF in self.end_keys + keycode = cv2.waitKey(self.delay) & 0xFF + with suppress(KeyError): + self.keycode_callbacks[keycode]() + + def stop(self): + self.finished = True + + def _make_keycode_callbacks(self, end_key: str, key_callbacks: Dict[str, Callback]) -> Dict[int, Callback]: + key_callbacks[end_key] = self.stop + return {ord(k): f for k, f in key_callbacks.items()}