This repository has been archived by the owner on Dec 29, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
train.py
executable file
·277 lines (235 loc) · 10.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
#! /usr/bin/env python
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main script to run training and evaluation of models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import tempfile
import yaml
import tensorflow as tf
from tensorflow.contrib.learn.python.learn import learn_runner
from tensorflow.contrib.learn.python.learn.estimators import run_config
from tensorflow import gfile
from seq2seq import models
from seq2seq.contrib.experiment import Experiment as PatchedExperiment
from seq2seq.configurable import _maybe_load_yaml, _create_from_dict
from seq2seq.configurable import _deep_merge_dict
from seq2seq.data import input_pipeline
from seq2seq.metrics import metric_specs
from seq2seq.training import hooks
from seq2seq.training import utils as training_utils
tf.flags.DEFINE_string("config_paths", "",
"""Path to a YAML configuration files defining FLAG
values. Multiple files can be separated by commas.
Files are merged recursively. Setting a key in these
files is equivalent to setting the FLAG value with
the same name.""")
tf.flags.DEFINE_string("hooks", "[]",
"""YAML configuration string for the
training hooks to use.""")
tf.flags.DEFINE_string("metrics", "[]",
"""YAML configuration string for the
training metrics to use.""")
tf.flags.DEFINE_string("model", "",
"""Name of the model class.
Can be either a fully-qualified name, or the name
of a class defined in `seq2seq.models`.""")
tf.flags.DEFINE_string("model_params", "{}",
"""YAML configuration string for the model
parameters.""")
tf.flags.DEFINE_string("input_pipeline_train", "{}",
"""YAML configuration string for the training
data input pipeline.""")
tf.flags.DEFINE_string("input_pipeline_dev", "{}",
"""YAML configuration string for the development
data input pipeline.""")
tf.flags.DEFINE_string("buckets", None,
"""Buckets input sequences according to these length.
A comma-separated list of sequence length buckets, e.g.
"10,20,30" would result in 4 buckets:
<10, 10-20, 20-30, >30. None disabled bucketing. """)
tf.flags.DEFINE_integer("batch_size", 16,
"""Batch size used for training and evaluation.""")
tf.flags.DEFINE_string("output_dir", None,
"""The directory to write model checkpoints and summaries
to. If None, a local temporary directory is created.""")
# Training parameters
tf.flags.DEFINE_string("schedule", "continuous_train_and_eval",
"""Estimator function to call, defaults to
continuous_train_and_eval for local run""")
tf.flags.DEFINE_integer("train_steps", None,
"""Maximum number of training steps to run.
If None, train forever.""")
tf.flags.DEFINE_integer("eval_every_n_steps", 1000,
"Run evaluation on validation data every N steps.")
# RunConfig Flags
tf.flags.DEFINE_integer("tf_random_seed", None,
"""Random seed for TensorFlow initializers. Setting
this value allows consistency between reruns.""")
tf.flags.DEFINE_integer("save_checkpoints_secs", None,
"""Save checkpoints every this many seconds.
Can not be specified with save_checkpoints_steps.""")
tf.flags.DEFINE_integer("save_checkpoints_steps", None,
"""Save checkpoints every this many steps.
Can not be specified with save_checkpoints_secs.""")
tf.flags.DEFINE_integer("keep_checkpoint_max", 5,
"""Maximum number of recent checkpoint files to keep.
As new files are created, older files are deleted.
If None or 0, all checkpoint files are kept.""")
tf.flags.DEFINE_integer("keep_checkpoint_every_n_hours", 4,
"""In addition to keeping the most recent checkpoint
files, keep one checkpoint file for every N hours of
training.""")
tf.flags.DEFINE_float("gpu_memory_fraction", 1.0,
"""Fraction of GPU memory used by the process on
each GPU uniformly on the same machine.""")
tf.flags.DEFINE_boolean("gpu_allow_growth", False,
"""Allow GPU memory allocation to grow
dynamically.""")
tf.flags.DEFINE_boolean("log_device_placement", False,
"""Log the op placement to devices""")
FLAGS = tf.flags.FLAGS
def create_experiment(output_dir):
"""
Creates a new Experiment instance.
Args:
output_dir: Output directory for model checkpoints and summaries.
"""
config = run_config.RunConfig(
tf_random_seed=FLAGS.tf_random_seed,
save_checkpoints_secs=FLAGS.save_checkpoints_secs,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
gpu_memory_fraction=FLAGS.gpu_memory_fraction)
config.tf_config.gpu_options.allow_growth = FLAGS.gpu_allow_growth
config.tf_config.log_device_placement = FLAGS.log_device_placement
train_options = training_utils.TrainOptions(
model_class=FLAGS.model,
model_params=FLAGS.model_params)
# On the main worker, save training options
if config.is_chief:
gfile.MakeDirs(output_dir)
train_options.dump(output_dir)
bucket_boundaries = None
if FLAGS.buckets:
bucket_boundaries = list(map(int, FLAGS.buckets.split(",")))
# Training data input pipeline
train_input_pipeline = input_pipeline.make_input_pipeline_from_def(
def_dict=FLAGS.input_pipeline_train,
mode=tf.contrib.learn.ModeKeys.TRAIN)
# Create training input function
train_input_fn = training_utils.create_input_fn(
pipeline=train_input_pipeline,
batch_size=FLAGS.batch_size,
bucket_boundaries=bucket_boundaries,
scope="train_input_fn")
# Development data input pipeline
dev_input_pipeline = input_pipeline.make_input_pipeline_from_def(
def_dict=FLAGS.input_pipeline_dev,
mode=tf.contrib.learn.ModeKeys.EVAL,
shuffle=False, num_epochs=1)
# Create eval input function
eval_input_fn = training_utils.create_input_fn(
pipeline=dev_input_pipeline,
batch_size=FLAGS.batch_size,
allow_smaller_final_batch=True,
scope="dev_input_fn")
def model_fn(features, labels, params, mode):
"""Builds the model graph"""
model = _create_from_dict({
"class": train_options.model_class,
"params": train_options.model_params
}, models, mode=mode)
return model(features, labels, params)
estimator = tf.contrib.learn.Estimator(
model_fn=model_fn,
model_dir=output_dir,
config=config,
params=FLAGS.model_params)
# Create hooks
train_hooks = []
for dict_ in FLAGS.hooks:
hook = _create_from_dict(
dict_, hooks,
model_dir=estimator.model_dir,
run_config=config)
train_hooks.append(hook)
# Create metrics
eval_metrics = {}
for dict_ in FLAGS.metrics:
metric = _create_from_dict(dict_, metric_specs)
eval_metrics[metric.name] = metric
experiment = PatchedExperiment(
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
min_eval_frequency=FLAGS.eval_every_n_steps,
train_steps=FLAGS.train_steps,
eval_steps=None,
eval_metrics=eval_metrics,
train_monitors=train_hooks)
return experiment
def main(_argv):
"""The entrypoint for the script"""
# Parse YAML FLAGS
FLAGS.hooks = _maybe_load_yaml(FLAGS.hooks)
FLAGS.metrics = _maybe_load_yaml(FLAGS.metrics)
FLAGS.model_params = _maybe_load_yaml(FLAGS.model_params)
FLAGS.input_pipeline_train = _maybe_load_yaml(FLAGS.input_pipeline_train)
FLAGS.input_pipeline_dev = _maybe_load_yaml(FLAGS.input_pipeline_dev)
# Load flags from config file
final_config = {}
if FLAGS.config_paths:
for config_path in FLAGS.config_paths.split(","):
config_path = config_path.strip()
if not config_path:
continue
config_path = os.path.abspath(config_path)
tf.logging.info("Loading config from %s", config_path)
with gfile.GFile(config_path.strip()) as config_file:
config_flags = yaml.load(config_file)
final_config = _deep_merge_dict(final_config, config_flags)
tf.logging.info("Final Config:\n%s", yaml.dump(final_config))
# Merge flags with config values
for flag_key, flag_value in final_config.items():
if hasattr(FLAGS, flag_key) and isinstance(getattr(FLAGS, flag_key), dict):
merged_value = _deep_merge_dict(flag_value, getattr(FLAGS, flag_key))
setattr(FLAGS, flag_key, merged_value)
elif hasattr(FLAGS, flag_key):
setattr(FLAGS, flag_key, flag_value)
else:
tf.logging.warning("Ignoring config flag: %s", flag_key)
if FLAGS.save_checkpoints_secs is None \
and FLAGS.save_checkpoints_steps is None:
FLAGS.save_checkpoints_secs = 600
tf.logging.info("Setting save_checkpoints_secs to %d",
FLAGS.save_checkpoints_secs)
if not FLAGS.output_dir:
FLAGS.output_dir = tempfile.mkdtemp()
if not FLAGS.input_pipeline_train:
raise ValueError("You must specify input_pipeline_train")
if not FLAGS.input_pipeline_dev:
raise ValueError("You must specify input_pipeline_dev")
learn_runner.run(
experiment_fn=create_experiment,
output_dir=FLAGS.output_dir,
schedule=FLAGS.schedule)
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()