From 9f1c1e9985b5b2326ca905b496bd7942d9bfae09 Mon Sep 17 00:00:00 2001
From: Mathieu Beligon <mathieu@feedly.com>
Date: Sun, 19 Jan 2020 11:18:11 -0500
Subject: [PATCH] [common] (scrits) add a script to split the datasets into
 train/val/test

---
 .../scripts/train_val_test_split.py           | 39 +++++++++++++++++++
 1 file changed, 39 insertions(+)
 create mode 100644 common/research_common/scripts/train_val_test_split.py

diff --git a/common/research_common/scripts/train_val_test_split.py b/common/research_common/scripts/train_val_test_split.py
new file mode 100644
index 0000000..6e1dc56
--- /dev/null
+++ b/common/research_common/scripts/train_val_test_split.py
@@ -0,0 +1,39 @@
+from pathlib import Path
+from typing import Iterable
+
+from sklearn.model_selection import train_test_split
+
+from research_common.dataset.dataset import Dataset
+from research_common.dataset.roco.roco_datasets import ROCODataset
+from research_common.dataset.split import Split
+
+
+def _check_for_previous_split(dataset: Dataset):
+    for split in Split:
+        error_message = f"split {split.value}.txt already exists. Forbidden to overwrite for results consistency."
+        assert not split.get_split_file(dataset).exists(), error_message
+
+
+def _save_set_split(dataset: Dataset, split: Split, file_paths: Iterable[Path]):
+    split.get_split_file(dataset).write_text("\n".join(file.stem for file in file_paths))
+
+
+def _create_splits(dataset: Dataset):
+    file_paths = list(dataset.image_paths)
+    train_val_files, test_files = train_test_split(file_paths, test_size=0.2, random_state=424242)
+    train_files, val_files = train_test_split(train_val_files, test_size=0.2, random_state=424242)
+
+    _save_set_split(dataset, Split.Test, test_files)
+    _save_set_split(dataset, Split.Train, train_files)
+    _save_set_split(dataset, Split.Val, val_files)
+    _save_set_split(dataset, Split.TrainVal, train_val_files)
+
+
+def train_test_split_set(dataset: Dataset):
+    _check_for_previous_split(dataset)
+    _create_splits(dataset)
+
+
+if __name__ == "__main__":
+    for _roco_set in ROCODataset:
+        train_test_split_set(_roco_set)
-- 
GitLab