forked from google-deepmind/learning-to-learn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
base.py
322 lines (248 loc) · 11.2 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
# Copyright 2016 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for TensorFlow nn.
This file contains the Abstract Base Class for defining Modules in TensorFlow.
A Module is an object which can be connected into the Graph multiple times
using the __call__ method, sharing variables automatically with no need to
explicitly use scopes or specify reuse=True.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import six
from six import string_types
from six.moves import xrange
import tensorflow as tf
class Error(Exception):
"""Base class for all errors from nn.
This is thrown to indicate a Neural Network specific problem, e.g. wrong
module arity, module is not connected to the graph when it should be,
tried to wire together incompatible modules, etc.
"""
class NotConnectedError(Error):
"""Error raised when operating on a module that has not yet been connected.
Some module properties / methods are valid to access before the module has
been connected into the graph, but some are not. This Error is raised when
the user attempts to do anything not valid before connection.
"""
class ParentNotBuiltError(Error):
"""Error raised when the parent of a module has not been built yet.
For example, when making a transpose of modules which inherit from
`module.Transposable`, the parent has to be connected to the graph before the
child transpose to ensure that shape inference has already occurred.
"""
class IncompatibleShapeError(Error):
"""Error raised when the shape of the input at build time is incompatible."""
class UnderspecifiedError(Error):
"""Error raised when too little information is available.
This does not typically mean the user is trying to do something that doesn't
work (in which case `IncompatibleShapeError` should be used), just that
some more information needs to be provided in order to build the Graph.
"""
class NotSupportedError(Error):
"""Error raised when something that cannot be supported is requested.
For example a Dilated Convolution module cannot be transposed.
"""
@six.add_metaclass(abc.ABCMeta)
class AbstractModule(object):
"""Superclass for nn Modules.
This class defines the functionality that every module should implement,
principally the `build` method which is wrapped using `tf.make_template`
and called from `__call__`. Every time the module is called it will
be connected into the graph but using the same shared set of variables, thanks
to the template.
For this to work correctly, the `build` implementation in the derived class
must access all variables using `tf.get_variable`, not `tf.Variable`. The same
set of variables must be created each time, if this is not the case an Error
will be raised.
Every subclass must call this class' `__init__` at the start of their
`__init__`, passing the relevant name. If this step is omitted variable
sharing will not work.
"""
# Name of TensorFlow collection containing ops to update every step, such as
# moving average update ops.
UPDATE_OPS_COLLECTION = tf.GraphKeys.UPDATE_OPS
def __init__(self, name):
"""Performs the initialisation necessary for all AbstractModule instances.
Every subclass of AbstractModule must begin their constructor with a call to
this constructor, i.e. `super(MySubModule, self).__init__(name=name)`.
Avoid instantiating sub-modules in __init__ where possible, as they will not
be defined under the module's scope. Instead, instantiate sub-modules in
`build`.
Args:
name: Name of this module. Used to construct the Templated build function.
Raises:
ValueError: If name is not specified.
"""
if not isinstance(name, string_types):
raise ValueError("Name must be a string.")
self._is_connected = False
self._template = tf.make_template(name, self._build,
create_scope_now_=True)
# Update __call__ and the object docstrings to enable better introspection
self.__doc__ = self._build.__doc__
self.__call__.__func__.__doc__ = self._build.__doc__
@abc.abstractmethod
def _build(self, *args, **kwargs):
"""Add elements to the Graph, computing output Tensors from input Tensors.
Subclasses must implement this method, which will be wrapped in a Template.
Args:
*args: Input Tensors.
**kwargs: Additional Python flags controlling connection.
"""
pass
def __call__(self, *args, **kwargs):
out = self._template(*args, **kwargs)
# Connect the module only if self._template returns with no errors.
self._is_connected = True
return out
@property
def variable_scope(self):
"""Returns the variable_scope declared by the module.
It is valid for library users to access the internal templated
variable_scope, but only makes sense to do so after connection. Therefore
we raise an error here if the variable_scope is requested before connection.
The only case where it does make sense to access the variable_scope before
connection is to get the post-uniquification name, which we support using
the separate .name property.
Returns:
variable_scope: `tf.VariableScope` instance of the internal `tf.Template`.
Raises:
NotConnectedError: If the module is not connected to the Graph.
"""
self._ensure_is_connected()
return self._template.variable_scope
@property
def name(self):
"""Returns the name of the Module."""
return self._template.variable_scope.name
@property
def is_connected(self):
"""Returns true iff the Module been connected to the Graph at least once."""
return self._is_connected
@classmethod
def get_possible_initializer_keys(cls):
"""Returns the keys the dictionary of variable initializers may contain.
This provides the user with a way of knowing the initializer keys that are
available without having to instantiate a nn module. Subclasses may
override this class method if they need additional arguments to determine
what initializer keys may be provided.
Returns:
Set with strings corresponding to the strings that may be passed to the
constructor.
"""
return getattr(cls, "POSSIBLE_INITIALIZER_KEYS", set())
def _ensure_is_connected(self):
"""Raise an Error if the module has not been connected yet.
Until the module is connected into the Graph, any variables created do
not exist yet and cannot be created in advance due to not knowing the size
of the input Tensor(s). This assertion ensures that any variables contained
in this module must now exist.
Raises:
NotConnectedError: If the module is not connected to the Graph.
"""
if not self.is_connected:
raise NotConnectedError(
"Variables in {} not instantiated yet, __call__ the module "
"first.".format(self.name))
@six.add_metaclass(abc.ABCMeta)
class Transposable(object):
"""Transposable module interface.
The Transposable interface requires that transposable modules implement
a method called `transpose`, returning a module which is the transposed
version of the one the method is called on.
Calling the method twice should return a module with the same specifications
as the original module.
When implementing a transposable module, special care is required to make
sure that parameters needed to instantiate the module are provided as
functions whose invocation is deferred to graph construction time.
For example, in Linear we might want to call:
```python
linear = nn.Linear(name="linear", output_size=output_size)
linear_transpose = linear.transpose()
```
where the output_size for linear_transpose is not known yet, as linear is
not yet connected to the graph: output_size is passed to linear_transpose's
constructor as a lambda returning linear.input_size. The lambda will return
the correct value once linear is given an input.
Notice that linear_transpose's output_size value does not need to be defined
until the module is connected to the graph.
"""
@abc.abstractmethod
def transpose(self, name=None, **kwargs):
"""Builds and returns transposed version of module.
Args:
name: Name of the transposed module.
**kwargs: Additional Python flags controlling transposition.
Returns:
Transposed version of the module.
"""
pass
@abc.abstractmethod
def input_shape(self):
"""Returns shape of input `Tensor` passed at last call to `build`."""
pass
class Module(AbstractModule):
"""Module wrapping a function provided by the user."""
def __init__(self, build, name="module"):
"""Constructs a module with a given build function.
The Module class can be used to wrap a function assembling a network into a
module.
For example, the following code implements a simple one-hidden-layer MLP
model by defining a function called make_model and using a Module instance
to wrap it.
```python
def make_model(inputs):
lin1 = nn.Linear(name="lin1", output_size=10)(inputs)
relu1 = tf.nn.relu(lin1, name="relu1")
lin2 = nn.Linear(name="lin2", output_size=20)(relu1)
return lin2
model = nn.Module(name='simple_mlp', build=make_model)
outputs = model(inputs)
```
The `partial` package from `functools` can be used to bake configuration
parameters into the function at construction time, as shown in the following
example.
```python
from functools import partial
def make_model(inputs, output_sizes):
lin1 = nn.Linear(name="lin1", output_size=output_sizes[0])(inputs)
relu1 = tf.nn.relu(lin1, name="relu1")
lin2 = nn.Linear(name="lin2", output_size=output_sizes[1])(relu1)
return lin2
model = nn.Module(name='simple_mlp',
build=partial(make_model, output_size=[10, 20])
outputs = model(inputs)
```
Args:
build: Callable to be invoked when connecting the module to the graph.
The `build` function is invoked when the module is called, and its
role is to specify how to add elements to the Graph, and how to
compute output Tensors from input Tensors.
The `build` function signature can include the following parameters:
*args - Input Tensors.
**kwargs - Additional Python parameters controlling connection.
name: Module name.
Raises:
TypeError: If build is not callable.
"""
super(Module, self).__init__(name)
if not callable(build):
raise TypeError("Input 'build' must be callable.")
self._build = build
def _build(self, *args, **kwargs):
"""Forwards call to the passed-in build function."""
return self._build(*args, **kwargs)