Skip to content

Commit

Permalink
[scripts] Enhancements & minor bugfix to segmentation postprocessing (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
stanleyguan authored and danpovey committed Oct 11, 2018
1 parent 43ec82e commit 535bb2c
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 25 deletions.
11 changes: 10 additions & 1 deletion egs/wsj/s5/steps/segmentation/detect_speech_activity.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,15 @@ acwt=0.3
# e.g. --speech-in-sil-weight=0.0 --garbage-in-sil-weight=0.0 --sil-in-speech-weight=0.0 --garbage-in-speech-weight=0.3
transform_probs_opts=""

# Postprocessing options
segment_padding=0.2 # Duration (in seconds) of padding added to segments
min_segment_dur=0 # Minimum duration (in seconds) required for a segment to be included
# This is before any padding. Segments shorter than this duration will be removed.
# This is an alternative to --min-speech-duration above.
merge_consecutive_max_dur=0 # Merge consecutive segments as long as the merged segment is no longer than this many
# seconds. The segments are only merged if their boundaries are touching.
# This is after padding by --segment-padding seconds.
# 0 means do not merge. Use 'inf' to not limit the duration.

echo $*

Expand Down Expand Up @@ -225,7 +233,8 @@ fi

if [ $stage -le 7 ]; then
steps/segmentation/post_process_sad_to_segments.sh \
--segment-padding $segment_padding \
--segment-padding $segment_padding --min-segment-dur $min_segment_dur \
--merge-consecutive-max-dur $merge_consecutive_max_dur \
--cmd "$cmd" --frame-shift $(perl -e "print $frame_subsampling_factor * $frame_shift") \
${test_data_dir} ${seg_dir} ${seg_dir}
fi
Expand Down
126 changes: 102 additions & 24 deletions egs/wsj/s5/steps/segmentation/internal/sad_to_segments.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python

# Copyright 2017 Vimal Manohar
# 2018 Capital One (Author: Zhiyuan Guan)
# Apache 2.0

"""
Expand Down Expand Up @@ -29,6 +30,7 @@

global_verbose = 0


def get_args():
parser = argparse.ArgumentParser(
description="""
Expand All @@ -44,18 +46,31 @@ def get_args():

parser.add_argument("--utt2dur", type=str,
help="File containing durations of utterances.")

parser.add_argument("--frame-shift", type=float, default=0.01,
help="Frame shift to convert frame indexes to time")

parser.add_argument("--segment-padding", type=float, default=0.2,
help="Additional padding on speech segments. But we "
"ensure that the padding does not go beyond the "
"adjacent segment.")
"ensure that the padding does not go beyond the "
"adjacent segment.")

parser.add_argument("--min-segment-dur", type=float, default=0,
help="Minimum duration (in seconds) required for a segment "
"to be included. This is before any padding. Segments "
"shorter than this duration will be removed.")

parser.add_argument("--merge-consecutive-max-dur", type=float, default=0,
help="Merge consecutive segments as long as the merged "
"segment is no longer than this many seconds. The segments "
"are only merged if their boundaries are touching. "
"This is after padding by --segment-padding seconds."
"0 means do not merge. Use 'inf' to not limit the duration.")

parser.add_argument("in_sad", type=str,
help="Input file containing alignments in "
"text archive format")
"text archive format")

parser.add_argument("out_segments", type=str,
help="Output kaldi segments file")

Expand All @@ -80,28 +95,45 @@ def to_str(segment):

class SegmenterStats(object):
"""Stores stats about the post-process stages"""

def __init__(self):
self.num_segments = 0
self.num_segments_initial = 0
self.num_short_segments_filtered = 0
self.num_merges = 0
self.num_segments_final = 0
self.initial_duration = 0.0
self.padding_duration = 0.0
self.filter_short_duration = 0.0
self.final_duration = 0.0

