diff --git a/src/polystar/frame_generators/cv2_frame_generator_abc.py b/src/polystar/frame_generators/cv2_frame_generator_abc.py index 6f6033988c5a81a6b0cd742be391cf7ece64e8f5..5a19f94aebdee523980a66c8a6e003c3220915df 100644 --- a/src/polystar/frame_generators/cv2_frame_generator_abc.py +++ b/src/polystar/frame_generators/cv2_frame_generator_abc.py @@ -10,29 +10,21 @@ from polystar.models.image import Image @dataclass class CV2FrameGeneratorABC(FrameGeneratorABC, ABC): - - _cap: cv2.VideoCapture = field(init=False, repr=False) - - def __enter__(self): - self._cap = cv2.VideoCapture(*self._capture_params()) - # self._cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) - assert self._cap.isOpened() - self._post_opening_operation() - - def __exit__(self, exc_type, exc_val, exc_tb): - self._cap.release() - def generate(self) -> Iterable[Image]: - with self: - while 1: - is_open, frame = self._cap.read() - if not is_open: - return - yield frame + _cap = self._open() + while 1: + is_open, frame = _cap.read() + if not is_open: + break + yield frame + _cap.release() + + def _open(self) -> cv2.VideoCapture: + _cap = cv2.VideoCapture(*self._capture_params()) + _cap.set(cv2.CAP_PROP_BUFFERSIZE, 0) + assert _cap.isOpened() + return _cap @abstractmethod def _capture_params(self) -> Iterable[Any]: pass - - def _post_opening_operation(self): - pass diff --git a/src/polystar/frame_generators/video_frame_generator.py b/src/polystar/frame_generators/video_frame_generator.py index 22701b54effe603b8f43d0089924db7929793147..001eb0fe6cb7a5f15d0d0422dbf1f673b51f14cf 100644 --- a/src/polystar/frame_generators/video_frame_generator.py +++ b/src/polystar/frame_generators/video_frame_generator.py @@ -1,3 +1,5 @@ +import cv2 + from dataclasses import dataclass from pathlib import Path from typing import Any, Iterable, Optional @@ -18,10 +20,6 @@ class VideoFrameGenerator(CV2FrameGeneratorABC): def _capture_params(self) -> Iterable[Any]: return (str(self.video_path),) - def _post_opening_operation(self): - if self.offset_seconds: - self._cap.set(CAP_PROP_POS_FRAMES, self._video_fps * self.offset_seconds - 2) - @memoized_property def _video_fps(self) -> int: streams_info = ffmpeg.probe(str(self.video_path))["streams"] @@ -30,3 +28,9 @@ class VideoFrameGenerator(CV2FrameGeneratorABC): continue return round(eval(stream_info["avg_frame_rate"])) raise ValueError(f"No fps found for video {self.video_path.name}") + + def _open(self) -> cv2.VideoCapture: + _cap = super()._open() + if self.offset_seconds: + _cap.set(CAP_PROP_POS_FRAMES, self._video_fps * self.offset_seconds - 2) + return _cap