From 0a03a66b1ec703e7e835b8deab701523bfa9cc5d Mon Sep 17 00:00:00 2001 From: JsPatenaude <jean-sebastien.patenaude@polymtl.ca> Date: Sat, 4 Apr 2020 18:47:33 -0400 Subject: [PATCH] ignore th 6s in dataset --- .../armor_digit/baseline_experiments.py | 37 ++++++++++++------- .../research/dataset/armor_dataset_factory.py | 25 +++++-------- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/robots-at-robots/research/armor_digit/baseline_experiments.py b/robots-at-robots/research/armor_digit/baseline_experiments.py index 8768336..92dcc29 100644 --- a/robots-at-robots/research/armor_digit/baseline_experiments.py +++ b/robots-at-robots/research/armor_digit/baseline_experiments.py @@ -7,6 +7,8 @@ from tensorflow.keras.utils import to_categorical from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D +from sklearn.model_selection import train_test_split + IMG_SIZE = 28 DATASET_SPLIT_FACTOR = 5 N_CLASSES = 10 @@ -39,10 +41,13 @@ def load_digits_img_to_data(): # TODO but before, we use this hardcoded path... path = '.\\..\\..\\..\\dataset\\dji_roco\\robomaster_Final Tournament\\digits_found' digits_found = os.listdir(path) - features_train = [] - label_train = [] - features_test = [] - label_test = [] + # features_train = [] + # label_train = [] + # features_test = [] + # label_test = [] + + all_features = [] + all_labels = [] for i, img_file in enumerate(digits_found): img_file_name = os.path.join(path, img_file) @@ -52,18 +57,25 @@ def load_digits_img_to_data(): # split datatset into training and test # TODO use the train_test_split or see https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html - if i % DATASET_SPLIT_FACTOR == 0: - features_test.append(digit_res) - label_test.append(img_file[-5]) - else: - features_train.append(digit_res) - label_train.append(img_file[-5]) + # if i % DATASET_SPLIT_FACTOR == 0: + # features_test.append(digit_res) + # label_test.append(img_file[-5]) + # else: + # features_train.append(digit_res) + # label_train.append(img_file[-5]) + + all_features.append(digit_res) + all_labels.append(img_file[-5]) # cv2.imshow('img', digit_res) # k = cv2.waitKey(0) & 0xFF # if k == 27: # break # cv2.destroyAllWindows() + + features_train, features_test, label_train, label_test = train_test_split( + all_features, all_labels, test_size=0.1) + return np.array(features_train), np.array(label_train), np.array(features_test), np.array(label_test) @@ -112,11 +124,8 @@ def training_model(features_train, label_train, features_test, label_test): if __name__ == "__main__": digits_features_train, digits_label_train, digits_features_test, digits_label_test = load_digits_img_to_data() - # TODO fct to split data into train and test for img in digits_label_train: - # plt.imshow(img, cmap='Greys') - # plt.show() print('train', img) for img in digits_label_test: print('test', img) - training_model(digits_features_train, digits_label_train, digits_features_test, digits_label_test) + # training_model(digits_features_train, digits_label_train, digits_features_test, digits_label_test) diff --git a/robots-at-robots/research/dataset/armor_dataset_factory.py b/robots-at-robots/research/dataset/armor_dataset_factory.py index 5fc5289..60db019 100644 --- a/robots-at-robots/research/dataset/armor_dataset_factory.py +++ b/robots-at-robots/research/dataset/armor_dataset_factory.py @@ -14,7 +14,7 @@ from research_common.dataset.roco_dataset import ROCODataset class ArmorDatasetFactory: @staticmethod def from_image_annotation( - image_annotation: ImageAnnotation, + image_annotation: ImageAnnotation, ) -> Iterable[Tuple[Image, ArmorColor, ArmorNumber, int, Path]]: img = image_annotation.image armors: List[Armor] = TypeObjectValidator(ObjectType.Armor).filter(image_annotation.objects, img) @@ -31,22 +31,15 @@ class ArmorDatasetFactory: def save_digits_img(): for j, (digit_img, c, n, k, p) in enumerate(ArmorDatasetFactory.from_dataset(DJIROCODataset.Final)): - print(c, n, k, 'in', p) - plt.imshow(digit_img) - plt.savefig(str(p.parents[1]) + '\\digits_found\\digit' + str(j) + '_' + c.name + '_' + str(n) + '.png') - plt.clf() - - if j == 50: - break + # we ignore the 6, because the format of their pictures is not accurate and we don't need them + if n != 6: + if j >= 50: + break + print(c, n, k, 'in', p) + plt.imshow(digit_img) + plt.savefig(str(p.parents[1]) + '\\digits_found\\digit' + str(j) + '_' + c.name + '_' + str(n) + '.png') + plt.clf() if __name__ == "__main__": save_digits_img() - # for i, (armor_img, c, n, k, p) in enumerate(ArmorDatasetFactory.from_dataset(DJIROCODataset.Final)): - # print(c, n, k, "in", p) - # plt.imshow(armor_img) - # plt.show() - # plt.clf() - # - # if i == 50: - # break -- GitLab