From 3207414514e6f9668f624bcc2ffbafa7bb2bbe64 Mon Sep 17 00:00:00 2001
From: JsPatenaude <jean-sebastien.patenaude@polymtl.ca>
Date: Sat, 4 Apr 2020 18:56:56 -0400
Subject: [PATCH] cleanup split train test

---
 .../armor_digit/baseline_experiments.py       | 45 +++++--------------
 .../research/dataset/armor_dataset_factory.py |  2 +-
 2 files changed, 13 insertions(+), 34 deletions(-)

diff --git a/robots-at-robots/research/armor_digit/baseline_experiments.py b/robots-at-robots/research/armor_digit/baseline_experiments.py
index 92dcc29..23d89d1 100644
--- a/robots-at-robots/research/armor_digit/baseline_experiments.py
+++ b/robots-at-robots/research/armor_digit/baseline_experiments.py
@@ -6,11 +6,10 @@ import numpy as np
 from tensorflow.keras.utils import to_categorical
 from tensorflow.keras.models import Sequential
 from tensorflow.keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
-
 from sklearn.model_selection import train_test_split
 
 IMG_SIZE = 28
-DATASET_SPLIT_FACTOR = 5
+TEST_SIZE_FACTOR = 0.1
 N_CLASSES = 10
 
 
@@ -41,11 +40,6 @@ def load_digits_img_to_data():
     # TODO but before, we use this hardcoded path...
     path = '.\\..\\..\\..\\dataset\\dji_roco\\robomaster_Final Tournament\\digits_found'
     digits_found = os.listdir(path)
-    # features_train = []
-    # label_train = []
-    # features_test = []
-    # label_test = []
-
     all_features = []
     all_labels = []
 
@@ -54,38 +48,23 @@ def load_digits_img_to_data():
         digit_img = cv2.imread(img_file_name)
         digit_res = cv2.resize(digit_img, dsize=(IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_CUBIC)
         digit_res = cv2.cvtColor(digit_res, cv2.COLOR_BGR2GRAY)
-
-        # split datatset into training and test
-        # TODO use the train_test_split or see https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html
-        # if i % DATASET_SPLIT_FACTOR == 0:
-        #     features_test.append(digit_res)
-        #     label_test.append(img_file[-5])
-        # else:
-        #     features_train.append(digit_res)
-        #     label_train.append(img_file[-5])
-
         all_features.append(digit_res)
         all_labels.append(img_file[-5])
 
-    #     cv2.imshow('img', digit_res)
-    #     k = cv2.waitKey(0) & 0xFF
-    #     if k == 27:
-    #         break
-    # cv2.destroyAllWindows()
-
-    features_train, features_test, label_train, label_test = train_test_split(
-        all_features, all_labels, test_size=0.1)
+    # split dataset into training and test
+    features_train, features_test, labels_train, labels_test = train_test_split(
+        all_features, all_labels, test_size=TEST_SIZE_FACTOR)
 
-    return np.array(features_train), np.array(label_train), np.array(features_test), np.array(label_test)
+    return np.array(features_train), np.array(labels_train), np.array(features_test), np.array(labels_test)
 
 
-def training_model(features_train, label_train, features_test, label_test):
+def training_model(features_train, labels_train, features_test, labels_test):
     # Source: https://www.sitepoint.com/keras-digit-recognition-tutorial/
     # Cleaning
     features_train = features_train.reshape(features_train.shape[0], IMG_SIZE, IMG_SIZE, 1)
     features_test = features_test.reshape(features_test.shape[0], IMG_SIZE, IMG_SIZE, 1)
-    label_train = to_categorical(label_train, N_CLASSES)
-    label_test = to_categorical(label_test, N_CLASSES)
+    label_train = to_categorical(labels_train, N_CLASSES)
+    label_test = to_categorical(labels_test, N_CLASSES)
 
     # Design Model
     model = Sequential()
@@ -123,9 +102,9 @@ def training_model(features_train, label_train, features_test, label_test):
 
 
 if __name__ == "__main__":
-    digits_features_train, digits_label_train, digits_features_test, digits_label_test = load_digits_img_to_data()
-    for img in digits_label_train:
+    digits_features_train, digits_labels_train, digits_features_test, digits_labels_test = load_digits_img_to_data()
+    for img in digits_labels_train:
         print('train', img)
-    for img in digits_label_test:
+    for img in digits_labels_test:
         print('test', img)
-    # training_model(digits_features_train, digits_label_train, digits_features_test, digits_label_test)
+    # training_model(digits_features_train, digits_labels_train, digits_features_test, digits_labels_test)
diff --git a/robots-at-robots/research/dataset/armor_dataset_factory.py b/robots-at-robots/research/dataset/armor_dataset_factory.py
index 60db019..8d85045 100644
--- a/robots-at-robots/research/dataset/armor_dataset_factory.py
+++ b/robots-at-robots/research/dataset/armor_dataset_factory.py
@@ -31,7 +31,7 @@ class ArmorDatasetFactory:
 
 def save_digits_img():
     for j, (digit_img, c, n, k, p) in enumerate(ArmorDatasetFactory.from_dataset(DJIROCODataset.Final)):
-        # we ignore the 6, because the format of their pictures is not accurate and we don't need them
+        # we ignore the 6s, because the format of their pictures is not accurate and we don't need them
         if n != 6:
             if j >= 50:
                 break
-- 
GitLab