Skip to content

Commit

Permalink
Added parameter search for kenlm decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
louiskirsch committed Apr 19, 2017
1 parent 14b1b5c commit 63a5c6a
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 14 deletions.
21 changes: 17 additions & 4 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from lazy import lazy

from evaluation import Evaluation
from parameter_search import LanguageModelParameterSearch
from recording import Recording
from training import Training

Expand All @@ -33,7 +34,7 @@ def __init__(self):
self._add_training_parser()
self._add_evaluation_parser()
self._add_recording_parser()
self._add_local_search_parser()
self._add_parameter_search_parser()

def _create_base_parser(self):
base_parser = argparse.ArgumentParser(add_help=False)
Expand Down Expand Up @@ -102,8 +103,17 @@ def _add_recording_parser(self):
help='The input size of each sample, depending on what preprocessing was used')
self._add_language_model_argument(recording_parser)

def _add_local_search_parser(self):
pass
def _add_parameter_search_parser(self):
parameter_search_parser = self.subparsers.add_parser('search', help='Search for language model hyper parameters'
'using local search.',
parents=[self.base_parser])
parameter_search_parser.add_argument('--population-size', dest='population_size', type=int, default=10,
help='The size of the population for the local search.')
parameter_search_parser.add_argument('--noise-std', dest='noise_std', type=float, default=0.5,
help='The standard deviation of the normal noise for mutation.')
parameter_search_parser.add_argument('--ui', dest='use_ui', action='store_true',
help='Whether to use an UI to print results.')
self._add_language_model_argument(parameter_search_parser)

@lazy
def parsed(self):
Expand All @@ -115,6 +125,8 @@ def parsed(self):
parsed.run_type = parsed.dataset
elif parsed.command == 'record':
parsed.run_type = 'record'
else:
parsed.run_type = 'other'

parsed.run_train_dir = parsed.train_dir + '/' + parsed.run_name

Expand All @@ -125,7 +137,8 @@ def command_executor(self):
return {
'train': Training,
'evaluate': Evaluation,
'record': Recording
'record': Recording,
'search': LanguageModelParameterSearch
}[self.parsed.command](self.parsed)

def run(self):
Expand Down
20 changes: 12 additions & 8 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Dict

import editdistance
import numpy as np
Expand Down Expand Up @@ -95,7 +96,8 @@ def run(self):
if coord.should_stop():
break

self.run_epoch(epoch, model, sess, stats)
should_save = self.flags.should_save and epoch == 0
self.run_epoch(should_save, model, sess, stats)

self.print_global_statistics(stats)

Expand All @@ -114,16 +116,18 @@ def print_global_statistics(stats):
stats.global_word_edit_distance,
stats.global_word_error_rate))

def run_epoch(self, epoch: int, model: SpeechModel, sess: tf.Session, stats: EvalStatistics, verbose=True):
def run_epoch(self, model: SpeechModel, sess: tf.Session, stats: EvalStatistics,
save: bool, verbose=True, feed_dict: Dict=None):
global_step = model.global_step.eval()

# Validate on development set and write summary
if not self.flags.should_save or epoch > 0:
avg_loss, decoded, label = model.step(sess, update=False, decode=True, return_label=True)
else:
avg_loss, decoded, label, summary = model.step(sess, update=False, decode=True,
return_label=True, summary=True)
# Validate on data set and write summary
if save:
avg_loss, decoded, label, summary = model.step(sess, update=False, decode=True, return_label=True,
summary=True, feed_dict=feed_dict)
model.summary_writer.add_summary(summary, global_step)
else:
avg_loss, decoded, label = model.step(sess, update=False, decode=True,
return_label=True, feed_dict=feed_dict)

if verbose:
perplexity = np.exp(float(avg_loss)) if avg_loss < 300 else float("inf")
Expand Down
129 changes: 129 additions & 0 deletions parameter_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import random
import bisect
from curses import wrapper
from typing import List, Iterable

from evaluation import Evaluation, EvalStatistics
import tensorflow as tf
import numpy as np

from speech_model import SpeechModel


class Candidate:

def __init__(self, lm_weight: float, word_count_weight: float):
self.score = None
self.stats = None
self.lm_weight = lm_weight
self.word_count_weight = word_count_weight

def __gt__(self, other):
return self.score > other.score

def __lt__(self, other):
return self.score < other.score

def __str__(self):
return ('{:.2f} Candidate (lm_weight={:.2f}, wc_weight={:.2f}) '
'has LER: {:.2f} WER: {:.2f}').format(self.score,
self.lm_weight,
self.word_count_weight,
self.stats.global_letter_error_rate,
self.stats.global_word_error_rate)

def update_score(self, score: float, stats: EvalStatistics):
self.score = score
self.stats = stats

