Skip to content

Commit

Permalink
Decoupling model creation from training script. Use of threads for da…
Browse files Browse the repository at this point in the history
…ta generation. Small changes in the run script.
  • Loading branch information
sergiomsilva committed Jan 15, 2019
1 parent e637dd0 commit d6eff71
Show file tree
Hide file tree
Showing 8 changed files with 377 additions and 141 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Use the script "run.sh" to run our ALPR approach. It requires 3 arguments:
* __CSV file:__ specify an output CSV file.

```shellscript
$ bash run.sh samples/test /tmp/output /tmp/output/results.csv
$ bash run.sh -i samples/test -o /tmp/output -c /tmp/output/results.csv
```

## Training the LP detector
Expand All @@ -40,10 +40,12 @@ To train the LP detector network from scratch, or fine-tuning it for new samples
The following command can be used to train the network from scratch considering the data inside the train-detector folder:

```shellscript
$ python train-detector.py --name new-network --outdir /tmp/ --input-dir samples/train-detector
$ mkdir models
$ python create-model.py eccv models/eccv-model-scracth
$ python train-detector.py --model models/eccv-model-scracth --name my-trained-model --train-dir samples/train-detector --output-dir models/my-trained-model/ -op Adam -lr .001 -its 300000 -bs 64
```

For fine-tunning, add "-m data/lp-detector/wpod-net" to the command line above.
For fine-tunning, use your model with --model option.

## A word on GPU and CPU

Expand Down
64 changes: 59 additions & 5 deletions src/detector_network.py → create-model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@

import sys
import keras

from keras.layers import Conv2D, MaxPooling2D, BatchNormalization, Add, Activation, Concatenate, Input
from keras.models import Model
from keras.applications.mobilenet import MobileNet

from src.keras_utils import save_model


def res_block(x,sz,filter_sz=3,in_conv_size=1):
Expand All @@ -21,7 +27,13 @@ def conv_batch(_input,fsz,csz,activation='relu',padding='same',strides=(1,1)):
output = Activation(activation)(output)
return output

def create_model():
def end_block(x):
xprobs = Conv2D(2, 3, activation='softmax', padding='same')(x)
xbbox = Conv2D(6, 3, activation='linear' , padding='same')(x)
return Concatenate(3)([xprobs,xbbox])


def create_model_eccv():

input_layer = Input(shape=(None,None,3),name='input')

Expand All @@ -45,9 +57,51 @@ def create_model():
x = res_block(x,128)
x = res_block(x,128)

xprobs = Conv2D(2, 3, activation='softmax', padding='same')(x)
xbbox = Conv2D(6, 3, activation='linear' , padding='same')(x)
x = end_block(x)

return Model(inputs=input_layer,outputs=x)


# Model not converging...
def create_model_mobnet():

input_layer = Input(shape=(None,None,3),name='input')
x = input_layer

mbnet = MobileNet(input_shape=(224,224,3),include_top=True)

backbone = keras.models.clone_model(mbnet)
for i,bblayer in enumerate(backbone.layers[1:74]):
layer = bblayer.__class__.from_config(bblayer.get_config())
layer.name = 'backbone_' + layer.name
x = layer(x)

x = end_block(x)

model = Model(inputs=input_layer,outputs=x)

backbone_layers = {'backbone_' + layer.name: layer for layer in backbone.layers}
for layer in model.layers:
if layer.name in backbone_layers:
print 'setting ' + layer.name
layer.set_weights(backbone_layers[layer.name].get_weights())

return model


if __name__ == '__main__':

modules = [func.replace('create_model_','') for func in dir(sys.modules[__name__]) if 'create_model_' in func]

assert sys.argv[1] in modules, \
'Model name must be on of the following: %s' % ', '.join(modules)

modelf = getattr(sys.modules[__name__],'create_model_' + sys.argv[1])

print 'Creating model %s' % sys.argv[1]
model = modelf()
print 'Finished'

x = Concatenate(3)([xprobs,xbbox])
print 'Saving at %s' % sys.argv[2]
save_model(model,sys.argv[2])

return Model(input=input_layer,output=x)
58 changes: 34 additions & 24 deletions license-plate-detection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys, os
import keras
import cv2
import traceback

from src.keras_utils import load_model
from glob import glob
Expand All @@ -15,40 +16,49 @@ def adjust_pts(pts,lroi):