def add(self, other):
"""Adds stats from another object"""
self.num_segments += other.num_segments
self.num_segments_initial += other.num_segments_initial
self.num_short_segments_filtered += other.num_short_segments_filtered
self.num_merges += other.num_merges
self.num_segments_final += other.num_segments_final
self.initial_duration += other.initial_duration
self.padding_duration = other.padding_duration
self.final_duration = other.final_duration
self.filter_short_duration += other.filter_short_duration
self.padding_duration += other.padding_duration
self.final_duration += other.final_duration

def __str__(self):
return ("num-segments={num_segments}, "
return ("num-segments-initial={num_segments_initial}, "
"num-short-segments-filtered={num_short_segments_filtered}, "
"num-merges={num_merges}, "
"num-segments-final={num_segments_final}, "
"initial-duration={initial_duration}, "
"filter-short-duration={filter_short_duration}, "
"padding-duration={padding_duration}, "
"final-duration={final_duration}".format(
num_segments=self.num_segments,
initial_duration=self.initial_duration,
padding_duration=self.padding_duration,
final_duration=self.final_duration))
num_segments_initial=self.num_segments_initial,
num_short_segments_filtered=self.num_short_segments_filtered,
num_merges=self.num_merges,
num_segments_final=self.num_segments_final,
initial_duration=self.initial_duration,
filter_short_duration=self.filter_short_duration,
padding_duration=self.padding_duration,
final_duration=self.final_duration))


def process_label(text_label):
Expand All @@ -114,13 +146,14 @@ def process_label(text_label):
prev_label = int(text_label)
if prev_label not in [1, 2]:
raise ValueError("Expecting label to 1 (non-speech) or 2 (speech); "
"got {0}".format(prev_label))
"got {}".format(prev_label))

return prev_label


class Segmentation(object):
"""Stores segmentation for an utterances"""

def __init__(self):
self.segments = None
self.stats = SegmenterStats()
Expand All @@ -143,8 +176,8 @@ def initialize_segments(self, alignment, frame_shift=0.01):
float(i) * frame_shift, prev_label])

prev_label = process_label(text_label)
prev_length = 0
self.stats.initial_duration += (prev_length * frame_shift)
prev_length = 0
elif prev_label is None:
prev_label = process_label(text_label)

Expand All @@ -156,7 +189,27 @@ def initialize_segments(self, alignment, frame_shift=0.01):
float(len(alignment)) * frame_shift, prev_label])
self.stats.initial_duration += (prev_length * frame_shift)

self.stats.num_segments = len(self.segments)
self.stats.num_segments_initial = len(self.segments)
self.stats.num_segments_final = len(self.segments)
self.stats.final_duration = self.stats.initial_duration

def filter_short_segments(self, min_dur):
"""Filters out segments with durations shorter than 'min_dur'."""
if min_dur <= 0:
return

segments_kept = []
for segment in self.segments:
assert segment[2] == 2, segment
dur = segment[1] - segment[0]
if dur < min_dur:
self.stats.filter_short_duration += dur
self.stats.num_short_segments_filtered += 1
else:
segments_kept.append(segment)
self.segments = segments_kept
self.stats.num_segments_final = len(self.segments)
self.stats.final_duration -= self.stats.filter_short_duration

