From 5de52c6436520508f22327079ed47d9a91d83342 Mon Sep 17 00:00:00 2001
From: Mathieu Beligon <mathieu@feedly.com>
Date: Mon, 9 Mar 2020 21:22:10 -0400
Subject: [PATCH] [common] (frame generators) FPSVideoFrameGenerator: add
 VideoFrameGenerator and CV2FrameGeneratorABC as base classes

---
 .../cv2_frame_generator_abc.py                | 33 +++++++++++++++++++
 .../fps_video_frame_generator.py              | 33 ++++++++-----------
 .../frame_generators/video_frame_generator.py | 14 ++++++++
 3 files changed, 61 insertions(+), 19 deletions(-)
 create mode 100644 common/polystar/common/frame_generators/cv2_frame_generator_abc.py
 create mode 100644 common/polystar/common/frame_generators/video_frame_generator.py

diff --git a/common/polystar/common/frame_generators/cv2_frame_generator_abc.py b/common/polystar/common/frame_generators/cv2_frame_generator_abc.py
new file mode 100644
index 0000000..89ae7b2
--- /dev/null
+++ b/common/polystar/common/frame_generators/cv2_frame_generator_abc.py
@@ -0,0 +1,33 @@
+from abc import ABC, abstractmethod
+from dataclasses import dataclass, field
+from typing import Any, Iterable
+
+import cv2
+from cv2.cv2 import VideoCapture
+
+from polystar.common.frame_generators.frames_generator_abc import FrameGeneratorABC
+from polystar.common.models.image import Image
+
+
+@dataclass
+class CV2FrameGeneratorABC(FrameGeneratorABC, ABC):
+
+    _cap: VideoCapture = field(init=False, repr=False)
+
+    def __enter__(self):
+        self._cap = cv2.VideoCapture(*self._capture_params())
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self._cap.release()
+
+    def generate(self) -> Iterable[Image]:
+        with self:
+            while 1:
+                is_open, frame = self._cap.read()
+                if not is_open:
+                    return
+                yield frame
+
+    @abstractmethod
+    def _capture_params(self) -> Iterable[Any]:
+        pass
diff --git a/common/polystar/common/frame_generators/fps_video_frame_generator.py b/common/polystar/common/frame_generators/fps_video_frame_generator.py
index 280220c..ade93a8 100644
--- a/common/polystar/common/frame_generators/fps_video_frame_generator.py
+++ b/common/polystar/common/frame_generators/fps_video_frame_generator.py
@@ -1,31 +1,26 @@
-from pathlib import Path
+from dataclasses import dataclass
+from typing import Iterable
 
-import cv2
 import ffmpeg
 
-from polystar.common.frame_generators.frames_generator_abc import FrameGeneratorABC
+from polystar.common.frame_generators.video_frame_generator import VideoFrameGenerator
+from polystar.common.models.image import Image
 
 
-class FPSVideoFrameGenerator(FrameGeneratorABC):
-    def __init__(self, video_path: Path, desired_fps: int):
-        self.video_path: Path = video_path
-        self.desired_fps: int = desired_fps
-        self.video_fps: int = self._get_video_fps()
+@dataclass
+class FPSVideoFrameGenerator(VideoFrameGenerator):
+
+    desired_fps: int
+
+    def __post_init__(self):
+        self.frame_rate: int = self._get_video_fps() // self.desired_fps
 
     def _get_video_fps(self):
         return max(
             int(stream["r_frame_rate"].split("/")[0]) for stream in ffmpeg.probe(str(self.video_path))["streams"]
         )
 
-    def generate(self):
-        video = cv2.VideoCapture(str(self.video_path))
-        frame_rate = self.video_fps // self.desired_fps
-        count = 0
-        while 1:
-            is_unfinished, frame = video.read()
-            if not is_unfinished:
-                video.release()
-                return
-            if not count % frame_rate:
+    def generate(self) -> Iterable[Image]:
+        for i, frame in enumerate(super().generate()):
+            if not i % self.frame_rate:
                 yield frame
-            count += 1
diff --git a/common/polystar/common/frame_generators/video_frame_generator.py b/common/polystar/common/frame_generators/video_frame_generator.py
new file mode 100644
index 0000000..388dcc2
--- /dev/null
+++ b/common/polystar/common/frame_generators/video_frame_generator.py
@@ -0,0 +1,14 @@
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Iterable, Any
+
+from polystar.common.frame_generators.cv2_frame_generator_abc import CV2FrameGeneratorABC
+
+
+@dataclass
+class VideoFrameGenerator(CV2FrameGeneratorABC):
+
+    video_path: Path
+
+    def _capture_params(self) -> Iterable[Any]:
+        return (str(self.video_path),)
-- 
GitLab