if __name__ == '__main__':

input_dir = sys.argv[1]
output_dir = input_dir

lp_threshold = .5
try:

input_dir = sys.argv[1]
output_dir = input_dir

wpod_net_path = sys.argv[2]
wpod_net = load_model(wpod_net_path)
lp_threshold = .5

imgs_paths = glob('%s/*car.png' % input_dir)
wpod_net_path = sys.argv[2]
wpod_net = load_model(wpod_net_path)

print 'Searching for license plates using WPOD-NET'
imgs_paths = glob('%s/*car.png' % input_dir)

for i,img_path in enumerate(imgs_paths):
print 'Searching for license plates using WPOD-NET'

print '\t Processing %s' % img_path
for i,img_path in enumerate(imgs_paths):

bname = splitext(basename(img_path))[0]
Ivehicle = cv2.imread(img_path)
print '\t Processing %s' % img_path

ratio = float(max(Ivehicle.shape[:2]))/min(Ivehicle.shape[:2])
side = int(ratio*288.)
bound_dim = min(side + (side%(2**4)),608)
print "\t\tBound dim: %d, ratio: %f" % (bound_dim,ratio)
bname = splitext(basename(img_path))[0]
Ivehicle = cv2.imread(img_path)

Llp,LlpImgs,_ = detect_lp(wpod_net,im2single(Ivehicle),bound_dim,2**4,(240,80),lp_threshold)
ratio = float(max(Ivehicle.shape[:2]))/min(Ivehicle.shape[:2])
side = int(ratio*288.)
bound_dim = min(side + (side%(2**4)),608)
print "\t\tBound dim: %d, ratio: %f" % (bound_dim,ratio)

if len(LlpImgs):
Ilp = LlpImgs[0]
Ilp = cv2.cvtColor(Ilp, cv2.COLOR_BGR2GRAY)
Ilp = cv2.cvtColor(Ilp, cv2.COLOR_GRAY2BGR)
Llp,LlpImgs,_ = detect_lp(wpod_net,im2single(Ivehicle),bound_dim,2**4,(240,80),lp_threshold)

s = Shape(Llp[0].pts)
if len(LlpImgs):
Ilp = LlpImgs[0]
Ilp = cv2.cvtColor(Ilp, cv2.COLOR_BGR2GRAY)
Ilp = cv2.cvtColor(Ilp, cv2.COLOR_GRAY2BGR)

s = Shape(Llp[0].pts)

cv2.imwrite('%s/%s_lp.png' % (output_dir,bname),Ilp*255.)
writeShapes('%s/%s_lp.txt' % (output_dir,bname),[s])

except:
traceback.print_exc()
sys.exit(1)

sys.exit(0)

cv2.imwrite('%s/%s_lp.png' % (output_dir,bname),Ilp*255.)
writeShapes('%s/%s_lp.txt' % (output_dir,bname),[s])

57 changes: 33 additions & 24 deletions license-plate-ocr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
import cv2
import numpy as np
import traceback

import darknet.python.darknet as dn

Expand All @@ -12,44 +13,52 @@


if __name__ == '__main__':

try:

input_dir = sys.argv[1]
output_dir = input_dir
input_dir = sys.argv[1]
output_dir = input_dir

ocr_threshold = .4

ocr_weights = 'data/ocr/ocr-net.weights'
ocr_netcfg = 'data/ocr/ocr-net.cfg'
ocr_dataset = 'data/ocr/ocr-net.data'

ocr_threshold = .4
ocr_net = dn.load_net(ocr_netcfg, ocr_weights, 0)
ocr_meta = dn.load_meta(ocr_dataset)

ocr_weights = 'data/ocr/ocr-net.weights'
ocr_netcfg = 'data/ocr/ocr-net.cfg'
ocr_dataset = 'data/ocr/ocr-net.data'
imgs_paths = sorted(glob('%s/*lp.png' % output_dir))

ocr_net = dn.load_net(ocr_netcfg, ocr_weights, 0)
ocr_meta = dn.load_meta(ocr_dataset)
print 'Performing OCR...'

imgs_paths = sorted(glob('%s/*lp.png' % output_dir))
for i,img_path in enumerate(imgs_paths):

print 'Performing OCR...'
print '\tScanning %s' % img_path

for i,img_path in enumerate(imgs_paths):
bname = basename(splitext(img_path)[0])

