From 1efd7f171ba30421b5d8369a93526395a721c0d9 Mon Sep 17 00:00:00 2001 From: Brennan Saeta Date: Tue, 20 Jun 2017 14:10:40 -0700 Subject: [PATCH] Adds new Dataset APIs for filesystems This change adds a new API to construct a Dataset: - Dataset.list_files("/path/to/files/*.data") This API can be used in conjunction with the forthcoming .interleave() API which will allow for the construction of parallel loading of data files. As an added bonus, input pipelines constructed in this manner should be deterministic, to allow for reproducible work. PiperOrigin-RevId: 159611145 --- .../contrib/data/python/kernel_tests/BUILD | 15 ++ .../list_files_dataset_op_test.py | 159 ++++++++++++++++++ .../contrib/data/python/ops/dataset_ops.py | 24 +++ 3 files changed, 198 insertions(+) create mode 100644 tensorflow/contrib/data/python/kernel_tests/list_files_dataset_op_test.py diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index ab4d80c3275d2c..9909ea41c93638 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -99,6 +99,21 @@ py_test( ], ) +py_test( + name = "list_files_dataset_op_test", + size = "small", + srcs = ["list_files_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) + py_test( name = "map_dataset_op_test", size = "small", diff --git a/tensorflow/contrib/data/python/kernel_tests/list_files_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/list_files_dataset_op_test.py new file mode 100644 index 00000000000000..27298de65f90c6 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/list_files_dataset_op_test.py @@ -0,0 +1,159 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from os import path +import shutil +import tempfile + +from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class ListFilesDatasetOpTest(test.TestCase): + + def setUp(self): + self.tmp_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + + def _touchTempFiles(self, filenames): + for filename in filenames: + open(path.join(self.tmp_dir, filename), 'a').close() + + def testEmptyDirectory(self): + dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*')) + with self.test_session() as sess: + itr = dataset.make_one_shot_iterator() + with self.assertRaises(errors.OutOfRangeError): + sess.run(itr.get_next()) + + def testSimpleDirectory(self): + filenames = ['a', 'b', 'c'] + self._touchTempFiles(filenames) + + dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*')) + with self.test_session() as sess: + itr = dataset.make_one_shot_iterator() + + full_filenames = [] + produced_filenames = [] + for filename in filenames: + full_filenames.append( + compat.as_bytes(path.join(self.tmp_dir, filename))) + produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) + self.assertItemsEqual(full_filenames, produced_filenames) + with self.assertRaises(errors.OutOfRangeError): + sess.run(itr.get_next()) + + def testEmptyDirectoryInitializer(self): + filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) + dataset = dataset_ops.Dataset.list_files(filename_placeholder) + + with self.test_session() as sess: + itr = dataset.make_initializable_iterator() + sess.run( + itr.initializer, + feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')}) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(itr.get_next()) + + def testSimpleDirectoryInitializer(self): + filenames = ['a', 'b', 'c'] + self._touchTempFiles(filenames) + + filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) + dataset = dataset_ops.Dataset.list_files(filename_placeholder) + + with self.test_session() as sess: + itr = dataset.make_initializable_iterator() + sess.run( + itr.initializer, + feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')}) + + full_filenames = [] + produced_filenames = [] + for filename in filenames: + full_filenames.append( + compat.as_bytes(path.join(self.tmp_dir, filename))) + produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) + + self.assertItemsEqual(full_filenames, produced_filenames) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(itr.get_next()) + + def testFileSuffixes(self): + filenames = ['a.txt', 'b.py', 'c.py', 'd.pyc'] + self._touchTempFiles(filenames) + + filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) + dataset = dataset_ops.Dataset.list_files(filename_placeholder) + + with self.test_session() as sess: + itr = dataset.make_initializable_iterator() + sess.run( + itr.initializer, + feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py')}) + + full_filenames = [] + produced_filenames = [] + for filename in filenames[1:-1]: + full_filenames.append( + compat.as_bytes(path.join(self.tmp_dir, filename))) + produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) + self.assertItemsEqual(full_filenames, produced_filenames) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(itr.get_next()) + + def testFileMiddles(self): + filenames = ['a.txt', 'b.py', 'c.pyc'] + self._touchTempFiles(filenames) + + filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) + dataset = dataset_ops.Dataset.list_files(filename_placeholder) + + with self.test_session() as sess: + itr = dataset.make_initializable_iterator() + sess.run( + itr.initializer, + feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py*')}) + + full_filenames = [] + produced_filenames = [] + for filename in filenames[1:]: + full_filenames.append( + compat.as_bytes(path.join(self.tmp_dir, filename))) + produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) + + self.assertItemsEqual(full_filenames, produced_filenames) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(itr.get_next()) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index 89410bf84472d1..29f1209a58aa02 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import gen_io_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops @@ -600,6 +601,29 @@ def read_batch_features(file_pattern, dataset = dataset.batch(batch_size) return dataset + @staticmethod + def list_files(file_pattern): + """A dataset of all files matching a pattern. + + Example: + If we had the following files on our filesystem: + - /path/to/dir/a.txt + - /path/to/dir/b.py + - /path/to/dir/c.py + If we pass "/path/to/dir/*.py" as the directory, the dataset would + produce: + - /path/to/dir/b.py + - /path/to/dir/c.py + + Args: + file_pattern: A string or scalar string `tf.Tensor`, representing + the filename pattern that will be matched. + + Returns: + A `Dataset` of strings corresponding to file names. + """ + return Dataset.from_tensor_slices(gen_io_ops.matching_files(file_pattern)) + def repeat(self, count=None): """Repeats this dataset `count` times.