From 3207414514e6f9668f624bcc2ffbafa7bb2bbe64 Mon Sep 17 00:00:00 2001 From: JsPatenaude <jean-sebastien.patenaude@polymtl.ca> Date: Sat, 4 Apr 2020 18:56:56 -0400 Subject: [PATCH] cleanup split train test --- .../armor_digit/baseline_experiments.py | 45 +++++-------------- .../research/dataset/armor_dataset_factory.py | 2 +- 2 files changed, 13 insertions(+), 34 deletions(-) diff --git a/robots-at-robots/research/armor_digit/baseline_experiments.py b/robots-at-robots/research/armor_digit/baseline_experiments.py index 92dcc29..23d89d1 100644 --- a/robots-at-robots/research/armor_digit/baseline_experiments.py +++ b/robots-at-robots/research/armor_digit/baseline_experiments.py @@ -6,11 +6,10 @@ import numpy as np from tensorflow.keras.utils import to_categorical from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D - from sklearn.model_selection import train_test_split IMG_SIZE = 28 -DATASET_SPLIT_FACTOR = 5 +TEST_SIZE_FACTOR = 0.1 N_CLASSES = 10 @@ -41,11 +40,6 @@ def load_digits_img_to_data(): # TODO but before, we use this hardcoded path... path = '.\\..\\..\\..\\dataset\\dji_roco\\robomaster_Final Tournament\\digits_found' digits_found = os.listdir(path) - # features_train = [] - # label_train = [] - # features_test = [] - # label_test = [] - all_features = [] all_labels = [] @@ -54,38 +48,23 @@ def load_digits_img_to_data(): digit_img = cv2.imread(img_file_name) digit_res = cv2.resize(digit_img, dsize=(IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_CUBIC) digit_res = cv2.cvtColor(digit_res, cv2.COLOR_BGR2GRAY) - - # split datatset into training and test - # TODO use the train_test_split or see https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html - # if i % DATASET_SPLIT_FACTOR == 0: - # features_test.append(digit_res) - # label_test.append(img_file[-5]) - # else: - # features_train.append(digit_res) - # label_train.append(img_file[-5]) - all_features.append(digit_res) all_labels.append(img_file[-5]) - # cv2.imshow('img', digit_res) - # k = cv2.waitKey(0) & 0xFF - # if k == 27: - # break - # cv2.destroyAllWindows() - - features_train, features_test, label_train, label_test = train_test_split( - all_features, all_labels, test_size=0.1) + # split dataset into training and test + features_train, features_test, labels_train, labels_test = train_test_split( + all_features, all_labels, test_size=TEST_SIZE_FACTOR) - return np.array(features_train), np.array(label_train), np.array(features_test), np.array(label_test) + return np.array(features_train), np.array(labels_train), np.array(features_test), np.array(labels_test) -def training_model(features_train, label_train, features_test, label_test): +def training_model(features_train, labels_train, features_test, labels_test): # Source: https://www.sitepoint.com/keras-digit-recognition-tutorial/ # Cleaning features_train = features_train.reshape(features_train.shape[0], IMG_SIZE, IMG_SIZE, 1) features_test = features_test.reshape(features_test.shape[0], IMG_SIZE, IMG_SIZE, 1) - label_train = to_categorical(label_train, N_CLASSES) - label_test = to_categorical(label_test, N_CLASSES) + label_train = to_categorical(labels_train, N_CLASSES) + label_test = to_categorical(labels_test, N_CLASSES) # Design Model model = Sequential() @@ -123,9 +102,9 @@ def training_model(features_train, label_train, features_test, label_test): if __name__ == "__main__": - digits_features_train, digits_label_train, digits_features_test, digits_label_test = load_digits_img_to_data() - for img in digits_label_train: + digits_features_train, digits_labels_train, digits_features_test, digits_labels_test = load_digits_img_to_data() + for img in digits_labels_train: print('train', img) - for img in digits_label_test: + for img in digits_labels_test: print('test', img) - # training_model(digits_features_train, digits_label_train, digits_features_test, digits_label_test) + # training_model(digits_features_train, digits_labels_train, digits_features_test, digits_labels_test) diff --git a/robots-at-robots/research/dataset/armor_dataset_factory.py b/robots-at-robots/research/dataset/armor_dataset_factory.py index 60db019..8d85045 100644 --- a/robots-at-robots/research/dataset/armor_dataset_factory.py +++ b/robots-at-robots/research/dataset/armor_dataset_factory.py @@ -31,7 +31,7 @@ class ArmorDatasetFactory: def save_digits_img(): for j, (digit_img, c, n, k, p) in enumerate(ArmorDatasetFactory.from_dataset(DJIROCODataset.Final)): - # we ignore the 6, because the format of their pictures is not accurate and we don't need them + # we ignore the 6s, because the format of their pictures is not accurate and we don't need them if n != 6: if j >= 50: break -- GitLab