Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

madcat arabic scripts cleaning and adding higher order language model #2718

Merged
merged 9 commits into from
Sep 19, 2018
1 change: 1 addition & 0 deletions egs/madcat_ar/v1/local/chain/run_cnn.sh
1 change: 1 addition & 0 deletions egs/madcat_ar/v1/local/chain/run_cnn_chainali.sh
1 change: 1 addition & 0 deletions egs/madcat_ar/v1/local/chain/run_cnn_e2eali.sh
1 change: 1 addition & 0 deletions egs/madcat_ar/v1/local/chain/run_e2e_cnn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ if [ $stage -le 5 ]; then
--trainer.srand=$srand \
--trainer.max-param-change=2.0 \
--trainer.num-epochs=4 \
--trainer.frames-per-iter=1000000 \
--trainer.frames-per-iter=2000000 \
--trainer.optimization.num-jobs-initial=3 \
--trainer.optimization.num-jobs-final=16 \
--trainer.optimization.initial-effective-lrate=0.001 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@

# local/chain/compare_wer.sh exp/chain/e2e_cnn_1a
# System e2e_cnn_1a
# WER 10.71
# CER 2.85
# Final train prob -0.0859
# Final valid prob -0.1266
# WER 7.81
# CER 2.05
# Final train prob -0.0812
# Final valid prob -0.0708
# Final train prob (xent)
# Final valid prob (xent)
# Parameters 2.94M

# steps/info/chain_dir_info.pl exp/chain/e2e_cnn_1a/
# exp/chain/e2e_cnn_1a/: num-iters=195 nj=6..16 num-params=2.9M dim=40->324 combine=-0.065->-0.064 (over 5) logprob:train/valid[129,194,final]=(-0.078,-0.077,-0.086/-0.129,-0.126,-0.127)
# exp/chain/e2e_cnn_1a/: num-iters=98 nj=6..16 num-params=2.9M dim=40->330 combine=-0.073->-0.073 (over 2) logprob:train/valid[64,97,final]=(-0.084,-0.080,-0.081/-0.073,-0.070,-0.071)

set -e

Expand All @@ -33,7 +33,7 @@ num_jobs_final=16
minibatch_size=150=128,64/300=128,64/600=64,32/1200=32,16
common_egs_dir=
l2_regularize=0.00005
frames_per_iter=1000000
frames_per_iter=2000000
cmvn_opts="--norm-means=true --norm-vars=true"
train_set=train
lang_test=lang_test
Expand Down Expand Up @@ -125,6 +125,7 @@ if [ $stage -le 3 ]; then
--egs.opts "--num_egs_diagnostic 100 --num_utts_subset 400" \
--chain.frame-subsampling-factor 4 \
--chain.alignment-subsampling-factor 4 \
--chain.lm-opts="--ngram-order=2 --no-prune-ngram-order=1 --num-extra-lm-states=1000" \
--trainer.add-option="--optimization.memory-compression-level=2" \
--trainer.num-chunk-per-minibatch $minibatch_size \
--trainer.frames-per-iter $frames_per_iter \
Expand Down
64 changes: 9 additions & 55 deletions egs/madcat_ar/v1/local/create_line_image_from_page_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
help='Path to the downloaded (and extracted) writing conditions file 3')
parser.add_argument('--padding', type=int, default=400,
help='padding across horizontal/verticle direction')
parser.add_argument("--subset", type=lambda x: (str(x).lower()=='true'), default=False,
help="only processes subset of data based on writing condition")
args = parser.parse_args()

