Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
PolySTAR
RoboMaster
Computer Vision
Tensorflow Object Detection finetuning
Commits
88db3bad
Commit
88db3bad
authored
Mar 08, 2020
by
Mathieu Beligon
Browse files
[fine-tuning] (script) add first draft of the script to fine tune
parent
8daf9a24
Changes
1
Hide whitespace changes
Inline
Side-by-side
fine_tune_model.py
0 → 100644
View file @
88db3bad
import
re
import
subprocess
import
tarfile
from
argparse
import
ArgumentParser
from
enum
import
Enum
from
pathlib
import
Path
from
urllib.request
import
urlretrieve
ROOT_DIR
:
Path
=
Path
(
__file__
).
parent
TF_RECORDS_DIR
:
Path
=
ROOT_DIR
/
"dataset"
PRE_TRAINED_MODELS_DIR
:
Path
=
ROOT_DIR
/
"pre-trained-models"
TRAINING_DIR
:
Path
=
ROOT_DIR
/
"training"
CONFIGS_DIR
:
Path
=
ROOT_DIR
/
"models"
/
"research"
/
"object_detection"
/
"samples"
/
"configs/"
LABEL_MAP_PATH
:
Path
=
ROOT_DIR
/
"label_map.pbtxt"
class
Records
(
Enum
):
def
__init__
(
self
,
train_file
:
str
,
val_file
:
str
):
self
.
train_path
=
TF_RECORDS_DIR
/
train_file
self
.
test_path
=
TF_RECORDS_DIR
/
val_file
CENTRAL_CHINA
=
(
"CentralChina_Train.record"
,
"CentralChina_Val.record"
)
DJI_ROCO
=
(
"DJI_ROCO_Train.record"
,
"DJI_ROCO_Val.record"
)
SMALL
=
(
"small_train.record"
,
"small_test.record"
)
class
Models
(
Enum
):
SSD_MOBILENET_V2
=
(
"ssd_mobilenet_v2_coco_2018_03_29"
,
"ssd_mobilenet_v2_coco.config"
,
12
,
)
FASTER_RCNN_INCEPTION_V2
=
(
"faster_rcnn_inception_v2_coco_2018_01_28"
,
"faster_rcnn_inception_v2_pets.config"
,
12
,
)
RFCN_RESENET101
=
(
"rfcn_resnet101_coco_2018_01_28"
,
"rfcn_resnet101_pets.config"
,
8
,
)
def
__init__
(
self
,
file_name
:
str
,
config_name
:
str
,
batch_size
:
int
):
self
.
file_name
=
file_name
self
.
config_name
=
config_name
self
.
batch_size
=
batch_size
@
property
def
pre_trained_dir
(
self
)
->
Path
:
return
PRE_TRAINED_MODELS_DIR
/
self
.
file_name
@
property
def
checkpoint_path
(
self
)
->
Path
:
return
self
.
pre_trained_dir
/
"model.ckpt"
@
property
def
config_path
(
self
)
->
Path
:
return
CONFIGS_DIR
/
self
.
config_name
@
property
def
training_dir
(
self
)
->
Path
:
return
TRAINING_DIR
/
self
.
file_name
def
download
(
self
):
if
self
.
training_dir
.
exists
():
print
(
f
"model
{
self
.
file_name
}
already downloaded"
)
return
zip_file
=
f
"
{
self
.
training_dir
}
.tar.gz"
# fetch
urlretrieve
(
f
"http://download.tensorflow.org/models/object_detection/
{
self
.
file_name
}
.tar.gz"
,
zip_file
,
)
# unzip
tar
=
tarfile
.
open
(
zip_file
)
tar
.
extractall
(
PRE_TRAINED_MODELS_DIR
)
tar
.
close
()
def
configure
(
self
,
record
:
Records
):
config
=
self
.
config_path
.
read_text
()
# fine_tune_checkpoint
config
=
re
.
sub
(
'fine_tune_checkpoint: ".*?"'
,
f
'fine_tune_checkpoint: "
{
self
.
checkpoint_path
}
"'
,
config
,
)
# tfrecord files train and test.
config
=
re
.
sub
(
'(input_path: ".*?)(train.record)(.*?")'
,
f
'input_path: "
{
record
.
train_path
}
"'
,
config
,
)
config
=
re
.
sub
(
'(input_path: ".*?)(val.record)(.*?")'
,
f
'input_path: "
{
record
.
test_path
}
"'
,
config
,
)
# label_map_path
config
=
re
.
sub
(
'label_map_path: ".*?"'
,
f
'label_map_path: "
{
LABEL_MAP_PATH
}
"'
,
config
)
# Set training batch_size.
config
=
re
.
sub
(
"batch_size: [0-9]+"
,
f
"batch_size:
{
self
.
batch_size
}
"
,
config
)
# Set training steps, num_steps
config
=
re
.
sub
(
"num_steps: [0-9]+"
,
f
"num_steps:
{
10_000
}
"
,
config
)
# Set number of classes num_classes.
config
=
re
.
sub
(
"num_classes: [0-9]+"
,
f
"num_classes:
{
5
}
"
,
config
)
self
.
config_path
.
write_text
(
config
)
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
()
parser
.
add_argument
(
"model_name"
)
parser
.
add_argument
(
"dset_name"
)
model
=
Models
.
SSD_MOBILENET_V2
model
.
download
()
model
.
configure
(
Records
.
CENTRAL_CHINA
)
subprocess
.
check_output
(
[
"poetry"
,
"run"
,
"python"
,
f
"
{
ROOT_DIR
}
/models/research/object_detection/model_main.py"
,
f
"--pipeline_config_path=
{
model
.
config_path
}
"
,
f
"--model_dir=
{
model
.
training_dir
}
"
,
f
"--sample_1_of_n_eval_samples=1"
,
f
"--alsologtostderr"
,
],
stderr
=
subprocess
.
STDOUT
,
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment