Skip to content

Commit

Permalink
Merge pull request SeanNaren#81 from SiddGururani/master
Browse files Browse the repository at this point in the history
Added tensorboard logging
  • Loading branch information
Sean Naren committed Jun 11, 2017
2 parents ab99099 + 78ee28d commit e048430
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 7 deletions.
73 changes: 73 additions & 0 deletions logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
import tensorflow as tf
import numpy as np
import scipy.misc
try:
from StringIO import StringIO # Python 2.7
except ImportError:
from io import BytesIO # Python 3.x


class Logger(object):

def __init__(self, log_dir):
"""Create a summary writer logging to log_dir."""
self.writer = tf.summary.FileWriter(log_dir)

def scalar_summary(self, tag, value, step):
"""Log a scalar variable."""
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
self.writer.flush()

def image_summary(self, tag, images, step):
"""Log a list of images."""

img_summaries = []
for i, img in enumerate(images):
# Write the image to a string
try:
s = StringIO()
except:
s = BytesIO()
scipy.misc.toimage(img).save(s, format="png")

# Create an Image object
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
height=img.shape[0],
width=img.shape[1])
# Create a Summary value
img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))

# Create and write Summary
summary = tf.Summary(value=img_summaries)
self.writer.add_summary(summary, step)
self.writer.flush()

def histo_summary(self, tag, values, step, bins=1000):
"""Log a histogram of the tensor of values."""

# Create a histogram using numpy
counts, bin_edges = np.histogram(values, bins=bins)

# Fill the fields of the histogram proto
hist = tf.HistogramProto()
hist.min = float(np.min(values))
hist.max = float(np.max(values))
hist.num = int(np.prod(values.shape))
hist.sum = float(np.sum(values))
hist.sum_squares = float(np.sum(values**2))

# Drop the start of the first bin
bin_edges = bin_edges[1:]

# Add bin edges and counts
for edge in bin_edges:
hist.bucket_limit.append(edge)
for c in counts:
hist.bucket.append(c)

# Create and write Summary
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
self.writer.add_summary(summary, step)
self.writer.flush()
66 changes: 59 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,13 @@
help='Minimum noise level to sample from. (1.0 means all noise, not original signal)', type=float)
parser.add_argument('--noise_max', default=0.5,
help='Maximum noise levels to sample from. Maximum 1.0', type=float)
parser.set_defaults(cuda=False, silent=False, checkpoint=False, visdom=False, augment=False)
parser.add_argument('--tensorboard', dest='tensorboard', action='store_true', help='Turn on tensorboard graphing')
parser.add_argument('--log_dir', default='visualize/deepspeech_final', help='Location of tensorboard log')
parser.add_argument('--log_params', dest='log_params', action='store_true', help='Log parameter values and gradients')
parser.set_defaults(cuda=False, silent=False, checkpoint=False, visdom=False, augment=False, tensorboard=False, log_params=False)

def to_np(x):
return x.data.cpu().numpy()

class AverageMeter(object):
"""Computes and stores the average and current value"""
Expand Down Expand Up @@ -87,6 +92,25 @@ def main():
loss_results, cer_results, wer_results = torch.Tensor(args.epochs), torch.Tensor(args.epochs), torch.Tensor(
args.epochs)
epochs = torch.arange(1, args.epochs + 1)
if args.tensorboard:
from logger import Logger
try:
os.makedirs(args.log_dir)
except OSError as e:
if e.errno == errno.EEXIST:
print('Directory already exists.')
for file in os.listdir(args.log_dir):
file_path = os.path.join(args.log_dir, file)
try:
if os.path.isfile(file_path):
os.unlink(file_path)
except Exception as e:
raise
else:
raise
loss_results, cer_results, wer_results = torch.Tensor(args.epochs), torch.Tensor(args.epochs), torch.Tensor(
args.epochs)
logger = Logger(args.log_dir)

try:
os.makedirs(save_folder)
Expand Down Expand Up @@ -146,7 +170,7 @@ def main():
if args.visdom and \
package['loss_results'] is not None and start_epoch > 0: # Add previous scores to visdom graph
epoch = start_epoch
loss_results, cer_results, wer_results = package['loss_results'], package['cer_results'], package[
loss_results[0:epoch], cer_results[0:epoch], wer_results[0:epoch] = package['loss_results'], package['cer_results'], package[
'wer_results']
x_axis = epochs[0:epoch]
y_axis = [loss_results[0:epoch], wer_results[0:epoch], cer_results[0:epoch]]
Expand All @@ -156,6 +180,18 @@ def main():
Y=y_axis[x],
opts=opts[x],
)
if args.tensorboard and package['loss_results'] is not None and start_epoch > 0: # Add previous scores to tensorboard logs
epoch = start_epoch
loss_results, cer_results, wer_results = package['loss_results'], package['cer_results'], package[
'wer_results']
for i in range(len(loss_results)):
info = {
'Avg Train Loss': loss_results[i],
'Avg WER': wer_results[i],
'Avg CER': cer_results[i]
}
for tag, val in info.items():
logger.scalar_summary(tag, val, i+1)
else:
avg_loss = 0
start_epoch = 0
Expand Down Expand Up @@ -238,7 +274,7 @@ def main():
epoch + 1, loss=avg_loss))

start_iter = 0 # Reset start iteration for next epoch
total_cer, total_wer = 0, 0
total_cer, total_wer= 0, 0
model.eval()
for i, (data) in enumerate(test_loader): # test
inputs, targets, input_percentages, target_sizes = data
Expand Down Expand Up @@ -271,7 +307,6 @@ def main():

if args.cuda:
torch.cuda.synchronize()

wer = total_wer / len(test_loader.dataset)
cer = total_cer / len(test_loader.dataset)
wer *= 100
Expand All @@ -286,9 +321,9 @@ def main():
loss_results[epoch] = avg_loss
wer_results[epoch] = wer
cer_results[epoch] = cer
epoch += 1
x_axis = epochs[0:epoch]
y_axis = [loss_results[0:epoch], wer_results[0:epoch], cer_results[0:epoch]]
# epoch += 1
x_axis = epochs[0:epoch+1]
y_axis = [loss_results[0:epoch+1], wer_results[0:epoch+1], cer_results[0:epoch+1]]
for x in range(len(viz_windows)):
if viz_windows[x] is None:
viz_windows[x] = viz.line(
Expand All @@ -303,6 +338,23 @@ def main():
win=viz_windows[x],
update='replace',
)
if args.tensorboard:
loss_results[epoch] = avg_loss
wer_results[epoch] = wer
cer_results[epoch] = cer
info = {
'Avg Train Loss': avg_loss,
'Avg WER': wer,
'Avg CER': cer
}
for tag, val in info.items():
logger.scalar_summary(tag, val, epoch+1)
if args.log_params:
for tag, value in model.named_parameters():
tag = tag.replace('.', '/')
logger.histo_summary(tag, to_np(value), epoch+1)
if value.grad is not None: # Condition inserted because batch_norm RNN_0 weights.grad and bias.grad are None. Check why
logger.histo_summary(tag+'/grad', to_np(value.grad), epoch+1)
if args.checkpoint:
file_path = '%s/deepspeech_%d.pth.tar' % (save_folder, epoch + 1)
torch.save(DeepSpeech.serialize(model, optimizer=optimizer, epoch=epoch, loss_results=loss_results,
Expand Down

0 comments on commit e048430

Please sign in to comment.