From 36dcb34255d2a6ed83986809b96cb240625d9ed9 Mon Sep 17 00:00:00 2001 From: JsPatenaude <jean-sebastien.patenaude@polymtl.ca> Date: Sat, 4 Apr 2020 17:17:17 -0400 Subject: [PATCH] starting keras digit recognition --- dataset/dji_roco/.gitignore | 1 + .../armor_digit/baseline_experiments.py | 129 ++++++++++++++++-- .../research/dataset/armor_dataset_factory.py | 28 ++-- 3 files changed, 137 insertions(+), 21 deletions(-) diff --git a/dataset/dji_roco/.gitignore b/dataset/dji_roco/.gitignore index f174da5..837840c 100644 --- a/dataset/dji_roco/.gitignore +++ b/dataset/dji_roco/.gitignore @@ -1,2 +1,3 @@ **/*.xml **/*.jpg +**/*.png diff --git a/robots-at-robots/research/armor_digit/baseline_experiments.py b/robots-at-robots/research/armor_digit/baseline_experiments.py index 9b9cc35..8768336 100644 --- a/robots-at-robots/research/armor_digit/baseline_experiments.py +++ b/robots-at-robots/research/armor_digit/baseline_experiments.py @@ -1,19 +1,122 @@ import logging -from polystar.common.image_pipeline.image_pipeline import ImagePipeline -from polystar.common.image_pipeline.models.random_model import RandomModel -from research.armor_digit.armor_digit_pipeline_reporter_factory import ArmorDigitPipelineReporterFactory -from research_common.dataset.twitch.twitch_roco_datasets import TwitchROCODataset -from research_common.dataset.union_dataset import UnionDataset +import os +import cv2 +import numpy as np +from tensorflow.keras.utils import to_categorical +from tensorflow.keras.models import Sequential +from tensorflow.keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D -if __name__ == "__main__": - logging.getLogger().setLevel("INFO") +IMG_SIZE = 28 +DATASET_SPLIT_FACTOR = 5 +N_CLASSES = 10 + + +# from polystar.common.image_pipeline.image_pipeline import ImagePipeline +# from polystar.common.image_pipeline.models.random_model import RandomModel +# from research.armor_digit.armor_digit_pipeline_reporter_factory import ArmorDigitPipelineReporterFactory +# from research_common.dataset.twitch.twitch_roco_datasets import TwitchROCODataset +# from research_common.dataset.union_dataset import UnionDataset + +# if __name__ == "__main__": +# logging.getLogger().setLevel("INFO") +# +# reporter = ArmorDigitPipelineReporterFactory.from_roco_datasets( +# train_roco_dataset=UnionDataset(TwitchROCODataset.TWITCH_470151286, TwitchROCODataset.TWITCH_470150052), +# test_roco_dataset=TwitchROCODataset.TWITCH_470152289, +# ) +# +# random_pipeline = ImagePipeline(model=RandomModel(), custom_name="random") +# +# reporter.report([random_pipeline], evaluation_short_name="baseline") + + +def load_digits_img_to_data(): + # TODO add a model for digits_img in common/polystar/common/models + # TODO add a digits_img in roco_dataset + # TODO create a from_dataset (see armor_dataset_factory) to read digits_img from digits_found folder + # TODO change the way to append digit label to label_train and label_test (with the info from model and dataset) + # TODO but before, we use this hardcoded path... + path = '.\\..\\..\\..\\dataset\\dji_roco\\robomaster_Final Tournament\\digits_found' + digits_found = os.listdir(path) + features_train = [] + label_train = [] + features_test = [] + label_test = [] + + for i, img_file in enumerate(digits_found): + img_file_name = os.path.join(path, img_file) + digit_img = cv2.imread(img_file_name) + digit_res = cv2.resize(digit_img, dsize=(IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_CUBIC) + digit_res = cv2.cvtColor(digit_res, cv2.COLOR_BGR2GRAY) + + # split datatset into training and test + # TODO use the train_test_split or see https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html + if i % DATASET_SPLIT_FACTOR == 0: + features_test.append(digit_res) + label_test.append(img_file[-5]) + else: + features_train.append(digit_res) + label_train.append(img_file[-5]) - reporter = ArmorDigitPipelineReporterFactory.from_roco_datasets( - train_roco_dataset=UnionDataset(TwitchROCODataset.TWITCH_470151286, TwitchROCODataset.TWITCH_470150052), - test_roco_dataset=TwitchROCODataset.TWITCH_470152289, - ) + # cv2.imshow('img', digit_res) + # k = cv2.waitKey(0) & 0xFF + # if k == 27: + # break + # cv2.destroyAllWindows() + return np.array(features_train), np.array(label_train), np.array(features_test), np.array(label_test) - random_pipeline = ImagePipeline(model=RandomModel(), custom_name="random") - reporter.report([random_pipeline], evaluation_short_name="baseline") +def training_model(features_train, label_train, features_test, label_test): + # Source: https://www.sitepoint.com/keras-digit-recognition-tutorial/ + # Cleaning + features_train = features_train.reshape(features_train.shape[0], IMG_SIZE, IMG_SIZE, 1) + features_test = features_test.reshape(features_test.shape[0], IMG_SIZE, IMG_SIZE, 1) + label_train = to_categorical(label_train, N_CLASSES) + label_test = to_categorical(label_test, N_CLASSES) + + # Design Model + model = Sequential() + model.add(Conv2D(32, kernel_size=(3, 3), + activation='relu', + input_shape=(IMG_SIZE, IMG_SIZE, 1))) + model.add(Conv2D(64, (3, 3), activation='relu')) + model.add(MaxPooling2D(pool_size=(2, 2))) + # TODO maybe dont drop or less if not enough data in dataset + # model.add(Dropout(0.25)) + model.add(Flatten()) + model.add(Dense(128, activation='relu')) + # TODO maybe dont drop or less if not enough data in dataset + # model.add((Dropout(0.5))) + # because we have a number of pre-decided classes + model.add(Dense(N_CLASSES, activation='softmax')) + + # Compile and Train Model + model.compile(loss='categorical_crossentropy', + optimizer='adam', + metrics=['accuracy']) + + batch_size = 128 + epochs = 10 + + model.fit(features_train, label_train, + batch_size=batch_size, + epochs=epochs, + verbose=1, + validation_data=(features_test, label_test)) + score = model.evaluate(features_test, label_test, verbose=0) + print('Test loss:', score[0]) + print('Test accuracy:', score[1]) + model.save("test_model.h5") + + +if __name__ == "__main__": + digits_features_train, digits_label_train, digits_features_test, digits_label_test = load_digits_img_to_data() + # TODO fct to split data into train and test + for img in digits_label_train: + # plt.imshow(img, cmap='Greys') + # plt.show() + print('train', img) + for img in digits_label_test: + print('test', img) + training_model(digits_features_train, digits_label_train, digits_features_test, digits_label_test) diff --git a/robots-at-robots/research/dataset/armor_dataset_factory.py b/robots-at-robots/research/dataset/armor_dataset_factory.py index 7940046..5fc5289 100644 --- a/robots-at-robots/research/dataset/armor_dataset_factory.py +++ b/robots-at-robots/research/dataset/armor_dataset_factory.py @@ -14,12 +14,12 @@ from research_common.dataset.roco_dataset import ROCODataset class ArmorDatasetFactory: @staticmethod def from_image_annotation( - image_annotation: ImageAnnotation, + image_annotation: ImageAnnotation, ) -> Iterable[Tuple[Image, ArmorColor, ArmorNumber, int, Path]]: img = image_annotation.image armors: List[Armor] = TypeObjectValidator(ObjectType.Armor).filter(image_annotation.objects, img) for i, obj in enumerate(armors): - croped_img = img[obj.y : obj.y + obj.h, obj.x : obj.x + obj.w] + croped_img = img[obj.y: obj.y + obj.h, obj.x: obj.x + obj.w] yield croped_img, obj.color, obj.numero, i, image_annotation.image_path @staticmethod @@ -29,12 +29,24 @@ class ArmorDatasetFactory: yield rv -if __name__ == "__main__": - for i, (armor_img, c, n, k, p) in enumerate(ArmorDatasetFactory.from_dataset(DJIROCODataset.CentralChina)): - print(c, n, k, "in", p) - plt.imshow(armor_img) - plt.show() +def save_digits_img(): + for j, (digit_img, c, n, k, p) in enumerate(ArmorDatasetFactory.from_dataset(DJIROCODataset.Final)): + print(c, n, k, 'in', p) + plt.imshow(digit_img) + plt.savefig(str(p.parents[1]) + '\\digits_found\\digit' + str(j) + '_' + c.name + '_' + str(n) + '.png') plt.clf() - if i == 50: + if j == 50: break + + +if __name__ == "__main__": + save_digits_img() + # for i, (armor_img, c, n, k, p) in enumerate(ArmorDatasetFactory.from_dataset(DJIROCODataset.Final)): + # print(c, n, k, "in", p) + # plt.imshow(armor_img) + # plt.show() + # plt.clf() + # + # if i == 50: + # break -- GitLab