diff --git a/common/polystar/common/filters/__init__.py b/common/polystar/common/filters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/common/polystar/common/filters/filter_abc.py b/common/polystar/common/filters/filter_abc.py index 6f5932e60edbff02f41e45622282ba04cf7a60cc..435b1836979021bd096e81d3062400faa97932cb 100644 --- a/common/polystar/common/filters/filter_abc.py +++ b/common/polystar/common/filters/filter_abc.py @@ -37,6 +37,31 @@ class FilterABC(Generic[T], ABC): def validate_single(self, example: T) -> bool: pass + def __or__(self, other: "FilterABC") -> "FilterABC[T]": + return UnionFilter(self, other) + + def __and__(self, other: "FilterABC") -> "FilterABC[T]": + return IntersectionFilter(self, other) + + +class IntersectionFilter(FilterABC[T]): + def __init__(self, *filters: FilterABC[T]): + self.filters = filters + assert self.filters + + def validate_single(self, example: T) -> bool: + return all(f.validate_single(example) for f in self.filters) + + +class UnionFilter(FilterABC[T]): + def __init__(self, *filters: FilterABC[T]): + print(self, filters) + self.filters = filters + assert self.filters + + def validate_single(self, example: T) -> bool: + return any(f.validate_single(example) for f in self.filters) + def _filter_with_siblings_from_preds( are_valid: List[bool], examples: List[T], *siblings: List, expected_value: bool = True diff --git a/common/polystar/common/filters/intersection_filter.py b/common/polystar/common/filters/intersection_filter.py deleted file mode 100644 index 843c1c69025434f2fda3698e49c23530372e5373..0000000000000000000000000000000000000000 --- a/common/polystar/common/filters/intersection_filter.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import List - -from polystar.common.filters.filter_abc import FilterABC, T - - -class IntersectionFilter(FilterABC[T]): - def __init__(self, filters: List[FilterABC[T]]): - self.filters = filters - assert self.filters - - def validate_single(self, example: T) -> bool: - return all(f.validate_single(example) for f in example) diff --git a/common/polystar/common/filters/union_filter.py b/common/polystar/common/filters/union_filter.py deleted file mode 100644 index df3c626dcc7f2ecab910d815e3d5131c575c8097..0000000000000000000000000000000000000000 --- a/common/polystar/common/filters/union_filter.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import List - -from polystar.common.filters.filter_abc import FilterABC, T - - -class UnionFilter(FilterABC[T]): - def __init__(self, filters: List[FilterABC[T]]): - self.filters = filters - assert self.filters - - def validate_single(self, example: T) -> bool: - return any(f.validate_single(example) for f in example) diff --git a/common/tests/common/unittests/filters/test_filters_abc.py b/common/tests/common/unittests/filters/test_filters_abc.py index 06cb81838b77c3fd5b5367b8b6987aba12623f2c..1888e595bb4b454875cb42603d8be79b681ce27d 100644 --- a/common/tests/common/unittests/filters/test_filters_abc.py +++ b/common/tests/common/unittests/filters/test_filters_abc.py @@ -1,6 +1,7 @@ from unittest import TestCase from polystar.common.filters.filter_abc import FilterABC +from polystar.common.filters.keep_filter import KeepFilter class OddFilter(FilterABC[int]): @@ -61,3 +62,17 @@ class TestFilterABC(TestCase): numbers = [1, 2, 3, 4, 5, 6] self.assertEqual(([False, True, False, True, False, True]), f.validate(numbers)) + + def test_or(self): + f = KeepFilter([2, 3]) | KeepFilter([3, 4]) + + numbers = [1, 2, 3, 4, 5, 6] + + self.assertEqual(([False, True, True, True, False, False]), f.validate(numbers)) + + def test_and(self): + f = KeepFilter([2, 3]) & KeepFilter([3, 4]) + + numbers = [1, 2, 3, 4, 5, 6] + + self.assertEqual(([False, False, True, False, False, False]), f.validate(numbers))