def pad_speech_segments(self, segment_padding, max_duration=float("inf")):
"""Pads segments by duration 'segment_padding' on either sides, but
Expand All @@ -166,19 +219,19 @@ def pad_speech_segments(self, segment_padding, max_duration=float("inf")):
max_duration = float("inf")
for i, segment in enumerate(self.segments):
assert segment[2] == 2, segment
segment[0] -= segment_padding # try adding padding on the left side
segment[0] -= segment_padding # try adding padding on the left side
self.stats.padding_duration += segment_padding
if segment[0] < 0.0:
# Padding takes the segment start to before the beginning of the utterance.
# Reduce padding.
self.stats.padding_duration += segment[0]
segment[0] = 0.0
if i >= 1 and self.segments[i-1][1] > segment[0]:
if i >= 1 and self.segments[i - 1][1] > segment[0]:
# Padding takes the segment start to before the end the previous segment.
# Reduce padding.
self.stats.padding_duration -= (
self.segments[i-1][1] - segment[0])
segment[0] = self.segments[i-1][1]
self.segments[i - 1][1] - segment[0])
segment[0] = self.segments[i - 1][1]

segment[1] += segment_padding
self.stats.padding_duration += segment_padding
Expand All @@ -188,12 +241,35 @@ def pad_speech_segments(self, segment_padding, max_duration=float("inf")):
self.stats.padding_duration -= (segment[1] - max_duration)
segment[1] = max_duration
if (i + 1 < len(self.segments)
and segment[1] > self.segments[i+1][0]):
and segment[1] > self.segments[i + 1][0]):
# Padding takes the segment end beyond the start of the next segment.
# Reduce padding.
self.stats.padding_duration -= (
segment[1] - self.segments[i+1][0])
segment[1] = self.segments[i+1][0]
segment[1] - self.segments[i + 1][0])
segment[1] = self.segments[i + 1][0]
self.stats.final_duration += self.stats.padding_duration

def merge_consecutive_segments(self, max_dur):
"""Merge consecutive segments (happens after padding), provided that
the merged segment is no longer than 'max_dur'."""
if max_dur <= 0 or not self.segments:
return

merged_segments = [self.segments[0]]
for segment in self.segments[1:]:
assert segment[2] == 2, segment
if segment[0] == merged_segments[-1][1] and \
segment[1] - merged_segments[-1][1] <= max_dur:
# The segment starts at the same time the last segment ends,
# and the merged segment is shorter than 'max_dur'.
# Extend the previous segment.
merged_segments[-1][1] = segment[1]
self.stats.num_merges += 1
else:
merged_segments.append(segment)

self.segments = merged_segments
self.stats.num_segments_final = len(self.segments)

def write(self, key, file_handle):
"""Write segments to file"""
Expand All @@ -203,9 +279,9 @@ def write(self, key, file_handle):
for segment in self.segments:
seg_id = "{key}-{st:07d}-{end:07d}".format(
key=key, st=int(segment[0] * 100), end=int(segment[1] * 100))
print ("{seg_id} {key} {st:.2f} {end:.2f}".format(
print("{seg_id} {key} {st:.2f} {end:.2f}".format(
seg_id=seg_id, key=key, st=segment[0], end=segment[1]),
file=file_handle)
file=file_handle)


def run(args):
Expand Down Expand Up @@ -235,9 +311,11 @@ def run(args):
segmentation = Segmentation()
segmentation.initialize_segments(
parts[1:], args.frame_shift)
segmentation.filter_short_segments(args.min_segment_dur)
segmentation.pad_speech_segments(args.segment_padding,
None if args.utt2dur is None
else utt2dur[utt_id])
segmentation.merge_consecutive_segments(args.merge_consecutive_max_dur)
segmentation.write(utt_id, out_segments_fh)
global_stats.add(segmentation.stats)
logger.info(global_stats)
Expand Down
3 changes: 3 additions & 0 deletions egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ nj=18
# The values below are in seconds
frame_shift=0.01
segment_padding=0.2
min_segment_dur=0
merge_consecutive_max_dur=0

. utils/parse_options.sh

Expand Down Expand Up @@ -53,6 +55,7 @@ if [ $stage -le 0 ]; then
copy-int-vector "ark:gunzip -c $vad_dir/ali.JOB.gz |" ark,t:- \| \
steps/segmentation/internal/sad_to_segments.py \
--frame-shift=$frame_shift --segment-padding=$segment_padding \
--min-segment-dur=$min_segment_dur --merge-consecutive-max-dur=$merge_consecutive_max_dur \
--utt2dur=$data_dir/utt2dur - $dir/segments.JOB
fi

Expand Down

0 comments on commit 535bb2c

Please sign in to comment.