Skip to content
Snippets Groups Projects
Commit 07b9cde8 authored by Mathieu Beligon's avatar Mathieu Beligon
Browse files

[common] (TargetPipeline) Add fist tracker

parent 6f6448f0
No related branches found
No related tags found
No related merge requests found
......@@ -11,6 +11,7 @@ class DetectedObject(Object):
confidence: float
previous_occurrences: Deque["DetectedObject"] = field(init=False, default_factory=deque)
step_of_detection: int = -1
def __str__(self) -> str:
return f"{self.type.name} ({self.confidence:.1%})"
from collections import deque
from typing import Optional
from dataclasses import dataclass
from polystar.common.target_pipeline.detected_objects.detected_object import DetectedObject
@dataclass
class ObjectTrack:
new_object: Optional[DetectedObject]
previous_object: Optional[DetectedObject]
def merge(self) -> DetectedObject:
if not self.previous_object:
return self.new_object
if not self.new_object:
return self.previous_object
self.new_object.previous_occurrences = self.previous_object.previous_occurrences
self.previous_object.previous_occurrences = deque()
self.new_object.previous_occurrences.append(self.previous_object)
from abc import ABC, abstractmethod
from typing import List, Iterable
from dataclasses import dataclass, field
from polystar.common.models.image import Image
from polystar.common.target_pipeline.detected_objects.detected_object import DetectedObject
from polystar.common.target_pipeline.objects_trackers.object_track import ObjectTrack
@dataclass
class ObjectsTrackerABC(ABC):
n_steps_to_track: int
tracked_objects: List[DetectedObject] = field(init=False, default_factory=list)
_step: int = field(init=False, default=0)
def reset(self):
self._step = 0
def track_objects(self, objects: List[DetectedObject], image: Image) -> List[DetectedObject]:
self._set_steps_of_new_objects(objects)
self._loose_too_old_objects()
tracks = self._link_to_previous_objects(objects, image)
self.update_from_tracks(tracks)
return self.tracked_objects
def _set_steps_of_new_objects(self, objects):
self._step += 1
for obj in objects:
obj.step_of_detection = self._step
def _loose_too_old_objects(self):
min_step_required = self._step - self.n_steps_to_track
self.tracked_objects = [obj for obj in self.tracked_objects if obj.step_of_detection >= min_step_required]
for obj in self.tracked_objects:
if obj.previous_occurrences and obj.previous_occurrences[0].step_of_detection < min_step_required:
obj.previous_occurrences.popleft()
@abstractmethod
def _link_to_previous_objects(self, objects: List[DetectedObject], image: Image) -> Iterable[ObjectTrack]:
pass
def update_from_tracks(self, tracks: Iterable[ObjectTrack]):
self.tracked_objects = [track.merge() for track in tracks]
from dataclasses import dataclass
from typing import List
from polystar.common.target_pipeline.detected_objects.detected_object import DetectedObject
from polystar.common.target_pipeline.objects_trackers.objects_tracker_abc import ObjectsTrackerABC
from polystar.common.target_pipeline.target_pipeline import TargetPipeline
@dataclass
class TrackingTargetPipeline(TargetPipeline):
tracker: ObjectsTrackerABC
def _detect_all_objects(self, image) -> List[DetectedObject]:
return self.tracker.track_objects(super()._detect_all_objects(image), image)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment