Skip to content

Commit

Permalink
Add in-memory caching to datasets.
Browse files Browse the repository at this point in the history
In-memory caching allows us to repeat over a small- or medium-sized dataset
while only reading them from storage and/or performing basic processing just
once.

PiperOrigin-RevId: 158585973
  • Loading branch information
saeta authored and tensorflower-gardener committed Jun 10, 2017
1 parent 19b4ccd commit 61a46ce
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test


class CacheDatasetTest(test.TestCase):
class FilesystemCacheDatasetTest(test.TestCase):

def setUp(self):
self.tmp_dir = tempfile.mkdtemp()
Expand Down Expand Up @@ -197,5 +198,102 @@ def testConcurrentReaders(self):
self.assertAllEqual(elements, elements_itr2)


class MemoryCacheDatasetTest(test.TestCase):

def testCacheDatasetPassthrough(self):
repeat_count = variables.Variable(constant_op.constant(10, dtypes.int64))
dataset = dataset_ops.Dataset.range(3).flat_map(
lambda x: dataset_ops.Dataset.from_tensors(x).repeat(repeat_count))

cached_dataset = dataset.cache().repeat(2)
uncached_dataset = dataset.repeat(2)

# Needs to be initializable to capture the variable.
cached_iterator = cached_dataset.make_initializable_iterator()
cached_next = cached_iterator.get_next()
uncached_iterator = uncached_dataset.make_initializable_iterator()
uncached_next = uncached_iterator.get_next()

with self.test_session() as sess:

sess.run(repeat_count.initializer)
sess.run(cached_iterator.initializer)
sess.run(uncached_iterator.initializer)

for i in range(3):
for _ in range(10):
self.assertEqual(sess.run(cached_next), i)
self.assertEqual(sess.run(uncached_next), i)

sess.run(repeat_count.assign(0))

# The uncached iterator should now be empty.
with self.assertRaises(errors.OutOfRangeError):
sess.run(uncached_next)

# The cached iterator replays from cache.
for i in range(3):
for _ in range(10):
self.assertEqual(sess.run(cached_next), i)

# The cached iterator should now be empty.
with self.assertRaises(errors.OutOfRangeError):
sess.run(cached_next)

def testEmptyCacheReading(self):
components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
np.array([9.0, 10.0, 11.0, 12.0]))
count_placeholder = array_ops.placeholder_with_default(
constant_op.constant(5, dtypes.int64), shape=[])

repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(components)
.repeat(count_placeholder))

cache_dataset = repeat_dataset.cache()

# Create initialization ops for iterators without and with
# caching, respectively.
iterator = cache_dataset.make_initializable_iterator()
init_cache_op = iterator.initializer

get_next = iterator.get_next()

with self.test_session() as sess:
# Initialize with an empty upstream and a missing cache file (should
# throw errors.OutOfRangeError immediately).
sess.run(init_cache_op, feed_dict={count_placeholder: 0})
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

def testConcurrentReaders(self):
count_placeholder = array_ops.placeholder_with_default(
constant_op.constant(5, dtypes.int64), shape=[])
dataset = dataset_ops.Dataset.range(count_placeholder).cache()
d1 = dataset.map(lambda x: x + 1)
d2 = dataset.map(lambda x: x + 6)

i1 = d1.make_initializable_iterator()
i2 = d2.make_initializable_iterator()

with self.test_session() as sess:
sess.run(i1.initializer)

self.assertEqual(1, sess.run(i1.get_next()))
self.assertEqual(2, sess.run(i1.get_next()))
self.assertEqual(3, sess.run(i1.get_next()))

sess.run(i2.initializer, feed_dict={count_placeholder: 3})

self.assertEqual(6, sess.run(i2.get_next()))
self.assertEqual(7, sess.run(i2.get_next()))
self.assertEqual(4, sess.run(i1.get_next())) # interleave execution
self.assertEqual([8, 5], sess.run([i2.get_next(), i1.get_next()]))

with self.assertRaises(errors.OutOfRangeError):
sess.run(i1.get_next())
with self.assertRaises(errors.OutOfRangeError):
sess.run(i2.get_next())


if __name__ == "__main__":
test.main()
3 changes: 2 additions & 1 deletion tensorflow/contrib/data/python/ops/dataset_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,12 +657,13 @@ def shuffle(self, buffer_size, seed=None):
"""
return ShuffleDataset(self, buffer_size, seed)

def cache(self, filename):
def cache(self, filename=""):
"""Caches the elements in this dataset.
Args:
filename: A `tf.string` scalar `tf.Tensor`, representing the name of a
directory on the filesystem to use for caching tensors in this Dataset.
If a filename is not provided, the dataset will be cached in memory.
Returns:
A `Dataset`.
Expand Down
188 changes: 152 additions & 36 deletions tensorflow/core/kernels/cache_dataset_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,35 +27,29 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level description of
// the following op.