def mutate(self, std: float):
return Candidate(lm_weight=self.lm_weight + np.random.normal(loc=0, scale=std),
word_count_weight=self.word_count_weight + np.random.normal(loc=0, scale=std))


class LanguageModelParameterSearch(Evaluation):

def __init__(self, flags):
super().__init__(flags)
self.candidates = []
self.num_iterations = 0

def create_sample_generator(self, limit_count: int):
return self.reader.load_samples('dev',
loop_infinitely=True,
limit_count=limit_count,
feature_type=self.flags.feature_type)

def _update_score_for_candidate(self, model: SpeechModel, sess: tf.Session, candidate: Candidate):
stats = EvalStatistics()
feed_dict = {
model.lm_weight: candidate.lm_weight,
model.word_count_weight: candidate.word_count_weight
}
self.run_epoch(model, sess, stats, save=False, verbose=False, feed_dict=feed_dict)
score = -(stats.global_letter_error_rate + stats.global_word_error_rate)
candidate.update_score(score, stats)

def get_loader_limit_count(self):
return 0

def get_max_epochs(self):
return None

def run(self):

with tf.Session() as sess:

model = self.create_model(sess)
coord = self.start_pipeline(sess)

def run_search(stdscr=None):
if stdscr:
stdscr.clear()
stdscr.addstr(0, 0, 'Loading...')
stdscr.refresh()

new_candidate = Candidate(1.0, 1.0)
self._update_score_for_candidate(model, sess, new_candidate)
self.candidates.append(new_candidate)

if stdscr:
self.print_population(stdscr)
else:
print(new_candidate)

while True:
if coord.should_stop():
break

random_candidate = random.choice(self.candidates)
new_candidate = random_candidate.mutate(self.flags.noise_std)
self._update_score_for_candidate(model, sess, new_candidate)

# Note: We're dealing with tiny populations, so O(n) is not an issue
bisect.insort(self.candidates, new_candidate)

if len(self.candidates) > self.flags.population_size:
del self.candidates[0]

self.num_iterations += 1

if stdscr:
self.print_population(stdscr)
else:
print(new_candidate)

coord.request_stop()
coord.join()

if self.flags.use_ui:
wrapper(run_search)
else:
run_search()

def print_population(self, stdscr):
stdscr.clear()
stdscr.addstr(0, 0, 'Current population after {} iterations'.format(self.num_iterations))
for idx, candidate in enumerate(reversed(self.candidates)):
stdscr.addstr(idx + 2, 0, str(candidate))
stdscr.refresh()
14 changes: 12 additions & 2 deletions speech_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,16 @@ def add_decoding_ops(self, language_model: str = None):
language_model: the file path to the language model to use for beam search decoding or None
"""
with tf.name_scope('decoding'):
self.lm_weight = tf.placeholder_with_default(1.0, shape=(), name='language_model_weight')
self.word_count_weight = tf.placeholder_with_default(1.0, shape=(), name='word_count_weight')

if language_model:
self.softmaxed = tf.log(tf.nn.softmax(self.logits, name='softmax') + 1e-8) / math.log(10)
self.decoded, self.log_probabilities = tf.nn.ctc_beam_search_decoder(self.softmaxed,
self.sequence_lengths // 2,
kenlm_directory_path=language_model,
kenlm_weight=self.lm_weight,
word_count_weight=self.word_count_weight,
beam_width=100,
merge_repeated=False,
top_paths=1)
Expand Down Expand Up @@ -182,7 +187,7 @@ def init_session(self, sess, init_variables=True):

self.summary_writer.add_graph(sess.graph)

def step(self, sess, loss=True, update=True, decode=False, return_label=False, summary=False):
def step(self, sess, loss=True, update=True, decode=False, return_label=False, summary=False, feed_dict=None):
"""
Evaluate the graph, you may update weights, decode audio or generate a summary
Expand All @@ -193,6 +198,7 @@ def step(self, sess, loss=True, update=True, decode=False, return_label=False, s
decode: should the decoding be performed and returned
return_label: should the label be returned
summary: should the summary be generated
feed_dict: additional tensors that should be fed
Returns: avg_loss (optional), decoded (optional), label (optional), update (optional), summary (optional)
Expand All @@ -215,7 +221,11 @@ def step(self, sess, loss=True, update=True, decode=False, return_label=False, s
if summary:
output_feed.append(self.merged_summaries)

return sess.run(output_feed, feed_dict=self.input_loader.get_feed_dict())
input_feed_dict = self.input_loader.get_feed_dict() or {}
if feed_dict is not None:
input_feed_dict.update(feed_dict)

return sess.run(output_feed, feed_dict=input_feed_dict)

@abc.abstractclassmethod
def _create_network(self, num_classes):
Expand Down

0 comments on commit 63a5c6a

Please sign in to comment.