diff --git a/src/research/common/datasets/dataset_builder.py b/src/research/common/datasets/dataset_builder.py index 68b595a61a424c3ece0d6adae108100f3143b19b..d96cebe01ade94da1e0210ddb81ad924e2453b34 100644 --- a/src/research/common/datasets/dataset_builder.py +++ b/src/research/common/datasets/dataset_builder.py @@ -9,6 +9,7 @@ from research.common.datasets.lazy_dataset import ExampleT, LazyDataset, TargetT from research.common.datasets.shuffle_dataset import ShuffleDataset from research.common.datasets.slice_dataset import SliceDataset from research.common.datasets.transform_dataset import TransformDataset +from research.common.datasets.union_dataset import UnionLazyDataset class DatasetBuilder(Generic[ExampleT, TargetT], Iterable[Tuple[ExampleT, TargetT, str]]): @@ -71,6 +72,10 @@ class DatasetBuilder(Generic[ExampleT, TargetT], Iterable[Tuple[ExampleT, Target self.dataset = SliceDataset(self.dataset, start=n) return self + def __or__(self, other: "DatasetBuilder[ExampleT, TargetT]"): + self.dataset = UnionLazyDataset((self.dataset, other.dataset)) + return self + @property def name(self) -> str: return self.dataset.name