class CacheDatasetOp : public OpKernel {
class CacheDatasetOp : public UnaryDatasetOpKernel {
public:
explicit CacheDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}

void Compute(OpKernelContext* ctx) override {
DatasetBase* input;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &input));
core::ScopedUnref unref_input(input);
explicit CacheDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {}

void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
// Parse out the filenames tensor.
const Tensor* filename_tensor;
OP_REQUIRES_OK(ctx, ctx->input("filename", &filename_tensor));
OP_REQUIRES(ctx, filename_tensor->dims() == 0,
errors::InvalidArgument("`filename` must be a scalar."));
string filename = filename_tensor->flat<string>()(0);

DatasetBase* dataset = new Dataset(input, filename, ctx->env());
Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
ResourceHandle handle = MakeResourceHandle<DatasetBase>(
ctx, ctx->step_container()->name(), name());
OP_REQUIRES_OK(ctx, CreateResource(ctx, handle, dataset));
output->flat<ResourceHandle>()(0) = handle;
string filename;
OP_REQUIRES_OK(ctx,
ParseScalarArgument<string>(ctx, "filename", &filename));

if (filename.empty()) {
*output = new MemoryDataset(input);
} else {
*output = new FileDataset(input, filename, ctx->env());
}
}

private:
class Dataset : public DatasetBase {
class FileDataset : public DatasetBase {
public:
explicit Dataset(const DatasetBase* input, string filename, Env* env)
explicit FileDataset(const DatasetBase* input, string filename, Env* env)
: input_(input),
filename_(std::move(filename)),
env_(env),
Expand All @@ -69,13 +63,13 @@ class CacheDatasetOp : public OpKernel {
DCHECK_EQ(item_index_padding_size_, 7);
}

~Dataset() override { input_->Unref(); }
~FileDataset() override { input_->Unref(); }

std::unique_ptr<IteratorBase> MakeIterator() const override {
if (env_->FileExists(strings::StrCat(filename_, ".index")).ok()) {
return std::unique_ptr<IteratorBase>(new ReaderIterator(this));
return std::unique_ptr<IteratorBase>(new FileReaderIterator(this));
} else {
return std::unique_ptr<IteratorBase>(new WriterIterator(this));
return std::unique_ptr<IteratorBase>(new FileWriterIterator(this));
}
}

Expand All @@ -87,7 +81,7 @@ class CacheDatasetOp : public OpKernel {
return input_->output_shapes();
}

string DebugString() override { return "CacheDatasetOp::Dataset"; }
string DebugString() override { return "CacheDatasetOp::FileDataset"; }

private:
static size_t StringPaddingSize(size_t num_tensors) {
Expand All @@ -99,15 +93,16 @@ class CacheDatasetOp : public OpKernel {
tensor_index);
}

// WriterIterator passes through and caches items from the input dataset.
// FileWriterIterator passes through and caches items from the input
// FileDataset.
//
// This iterator is used when the cache directory is not found on disk. It
// creates the cache directory, and passes on the underlying iterator's
// elements.
class WriterIterator : public DatasetIterator<Dataset> {
class FileWriterIterator : public DatasetIterator<FileDataset> {
public:
explicit WriterIterator(const Dataset* dataset)
: DatasetIterator<Dataset>(dataset),
explicit FileWriterIterator(const FileDataset* dataset)
: DatasetIterator<FileDataset>(dataset),
cur_index_(0),
input_impl_(dataset->input_->MakeIterator()),
writer_(dataset->env_, dataset->filename_),
Expand Down Expand Up @@ -207,12 +202,12 @@ class CacheDatasetOp : public OpKernel {
const string lockfile_;
bool lockfile_created_ GUARDED_BY(mu_);
bool iteration_completed_ GUARDED_BY(mu_);
}; // WriterIterator
}; // FileWriterIterator

class ReaderIterator : public DatasetIterator<Dataset> {
class FileReaderIterator : public DatasetIterator<FileDataset> {
public:
explicit ReaderIterator(const Dataset* dataset)
: DatasetIterator<Dataset>(dataset),
explicit FileReaderIterator(const FileDataset* dataset)
: DatasetIterator<FileDataset>(dataset),
cur_index_(0),
reader_(dataset->env_, dataset->filename_) {}

Expand Down Expand Up @@ -249,7 +244,7 @@ class CacheDatasetOp : public OpKernel {
mutex mu_;
size_t cur_index_ GUARDED_BY(mu_);
BundleReader reader_ GUARDED_BY(mu_);
}; // ReaderIterator
}; // FileReaderIterator

const DatasetBase* const input_;
const string filename_;
Expand All @@ -259,7 +254,128 @@ class CacheDatasetOp : public OpKernel {
static const size_t kMaxItems = 10000000; // 10 million
const size_t item_index_padding_size_;
const string tensor_format_string_;
}; // Dataset
}; // FileDataset

class MemoryDataset : public DatasetBase {
public:
explicit MemoryDataset(const DatasetBase* input) : input_(input) {
input->Ref();
}

~MemoryDataset() override { input_->Unref(); }

std::unique_ptr<IteratorBase> MakeIterator() const override {
mutex_lock l(mu_);
if (cache_) {
return std::unique_ptr<IteratorBase>(
new MemoryReaderIterator(this, cache_.get()));
}
if (!writer_iterator_created_) {
writer_iterator_created_ = true;
return std::unique_ptr<IteratorBase>(new MemoryWriterIterator(this));
}
return std::unique_ptr<IteratorBase>(new DuplicateWriterIterator(this));
}

const DataTypeVector& output_dtypes() const override {
return input_->output_dtypes();
}

const std::vector<PartialTensorShape>& output_shapes() const override {
return input_->output_shapes();
}

string DebugString() override { return "CacheDatasetOp::MemoryDataset"; }

private:
// MemoryWriterIterator passes through and appends items from the input
// dataset to its vector.
//
// This iterator is used when dataset->cache_ is null. After buffering
// the tensors in memory, upon exhausing the underlying iterator, they are
// updated into the parent dataset's cache_ pointer.
class MemoryWriterIterator : public DatasetIterator<MemoryDataset> {
public:
explicit MemoryWriterIterator(const MemoryDataset* dataset)
: DatasetIterator<MemoryDataset>(dataset),
input_impl_(dataset->input_->MakeIterator()),
cache_(new std::vector<std::vector<Tensor>>) {}

Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
if (*end_of_sequence) {
// Guard on cache_ to not crash if GetNext is called a second time
// after *end_of_sequence == true
if (cache_) {
mutex_lock l2(dataset()->mu_);
DCHECK(dataset()->writer_iterator_created_);
DCHECK(!dataset()->cache_);
cache_.swap(dataset()->cache_);
}
return Status::OK();
}
cache_->emplace_back(*out_tensors);
return Status::OK();
}

private:
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
std::unique_ptr<std::vector<std::vector<Tensor>>> cache_ GUARDED_BY(mu_);
}; // MemoryWriterIterator

class MemoryReaderIterator : public DatasetIterator<MemoryDataset> {
public:
explicit MemoryReaderIterator(
const MemoryDataset* dataset,
const std::vector<std::vector<Tensor>>* cache)
: DatasetIterator<MemoryDataset>(dataset), cache_(cache), index_(0) {
CHECK(cache);
}

Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
if (index_ < cache_->size()) {
const std::vector<Tensor>& cache_tensors = (*cache_)[index_];
out_tensors->insert(out_tensors->begin(), cache_tensors.begin(),
cache_tensors.end());
index_++;
*end_of_sequence = false;
return Status::OK();
} else {
*end_of_sequence = true;
return Status::OK();
}
}

private:
mutex mu_;
const std::vector<std::vector<Tensor>>* const cache_;
size_t index_ GUARDED_BY(mu_);
}; // MemoryReaderIterator

class DuplicateWriterIterator : public DatasetIterator<MemoryDataset> {
public:
explicit DuplicateWriterIterator(const MemoryDataset* dataset)
: DatasetIterator<MemoryDataset>(dataset) {}

Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
return errors::AlreadyExists(
"There appears to be a concurrent caching iterator running.");
}
}; // DuplicateWriterIterator

const DatasetBase* const input_;
mutable mutex mu_;
mutable std::unique_ptr<std::vector<std::vector<Tensor>>> cache_
GUARDED_BY(mu_);
mutable bool writer_iterator_created_ GUARDED_BY(mu_) = false;
}; // MemoryDataset
}; // CacheDatasetOp

REGISTER_KERNEL_BUILDER(Name("CacheDataset").Device(DEVICE_CPU),
Expand Down

0 comments on commit 61a46ce

Please sign in to comment.