"""
Expand Down Expand Up @@ -209,50 +211,6 @@ def get_orientation(origin, p1, p2):
return difference


def compute_hull(points):
"""
Given input list of points, return a list of points that
made up the convex hull.
Returns
-------
[(float, float)]: convexhull points
"""
hull_points = []
start = points[0]
min_x = start[0]
for p in points[1:]:
if p[0] < min_x:
min_x = p[0]
start = p

point = start
hull_points.append(start)

far_point = None
while far_point is not start:
p1 = None
for p in points:
if p is point:
continue
else:
p1 = p
break

far_point = p1

for p2 in points:
if p2 is point or p2 is p1:
continue
else:
direction = get_orientation(point, far_point, p2)
if direction > 0:
far_point = p2

hull_points.append(far_point)
point = far_point
return hull_points


def minimum_bounding_box(points):
""" Given a list of 2D points, it returns the minimum area rectangle bounding all
the points in the point cloud.
Expand All @@ -272,7 +230,6 @@ def minimum_bounding_box(points):

hull_ordered = [points[index] for index in ConvexHull(points).vertices]
hull_ordered.append(hull_ordered[0])
#hull_ordered = compute_hull(points)
hull_ordered = tuple(hull_ordered)

min_rectangle = bounding_area(0, hull_ordered)
Expand Down Expand Up @@ -535,16 +492,14 @@ def check_writing_condition(wc_dict, base_name):
Returns
(bool): True if writing condition matches.
"""
return True
writing_condition = wc_dict[base_name].strip()
if writing_condition != 'IUC':
return False

return True

if args.subset:
writing_condition = wc_dict[base_name].strip()
if writing_condition != 'IUC':
return False
else:
return True

### main ###

def main():

wc_dict1 = parse_writing_conditions(args.writing_condition1)
Expand All @@ -564,8 +519,7 @@ def main():
madcat_file_path, image_file_path, wc_dict = check_file_location(base_name, wc_dict1, wc_dict2, wc_dict3)
if wc_dict is None or not check_writing_condition(wc_dict, base_name):
continue
if madcat_file_path is not None:
get_line_images_from_page_image(image_file_path, madcat_file_path, image_fh)
get_line_images_from_page_image(image_file_path, madcat_file_path, image_fh)


if __name__ == '__main__':
Expand Down
4 changes: 4 additions & 0 deletions egs/madcat_ar/v1/local/extract_features.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#!/bin/bash

# Copyright 2017 Yiwen Shao
# 2018 Ashish Arora

# Apache 2.0
# This script runs the make features script in parallel.

nj=4
cmd=run.pl
feat_dim=40
Expand Down
72 changes: 29 additions & 43 deletions egs/madcat_ar/v1/local/process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
help='Path to the downloaded (and extracted) writing conditions file 2')
parser.add_argument('writing_condition3', type=str,
help='Path to the downloaded (and extracted) writing conditions file 3')
parser.add_argument("--subset", type=lambda x: (str(x).lower()=='true'), default=False,
help="only processes subset of data based on writing condition")
args = parser.parse_args()


Expand Down Expand Up @@ -97,50 +99,40 @@ def check_writing_condition(wc_dict):
Returns:
(bool): True if writing condition matches.
"""
return True
writing_condition = wc_dict[base_name].strip()
if writing_condition != 'IUC':
return False
if args.subset:
writing_condition = wc_dict[base_name].strip()
if writing_condition != 'IUC':
return False
else:
return True

return True


def get_word_line_mapping(madcat_file_path):
def read_text(madcat_file_path):
""" Maps every word in the page image to a corresponding line.
Args:
madcat_file_path (string): complete path and name of the madcat xml file
madcat_file_path (string): complete path and name of the madcat xml file
corresponding to the page image.
Returns:
dict: Mapping every word in the page image to a corresponding line.
"""

word_line_dict = dict()
doc = minidom.parse(madcat_file_path)
zone = doc.getElementsByTagName('zone')
for node in zone:
line_id = node.getAttribute('id')
line_word_dict[line_id] = list()
word_image = node.getElementsByTagName('token-image')
for tnode in word_image:
word_id = tnode.getAttribute('id')
line_word_dict[line_id].append(word_id)
word_line_dict[word_id] = line_id