print '\tScanning %s' % img_path
R,(width,height) = detect(ocr_net, ocr_meta, img_path ,thresh=ocr_threshold, nms=None)

bname = basename(splitext(img_path)[0])
if len(R):

R,(width,height) = detect(ocr_net, ocr_meta, img_path ,thresh=ocr_threshold, nms=None)
L = dknet_label_conversion(R,width,height)
L = nms(L,.45)

if len(R):
L.sort(key=lambda x: x.tl()[0])
lp_str = ''.join([chr(l.cl()) for l in L])

L = dknet_label_conversion(R,width,height)
L = nms(L,.45)
with open('%s/%s_str.txt' % (output_dir,bname),'w') as f:
f.write(lp_str + '\n')

L.sort(key=lambda x: x.tl()[0])
lp_str = ''.join([chr(l.cl()) for l in L])
print '\t\tLP: %s' % lp_str

with open('%s/%s_str.txt' % (output_dir,bname),'w') as f:
f.write(lp_str + '\n')
else:

print '\t\tLP: %s' % lp_str
print 'No characters found'

else:
except:
traceback.print_exc()
sys.exit(1)

print 'No characters found'
sys.exit(0)
73 changes: 46 additions & 27 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,61 +27,80 @@ retval=$?
if [ $retval -eq 0 ]
then
echo "Darknet is not compiled! Go to 'darknet' directory and 'make'!"
exit 0
exit 1
fi

lp_model="data/lp-detector/wpod-net.h5"
input_dir=''
output_dir=''
csv_file=''


# Check # of arguments
if [ ! $# -eq 3 ]
then
usage() {
echo ""
echo " Usage:"
echo ""
echo " Required arguments:"
echo " bash $0 -i input/dir -o output/dir -c csv_file.csv [-h] [-l path/to/model]:"
echo ""
echo " 1. Input dir path (containing JPG or PNG images)"
echo " 2. Output dir path"
echo " 3. Output CSV file path"
echo " -i Input dir path (containing JPG or PNG images)"
echo " -o Output dir path"
echo " -c Output CSV file path"
echo " -l Path to Keras LP detector model (default = $lp_model)"
echo " -h Print this help information"
echo ""
exit 1
fi

read -n1 -r -p "Press any key to continue..." key

}

# Download all networks
bash get-networks.sh
while getopts 'i:o:c:l:h' OPTION; do
case $OPTION in
i) input_dir=$OPTARG;;
o) output_dir=$OPTARG;;
c) csv_file=$OPTARG;;
l) lp_model=$OPTARG;;
h) usage;;
esac
done

if [ -z "$input_dir" ]; then echo "Input dir not set."; usage; exit 1; fi
if [ -z "$output_dir" ]; then echo "Ouput dir not set."; usage; exit 1; fi
if [ -z "$csv_file" ]; then echo "CSV file not set." ; usage; exit 1; fi

# Check if input dir exists
check_dir $1
check_dir $input_dir
retval=$?
if [ $retval -eq 0 ]
then
echo "Input directory ($1) does not exist"
exit 0
echo "Input directory ($input_dir) does not exist"
exit 1
fi

# Check if output dir exists, if not, create it
check_dir $2
check_dir $output_dir
retval=$?
if [ $retval -eq 0 ]
then
mkdir -p $2
mkdir -p $output_dir
fi

# End if any error occur
set -e

# Detect vehicles
python vehicle-detection.py $1 $2
python vehicle-detection.py $input_dir $output_dir

# Detect license plates
python license-plate-detection.py $2 data/lp-detector/wpod-net.h5
python license-plate-detection.py $output_dir $lp_model

# OCR
python license-plate-ocr.py $2
python license-plate-ocr.py $output_dir

# Draw output and generate list
python gen-outputs.py $1 $2 > $3
python gen-outputs.py $input_dir $output_dir > $csv_file

# Clean files and draw output
rm $2/*_lp.png
rm $2/*car.png
rm $2/*_cars.txt
rm $2/*_lp.txt
rm $2/*_str.txt
rm $output_dir/*_lp.png
rm $output_dir/*car.png
rm $output_dir/*_cars.txt
rm $output_dir/*_lp.txt
rm $output_dir/*_str.txt
Loading

0 comments on commit d6eff71

Please sign in to comment.