Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #34 from rsepassi/push
Browse files Browse the repository at this point in the history
v1.0.7
  • Loading branch information
lukaszkaiser committed Jun 24, 2017
2 parents 3410bea + d578f52 commit d029c45
Show file tree
Hide file tree
Showing 13 changed files with 285 additions and 12 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@

# Python egg metadata, regenerated from source files by setuptools.
/*.egg-info

# PyPI distribution artificats
build/
dist/
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.0.6',
version='1.0.7',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
Expand Down
1 change: 1 addition & 0 deletions tensor2tensor/data_generators/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io
import os
import tarfile
import urllib

# Dependency imports

Expand Down
7 changes: 1 addition & 6 deletions tensor2tensor/data_generators/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@
# Dependency imports

import numpy as np
from six.moves import cPickle
from six.moves import xrange # pylint: disable=redefined-builtin
from six.moves import zip # pylint: disable=redefined-builtin
from six.moves import cPickle

from tensor2tensor.data_generators import generator_utils

import tensorflow as tf
Expand Down Expand Up @@ -201,10 +200,6 @@ def cifar10_generator(tmp_dir, training, how_many, start_from=0):
])
labels = data["labels"]
all_labels.extend([labels[j] for j in xrange(num_images)])
# Shuffle the data to make sure classes are well distributed.
data = zip(all_images, all_labels)
random.shuffle(data)
all_images, all_labels = zip(*data)
return image_generator(all_images[start_from:start_from + how_many],
all_labels[start_from:start_from + how_many])

Expand Down
8 changes: 5 additions & 3 deletions tensor2tensor/data_generators/text_encoder.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
from __future__ import division
from __future__ import print_function

from collections import defaultdict

# Dependency imports

import six
from six.moves import xrange # pylint: disable=redefined-builtin
from collections import defaultdict
from tensor2tensor.data_generators import tokenizer

import tensorflow as tf
Expand All @@ -41,6 +42,7 @@
else:
RESERVED_TOKENS_BYTES = [bytes(PAD, 'ascii'), bytes(EOS, 'ascii')]


class TextEncoder(object):
"""Base class for converting from ints to/from human readable strings."""

Expand Down Expand Up @@ -95,7 +97,7 @@ def encode(self, s):
if six.PY2:
return [ord(c) + numres for c in s]
# Python3: explicitly convert to UTF-8
return [c + numres for c in s.encode("utf-8")]
return [c + numres for c in s.encode('utf-8')]

def decode(self, ids):
numres = self._num_reserved_ids
Expand All @@ -109,7 +111,7 @@ def decode(self, ids):
if six.PY2:
return ''.join(decoded_ids)
# Python3: join byte arrays and then decode string
return b''.join(decoded_ids).decode("utf-8")
return b''.join(decoded_ids).decode('utf-8')

@property
def vocab_size(self):
Expand Down
3 changes: 2 additions & 1 deletion tensor2tensor/data_generators/tokenizer.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@
from __future__ import division
from __future__ import print_function

from collections import defaultdict
import string

# Dependency imports

from six.moves import xrange # pylint: disable=redefined-builtin
from collections import defaultdict


class Tokenizer(object):
"""Vocab for breaking words into wordpieces.
Expand Down
150 changes: 150 additions & 0 deletions tensor2tensor/models/bluenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BlueNet: and out of the blue network to experiment with shake-shake."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Dependency imports

from six.moves import xrange # pylint: disable=redefined-builtin

from tensor2tensor.models import common_hparams
from tensor2tensor.models import common_layers
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model

import tensorflow as tf


def residual_module(x, hparams, train, n, sep):
"""A stack of convolution blocks with residual connection."""
k = (hparams.kernel_height, hparams.kernel_width)
dilations_and_kernels = [((1, 1), k) for _ in xrange(n)]
with tf.variable_scope("residual_module%d_sep%d" % (n, sep)):
y = common_layers.subseparable_conv_block(
x,
hparams.hidden_size,
dilations_and_kernels,
padding="SAME",
separability=sep,
name="block")
x = common_layers.layer_norm(x + y, hparams.hidden_size, name="lnorm")
return tf.nn.dropout(x, 1.0 - hparams.dropout * tf.to_float(train))


def residual_module1(x, hparams, train):
return residual_module(x, hparams, train, 1, 1)


def residual_module1_sep(x, hparams, train):
return residual_module(x, hparams, train, 1, 0)


def residual_module2(x, hparams, train):
return residual_module(x, hparams, train, 2, 1)


def residual_module2_sep(x, hparams, train):
return residual_module(x, hparams, train, 2, 0)


def residual_module3(x, hparams, train):
return residual_module(x, hparams, train, 3, 1)


def residual_module3_sep(x, hparams, train):
return residual_module(x, hparams, train, 3, 0)


def norm_module(x, hparams, train):
del train # Unused.
return common_layers.layer_norm(x, hparams.hidden_size, name="norm_module")


def identity_module(x, hparams, train):
del hparams, train # Unused.
return x


def run_modules(blocks, cur, hparams, train, dp):
"""Run blocks in parallel using dp as data_parallelism."""
assert len(blocks) % dp.n == 0
res = []
for i in xrange(len(blocks) // dp.n):
res.extend(dp(blocks[i * dp.n:(i + 1) * dp.n], cur, hparams, train))
return res


@registry.register_model
class BlueNet(t2t_model.T2TModel):

def model_fn_body_sharded(self, sharded_features, train):
dp = self._data_parallelism
dp._reuse = False # pylint:disable=protected-access
hparams = self._hparams
blocks = [identity_module, norm_module,
residual_module1, residual_module1_sep,
residual_module2, residual_module2_sep,
residual_module3, residual_module3_sep]
inputs = sharded_features["inputs"]

cur = tf.concat(inputs, axis=0)
cur_shape = cur.get_shape()
for i in xrange(hparams.num_hidden_layers):
with tf.variable_scope("layer_%d" % i):
processed = run_modules(blocks, cur, hparams, train, dp)
cur = common_layers.shakeshake(processed)
cur.set_shape(cur_shape)

return list(tf.split(cur, len(inputs), axis=0)), 0.0


@registry.register_hparams
def bluenet_base():
"""Set of hyperparameters."""
hparams = common_hparams.basic_params1()
hparams.batch_size = 4096
hparams.hidden_size = 768
hparams.dropout = 0.2
hparams.symbol_dropout = 0.2
hparams.label_smoothing = 0.1
hparams.clip_grad_norm = 2.0
hparams.num_hidden_layers = 8
hparams.kernel_height = 3
hparams.kernel_width = 3
hparams.learning_rate_decay_scheme = "exp50k"
hparams.learning_rate = 0.05
hparams.learning_rate_warmup_steps = 3000
hparams.initializer_gain = 1.0
hparams.weight_decay = 3.0
hparams.num_sampled_classes = 0
hparams.sampling_method = "argmax"
hparams.optimizer_adam_epsilon = 1e-6
hparams.optimizer_adam_beta1 = 0.85
hparams.optimizer_adam_beta2 = 0.997
hparams.add_hparam("imagenet_use_2d", True)
return hparams


@registry.register_hparams
def bluenet_tiny():
hparams = bluenet_base()
hparams.batch_size = 1024
hparams.hidden_size = 128
hparams.num_hidden_layers = 4
hparams.learning_rate_decay_scheme = "none"
return hparams
54 changes: 54 additions & 0 deletions tensor2tensor/models/bluenet_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BlueNet tests."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Dependency imports

import numpy as np

from tensor2tensor.data_generators import problem_hparams
from tensor2tensor.models import bluenet

import tensorflow as tf


class BlueNetTest(tf.test.TestCase):

def testBlueNet(self):
vocab_size = 9
x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1))
y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 1, 1, 1))
hparams = bluenet.bluenet_tiny()
p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size,
vocab_size)
with self.test_session() as session:
features = {
"inputs": tf.constant(x, dtype=tf.int32),
"targets": tf.constant(y, dtype=tf.int32),
}
model = bluenet.BlueNet(hparams, p_hparams)
sharded_logits, _, _ = model.model_fn(features, True)
logits = tf.concat(sharded_logits, 0)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (3, 5, 1, 1, vocab_size))


if __name__ == "__main__":
tf.test.main()
46 changes: 46 additions & 0 deletions tensor2tensor/models/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,52 @@ def inverse_exp_decay(max_step, min_value=0.01):
return inv_base**tf.maximum(float(max_step) - step, 0.0)


def shakeshake2_py(x, y, equal=False):
"""The shake-shake sum of 2 tensors, python version."""
alpha = 0.5 if equal else tf.random_uniform([])
return alpha * x + (1.0 - alpha) * y


@function.Defun()
def shakeshake2_grad(x1, x2, dy):
"""Overriding gradient for shake-shake of 2 tensors."""
y = shakeshake2_py(x1, x2)
dx = tf.gradients(ys=[y], xs=[x1, x2], grad_ys=[dy])
return dx


@function.Defun()
def shakeshake2_equal_grad(x1, x2, dy):
"""Overriding gradient for shake-shake of 2 tensors."""
y = shakeshake2_py(x1, x2, equal=True)
dx = tf.gradients(ys=[y], xs=[x1, x2], grad_ys=[dy])
return dx


@function.Defun(grad_func=shakeshake2_grad)
def shakeshake2(x1, x2):
"""The shake-shake function with a different alpha for forward/backward."""
return shakeshake2_py(x1, x2)


@function.Defun(grad_func=shakeshake2_equal_grad)
def shakeshake2_eqgrad(x1, x2):
"""The shake-shake function with a different alpha for forward/backward."""
return shakeshake2_py(x1, x2)


def shakeshake(xs, equal_grad=False):
"""Multi-argument shake-shake, currently approximated by sums of 2."""
if len(xs) == 1:
return xs[0]
div = (len(xs) + 1) // 2
arg1 = shakeshake(xs[:div], equal_grad=equal_grad)
arg2 = shakeshake(xs[div:], equal_grad=equal_grad)
if equal_grad:
return shakeshake2_eqgrad(arg1, arg2)
return shakeshake2(arg1, arg2)


def standardize_images(x):
"""Image standardization on batches (tf.image.per_image_standardization)."""
with tf.name_scope("standardize_images", [x]):
Expand Down
9 changes: 9 additions & 0 deletions tensor2tensor/models/common_layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ def testEmbedding(self):
res = session.run(y)
self.assertEqual(res.shape, (3, 5, 16))

def testShakeShake(self):
x = np.random.rand(5, 7)
with self.test_session() as session:
x = tf.constant(x, dtype=tf.float32)
y = common_layers.shakeshake([x, x, x, x, x])
session.run(tf.global_variables_initializer())
inp, res = session.run([x, y])
self.assertAllClose(res, inp)

def testConv(self):
x = np.random.rand(5, 7, 1, 11)
with self.test_session() as session:
Expand Down
1 change: 1 addition & 0 deletions tensor2tensor/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from tensor2tensor.models import attention_lm
from tensor2tensor.models import attention_lm_moe
from tensor2tensor.models import bluenet
from tensor2tensor.models import bytenet
from tensor2tensor.models import lstm
from tensor2tensor.models import modalities
Expand Down
10 changes: 10 additions & 0 deletions tensor2tensor/models/xception.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,13 @@ def xception_base():
hparams.optimizer_adam_beta2 = 0.997
hparams.add_hparam("imagenet_use_2d", True)
return hparams


@registry.register_hparams
def xception_tiny():
hparams = xception_base()
hparams.batch_size = 1024
hparams.hidden_size = 128
hparams.num_hidden_layers = 4
hparams.learning_rate_decay_scheme = "none"
return hparams
Loading

0 comments on commit d029c45

Please sign in to comment.