def read_text(madcat_file_path):
""" Maps every word in the page image to a corresponding line.
Args:
madcat_file_path (string): complete path and name of the madcat xml file
corresponding to the page image.
Returns:
dict: Mapping every word in the page image to a corresponding line.
"""
text_line_word_dict = dict()
doc = minidom.parse(madcat_file_path)
segment = doc.getElementsByTagName('segment')
for node in segment:
token = node.getElementsByTagName('token')
for tnode in token:
ref_word_id = tnode.getAttribute('ref_id')
word = tnode.getElementsByTagName('source')[0].firstChild.nodeValue
word = unicodedata.normalize('NFKC',word)
ref_line_id = word_line_dict[ref_word_id]
if ref_line_id not in text_line_word_dict:
text_line_word_dict[ref_line_id] = list()
Expand All @@ -160,7 +152,6 @@ def get_line_image_location():


### main ###

print("Processing '{}' data...".format(args.out_dir))

text_file = os.path.join(args.out_dir, 'text')
Expand Down Expand Up @@ -188,24 +179,19 @@ def get_line_image_location():
madcat_xml_path, image_file_path, wc_dict = check_file_location()
if wc_dict is None or not check_writing_condition(wc_dict):
continue
if madcat_xml_path is not None:
madcat_doc = minidom.parse(madcat_xml_path)
writer = madcat_doc.getElementsByTagName('writer')
writer_id = writer[0].getAttribute('id')
line_word_dict = dict()
word_line_dict = dict()
get_word_line_mapping(madcat_xml_path)
text_line_word_dict = read_text(madcat_xml_path)
base_name = os.path.basename(image_file_path)
base_name, b = base_name.split('.tif')
for lineID in sorted(text_line_word_dict):
updated_base_name = base_name + '_' + str(lineID).zfill(4) +'.png'
location = image_loc_dict[updated_base_name]
image_file_path = os.path.join(location, updated_base_name)
line = text_line_word_dict[lineID]
text = ' '.join(line)
utt_id = writer_id + '_' + str(image_num).zfill(6) + '_' + base_name + '_' + str(lineID).zfill(4)
text_fh.write(utt_id + ' ' + text + '\n')
utt2spk_fh.write(utt_id + ' ' + writer_id + '\n')
image_fh.write(utt_id + ' ' + image_file_path + '\n')
image_num += 1
madcat_doc = minidom.parse(madcat_xml_path)
writer = madcat_doc.getElementsByTagName('writer')
writer_id = writer[0].getAttribute('id')
text_line_word_dict = read_text(madcat_xml_path)
base_name = os.path.basename(image_file_path).split('.tif')[0]
for lineID in sorted(text_line_word_dict):
updated_base_name = base_name + '_' + str(lineID).zfill(4) +'.png'
location = image_loc_dict[updated_base_name]
image_file_path = os.path.join(location, updated_base_name)
line = text_line_word_dict[lineID]
text = ' '.join(line)
utt_id = writer_id + '_' + str(image_num).zfill(6) + '_' + base_name + '_' + str(lineID).zfill(4)
text_fh.write(utt_id + ' ' + text + '\n')
utt2spk_fh.write(utt_id + ' ' + writer_id + '\n')
image_fh.write(utt_id + ' ' + image_file_path + '\n')
image_num += 1
4 changes: 2 additions & 2 deletions egs/madcat_ar/v1/local/score.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash


steps/scoring/score_kaldi_wer.sh --word_ins_penalty 0.0,0.5,1.0,1.5,2.0,2.5,3.0,3.5,4.0,4.5,5.0,5.5,6.0,6.5,7.0 "$@"
steps/scoring/score_kaldi_cer.sh --stage 2 --word_ins_penalty 0.0,0.5,1.0,1.5,2.0,2.5,3.0,3.5,4.0,4.5,5.0,5.5,6.0,6.5,7.0 "$@"
steps/scoring/score_kaldi_wer.sh "$@"
steps/scoring/score_kaldi_cer.sh --stage 2 "$@"
Loading