Skip to content

Commit

Permalink
Support zero shapes for random_poisson. This matches random_uniform.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 159771215
  • Loading branch information
tensorflower-gardener committed Jun 22, 2017
1 parent 70cea91 commit 52581df
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
4 changes: 0 additions & 4 deletions tensorflow/core/kernels/random_poisson_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,6 @@ class RandomPoissonOp : public OpKernel {

const auto rate_flat = rate_t.flat<T>().data();
const int64 num_rate = rate_t.NumElements();
OP_REQUIRES(
ctx, num_rate > 0,
errors::InvalidArgument(
"Input rate should have non-zero element count, got: ", num_rate));
auto samples_flat = samples_t->flat<T>().data();
random::PhiloxRandom rng = generator_.ReserveRandomOutputs(
num_samples * num_rate, kReservedSamplesPerOutput);
Expand Down
8 changes: 7 additions & 1 deletion tensorflow/python/kernel_tests/random_poisson_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,14 @@ def testNoCSE(self):
# be at least 1 if they are different.
self.assertGreaterEqual(np.linalg.norm(diff.eval()), 1)

def testZeroShape(self):
with self.test_session():
rnd = random_ops.random_poisson([], [], seed=12345)
self.assertEqual([0], rnd.get_shape().as_list())
self.assertAllClose(np.array([], dtype=np.float32), rnd.eval())

def testShape(self):
# Fully known shape.
# Fully known shape
rnd = random_ops.random_poisson(2.0, [150], seed=12345)
self.assertEqual([150], rnd.get_shape().as_list())
rnd = random_ops.random_poisson(
Expand Down

0 comments on commit 52581df

Please sign in to comment.