From 0a03a66b1ec703e7e835b8deab701523bfa9cc5d Mon Sep 17 00:00:00 2001
From: JsPatenaude <jean-sebastien.patenaude@polymtl.ca>
Date: Sat, 4 Apr 2020 18:47:33 -0400
Subject: [PATCH] ignore th 6s in dataset

---
 .../armor_digit/baseline_experiments.py       | 37 ++++++++++++-------
 .../research/dataset/armor_dataset_factory.py | 25 +++++--------
 2 files changed, 32 insertions(+), 30 deletions(-)

diff --git a/robots-at-robots/research/armor_digit/baseline_experiments.py b/robots-at-robots/research/armor_digit/baseline_experiments.py
index 8768336..92dcc29 100644
--- a/robots-at-robots/research/armor_digit/baseline_experiments.py
+++ b/robots-at-robots/research/armor_digit/baseline_experiments.py
@@ -7,6 +7,8 @@ from tensorflow.keras.utils import to_categorical
 from tensorflow.keras.models import Sequential
 from tensorflow.keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
 
+from sklearn.model_selection import train_test_split
+
 IMG_SIZE = 28
 DATASET_SPLIT_FACTOR = 5
 N_CLASSES = 10
@@ -39,10 +41,13 @@ def load_digits_img_to_data():
     # TODO but before, we use this hardcoded path...
     path = '.\\..\\..\\..\\dataset\\dji_roco\\robomaster_Final Tournament\\digits_found'
     digits_found = os.listdir(path)
-    features_train = []
-    label_train = []
-    features_test = []
-    label_test = []
+    # features_train = []
+    # label_train = []
+    # features_test = []
+    # label_test = []
+
+    all_features = []
+    all_labels = []
 
     for i, img_file in enumerate(digits_found):
         img_file_name = os.path.join(path, img_file)
@@ -52,18 +57,25 @@ def load_digits_img_to_data():
 
         # split datatset into training and test
         # TODO use the train_test_split or see https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html
-        if i % DATASET_SPLIT_FACTOR == 0:
-            features_test.append(digit_res)
-            label_test.append(img_file[-5])
-        else:
-            features_train.append(digit_res)
-            label_train.append(img_file[-5])
+        # if i % DATASET_SPLIT_FACTOR == 0:
+        #     features_test.append(digit_res)
+        #     label_test.append(img_file[-5])
+        # else:
+        #     features_train.append(digit_res)
+        #     label_train.append(img_file[-5])
+
+        all_features.append(digit_res)
+        all_labels.append(img_file[-5])
 
     #     cv2.imshow('img', digit_res)
     #     k = cv2.waitKey(0) & 0xFF
     #     if k == 27:
     #         break
     # cv2.destroyAllWindows()
+
+    features_train, features_test, label_train, label_test = train_test_split(
+        all_features, all_labels, test_size=0.1)
+
     return np.array(features_train), np.array(label_train), np.array(features_test), np.array(label_test)
 
 
@@ -112,11 +124,8 @@ def training_model(features_train, label_train, features_test, label_test):
 
 if __name__ == "__main__":
     digits_features_train, digits_label_train, digits_features_test, digits_label_test = load_digits_img_to_data()
-    # TODO fct to split data into train and test
     for img in digits_label_train:
-        # plt.imshow(img, cmap='Greys')
-        # plt.show()
         print('train', img)
     for img in digits_label_test:
         print('test', img)
-    training_model(digits_features_train, digits_label_train, digits_features_test, digits_label_test)
+    # training_model(digits_features_train, digits_label_train, digits_features_test, digits_label_test)
diff --git a/robots-at-robots/research/dataset/armor_dataset_factory.py b/robots-at-robots/research/dataset/armor_dataset_factory.py
index 5fc5289..60db019 100644
--- a/robots-at-robots/research/dataset/armor_dataset_factory.py
+++ b/robots-at-robots/research/dataset/armor_dataset_factory.py
@@ -14,7 +14,7 @@ from research_common.dataset.roco_dataset import ROCODataset
 class ArmorDatasetFactory:
     @staticmethod
     def from_image_annotation(
-            image_annotation: ImageAnnotation,
+        image_annotation: ImageAnnotation,
     ) -> Iterable[Tuple[Image, ArmorColor, ArmorNumber, int, Path]]:
         img = image_annotation.image
         armors: List[Armor] = TypeObjectValidator(ObjectType.Armor).filter(image_annotation.objects, img)
@@ -31,22 +31,15 @@ class ArmorDatasetFactory:
 
 def save_digits_img():
     for j, (digit_img, c, n, k, p) in enumerate(ArmorDatasetFactory.from_dataset(DJIROCODataset.Final)):
-        print(c, n, k, 'in', p)
-        plt.imshow(digit_img)
-        plt.savefig(str(p.parents[1]) + '\\digits_found\\digit' + str(j) + '_' + c.name + '_' + str(n) + '.png')
-        plt.clf()
-
-        if j == 50:
-            break
+        # we ignore the 6, because the format of their pictures is not accurate and we don't need them
+        if n != 6:
+            if j >= 50:
+                break
+            print(c, n, k, 'in', p)
+            plt.imshow(digit_img)
+            plt.savefig(str(p.parents[1]) + '\\digits_found\\digit' + str(j) + '_' + c.name + '_' + str(n) + '.png')
+            plt.clf()
 
 
 if __name__ == "__main__":
     save_digits_img()
-    # for i, (armor_img, c, n, k, p) in enumerate(ArmorDatasetFactory.from_dataset(DJIROCODataset.Final)):
-    #     print(c, n, k, "in", p)
-    #     plt.imshow(armor_img)
-    #     plt.show()
-    #     plt.clf()
-    #
-    #     if i == 50:
-    #         break
-- 
GitLab