Skip to content
Snippets Groups Projects
Commit 4dde1c37 authored by Mathieu Beligon's avatar Mathieu Beligon
Browse files

[common] (ArmorsColorDescriptor) Add class to add the color of the armor amoung predictions

parent c6f67dfc
No related branches found
No related tags found
No related merge requests found
from dataclasses import dataclass
from typing import List, Tuple
from typing import Any, List, Tuple
import numpy as np
from polystar.common.image_pipeline.models.absolute_classifier_model_abc import AbsoluteClassifierModelABC
......@@ -11,6 +13,13 @@ class RedBlueComparisonModel(AbsoluteClassifierModelABC):
red_channel_id: int = 0
blue_channel_id: int = 2
def __post_init__(self):
self.labels_ = np.asarray(sorted(["Red", "Grey", "Blue"]))
self.label2index_ = {label: i for i, label in enumerate(self.labels_)}
def fit(self, features: List[Any], labels: List[Any]) -> "RedBlueComparisonModel":
return self
def predict(self, features: List[Tuple[float, float, float]]) -> List[str]:
return [
"Red" if feature[self.red_channel_id] >= feature[self.blue_channel_id] else "Blue" for feature in features
......
from dataclasses import dataclass
from enum import auto
from typing import Any, Dict, NewType
from dataclasses import dataclass
from polystar.common.models.box import Box
from polystar.common.utils.no_case_enum import NoCaseEnum
......@@ -18,6 +17,9 @@ class ArmorColor(NoCaseEnum):
Unknown = auto()
ORDERED_ARMOR_COLORS = [ArmorColor.Blue, ArmorColor.Grey, ArmorColor.Red]
class ObjectType(NoCaseEnum):
Car = auto()
Watcher = auto()
......
from dataclasses import dataclass
from typing import List
from polystar.common.image_pipeline.classifier_image_pipeline import ClassifierImagePipeline
from polystar.common.models.image import Image
from polystar.common.target_pipeline.armors_descriptors.armors_descriptor_abc import ArmorsDescriptorABC
from polystar.common.target_pipeline.detected_objects.detected_armor import DetectedArmor
@dataclass
class ArmorsColorDescriptor(ArmorsDescriptorABC):
image_pipeline: ClassifierImagePipeline
def _describe_armors_from_images(self, armors_images: List[Image], armors: List[DetectedArmor]):
colors_predictions = self.image_pipeline.predict_proba(armors_images)
for colors_proba, armor in zip(colors_predictions, armors):
armor.colors_proba = colors_proba
from dataclasses import dataclass, field
from typing import List
import numpy as np
from numpy import argmax
from polystar.common.models.object import ArmorColor, ObjectType
from polystar.common.models.object import ORDERED_ARMOR_COLORS, ArmorColor, ObjectType
from polystar.common.target_pipeline.detected_objects.detected_object import DetectedObject
......@@ -12,8 +12,8 @@ class DetectedArmor(DetectedObject):
def __post_init__(self):
assert self.type == ObjectType.Armor
colors_proba: List[float] = field(init=False, default=None)
numbers_proba: List[float] = field(init=False, default=None)
colors_proba: np.ndarray = field(init=False, default=None)
numbers_proba: np.ndarray = field(init=False, default=None)
_color: ArmorColor = field(init=False, default=None)
_number: int = field(init=False, default=None)
......@@ -23,8 +23,8 @@ class DetectedArmor(DetectedObject):
if self._color is not None:
return self._color
if self.colors_proba:
self._color = max(zip(self.colors_proba, ArmorColor))[1]
if self.colors_proba is not None:
self._color = ORDERED_ARMOR_COLORS[self.colors_proba.argmax()]
return self._color
return ArmorColor.Unknown
......@@ -35,6 +35,7 @@ class DetectedArmor(DetectedObject):
return self._number
if self.numbers_proba:
# FIXME: We skip some of the numbers at training...
self._number = 1 + argmax(self.colors_proba)
return self._number
......
import cv2
from polystar.common.communication.print_target_sender import PrintTargetSender
from polystar.common.image_pipeline.classifier_image_pipeline import ClassifierImagePipeline
from polystar.common.image_pipeline.image_featurizer.mean_rgb_channels_featurizer import MeanChannelsFeaturizer
from polystar.common.image_pipeline.models.red_blue_channels_comparison_model import RedBlueComparisonModel
from polystar.common.models.camera import Camera
from polystar.common.models.label_map import LabelMap
from polystar.common.target_pipeline.armors_descriptors.armors_color_descriptor import ArmorsColorDescriptor
from polystar.common.target_pipeline.debug_pipeline import DebugTargetPipeline
from polystar.common.target_pipeline.object_selectors.closest_object_selector import ClosestObjectSelector
from polystar.common.target_pipeline.objects_detectors.tf_model_objects_detector import TFModelObjectsDetector
......@@ -25,7 +29,11 @@ if __name__ == "__main__":
pipeline = DebugTargetPipeline(
objects_detector=TFModelObjectsDetector(load_tf_model(), injector.get(LabelMap)),
armors_descriptors=[],
armors_descriptors=[
ArmorsColorDescriptor(
ClassifierImagePipeline(image_featurizer=MeanChannelsFeaturizer(), model=RedBlueComparisonModel())
)
],
objects_validators=[ConfidenceObjectValidator(0.6)],
object_selector=ClosestObjectSelector(),
target_factory=RatioSimpleTargetFactory(injector.get(Camera), 300, 100),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment