From 821c2a4b949167a013c757ec8a7a2fe5b49f951e Mon Sep 17 00:00:00 2001 From: Mathieu Beligon <beligonmathieu@gmail.com> Date: Thu, 4 Mar 2021 18:37:13 -0500 Subject: [PATCH] [CV2FrameGenerator] simplify class --- .../cv2_frame_generator_abc.py | 34 +++++++------------ .../frame_generators/video_frame_generator.py | 12 ++++--- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/src/polystar/frame_generators/cv2_frame_generator_abc.py b/src/polystar/frame_generators/cv2_frame_generator_abc.py index 6f60339..5a19f94 100644 --- a/src/polystar/frame_generators/cv2_frame_generator_abc.py +++ b/src/polystar/frame_generators/cv2_frame_generator_abc.py @@ -10,29 +10,21 @@ from polystar.models.image import Image @dataclass class CV2FrameGeneratorABC(FrameGeneratorABC, ABC): - - _cap: cv2.VideoCapture = field(init=False, repr=False) - - def __enter__(self): - self._cap = cv2.VideoCapture(*self._capture_params()) - # self._cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) - assert self._cap.isOpened() - self._post_opening_operation() - - def __exit__(self, exc_type, exc_val, exc_tb): - self._cap.release() - def generate(self) -> Iterable[Image]: - with self: - while 1: - is_open, frame = self._cap.read() - if not is_open: - return - yield frame + _cap = self._open() + while 1: + is_open, frame = _cap.read() + if not is_open: + break + yield frame + _cap.release() + + def _open(self) -> cv2.VideoCapture: + _cap = cv2.VideoCapture(*self._capture_params()) + _cap.set(cv2.CAP_PROP_BUFFERSIZE, 0) + assert _cap.isOpened() + return _cap @abstractmethod def _capture_params(self) -> Iterable[Any]: pass - - def _post_opening_operation(self): - pass diff --git a/src/polystar/frame_generators/video_frame_generator.py b/src/polystar/frame_generators/video_frame_generator.py index 22701b5..001eb0f 100644 --- a/src/polystar/frame_generators/video_frame_generator.py +++ b/src/polystar/frame_generators/video_frame_generator.py @@ -1,3 +1,5 @@ +import cv2 + from dataclasses import dataclass from pathlib import Path from typing import Any, Iterable, Optional @@ -18,10 +20,6 @@ class VideoFrameGenerator(CV2FrameGeneratorABC): def _capture_params(self) -> Iterable[Any]: return (str(self.video_path),) - def _post_opening_operation(self): - if self.offset_seconds: - self._cap.set(CAP_PROP_POS_FRAMES, self._video_fps * self.offset_seconds - 2) - @memoized_property def _video_fps(self) -> int: streams_info = ffmpeg.probe(str(self.video_path))["streams"] @@ -30,3 +28,9 @@ class VideoFrameGenerator(CV2FrameGeneratorABC): continue return round(eval(stream_info["avg_frame_rate"])) raise ValueError(f"No fps found for video {self.video_path.name}") + + def _open(self) -> cv2.VideoCapture: + _cap = super()._open() + if self.offset_seconds: + _cap.set(CAP_PROP_POS_FRAMES, self._video_fps * self.offset_seconds - 2) + return _cap -- GitLab