Commit 88db3bad authored by Mathieu Beligon's avatar Mathieu Beligon
Browse files

[fine-tuning] (script) add first draft of the script to fine tune

parent 8daf9a24
import re
import subprocess
import tarfile
from argparse import ArgumentParser
from enum import Enum
from pathlib import Path
from urllib.request import urlretrieve
ROOT_DIR: Path = Path(__file__).parent
TF_RECORDS_DIR: Path = ROOT_DIR / "dataset"
PRE_TRAINED_MODELS_DIR: Path = ROOT_DIR / "pre-trained-models"
TRAINING_DIR: Path = ROOT_DIR / "training"
CONFIGS_DIR: Path = ROOT_DIR / "models" / "research" / "object_detection" / "samples" / "configs/"
LABEL_MAP_PATH: Path = ROOT_DIR / "label_map.pbtxt"
class Records(Enum):
def __init__(self, train_file: str, val_file: str):
self.train_path = TF_RECORDS_DIR / train_file
self.test_path = TF_RECORDS_DIR / val_file
CENTRAL_CHINA = ("CentralChina_Train.record", "CentralChina_Val.record")
DJI_ROCO = ("DJI_ROCO_Train.record", "DJI_ROCO_Val.record")
SMALL = ("small_train.record", "small_test.record")
class Models(Enum):
SSD_MOBILENET_V2 = (
"ssd_mobilenet_v2_coco_2018_03_29",
"ssd_mobilenet_v2_coco.config",
12,
)
FASTER_RCNN_INCEPTION_V2 = (
"faster_rcnn_inception_v2_coco_2018_01_28",
"faster_rcnn_inception_v2_pets.config",
12,
)
RFCN_RESENET101 = (
"rfcn_resnet101_coco_2018_01_28",
"rfcn_resnet101_pets.config",
8,
)
def __init__(self, file_name: str, config_name: str, batch_size: int):
self.file_name = file_name
self.config_name = config_name
self.batch_size = batch_size
@property
def pre_trained_dir(self) -> Path:
return PRE_TRAINED_MODELS_DIR / self.file_name
@property
def checkpoint_path(self) -> Path:
return self.pre_trained_dir / "model.ckpt"
@property
def config_path(self) -> Path:
return CONFIGS_DIR / self.config_name
@property
def training_dir(self) -> Path:
return TRAINING_DIR / self.file_name
def download(self):
if self.training_dir.exists():
print(f"model {self.file_name} already downloaded")
return
zip_file = f"{self.training_dir}.tar.gz"
# fetch
urlretrieve(
f"http://download.tensorflow.org/models/object_detection/{self.file_name}.tar.gz",
zip_file,
)
# unzip
tar = tarfile.open(zip_file)
tar.extractall(PRE_TRAINED_MODELS_DIR)
tar.close()
def configure(self, record: Records):
config = self.config_path.read_text()
# fine_tune_checkpoint
config = re.sub(
'fine_tune_checkpoint: ".*?"',
f'fine_tune_checkpoint: "{self.checkpoint_path}"',
config,
)
# tfrecord files train and test.
config = re.sub(
'(input_path: ".*?)(train.record)(.*?")',
f'input_path: "{record.train_path}"',
config,
)
config = re.sub(
'(input_path: ".*?)(val.record)(.*?")',
f'input_path: "{record.test_path}"',
config,
)
# label_map_path
config = re.sub(
'label_map_path: ".*?"', f'label_map_path: "{LABEL_MAP_PATH}"', config
)
# Set training batch_size.
config = re.sub("batch_size: [0-9]+", f"batch_size: {self.batch_size}", config)
# Set training steps, num_steps
config = re.sub("num_steps: [0-9]+", f"num_steps: {10_000}", config)
# Set number of classes num_classes.
config = re.sub("num_classes: [0-9]+", f"num_classes: {5}", config)
self.config_path.write_text(config)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("model_name")
parser.add_argument("dset_name")
model = Models.SSD_MOBILENET_V2
model.download()
model.configure(Records.CENTRAL_CHINA)
subprocess.check_output(
[
"poetry",
"run",
"python",
f"{ROOT_DIR}/models/research/object_detection/model_main.py",
f"--pipeline_config_path={model.config_path}",
f"--model_dir={model.training_dir}",
f"--sample_1_of_n_eval_samples=1",
f"--alsologtostderr",
],
stderr=subprocess.STDOUT,
)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment