Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Unify TensorSpecs to a single framework-agnostic class #34493

Merged
merged 15 commits into from
Apr 24, 2023

Conversation

ArturNiederfahrenhorst
Copy link
Contributor

Why are these changes needed?

This PR makes it so that we can use a single TensorSpec class for all spec checks and can simply pass the desired framework if the checking should include the framework.
Downside: Makes it less easy to extend TensorSpec to a new framework.
Upside: We almost never have to do this and can save lots of code that currently separates between specs of different frameworks - and even more of that in the future.

…ntly expanding tests to jax

Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
_, tf, _ = try_import_tf()
self._expected_type = tf.Tensor

def _full(cls, shape, fill_value=0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't need to be a inner function. Lets change that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, thanks for catching it!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe make it a non member function.

Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
return jax.numpy.ndarray
elif self._framework is None:
# Don't restrict the type of the tensor if no framework is specified.
return object
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this PR, we'd extend this base class. Which made adding new frameworks very easy.
Now the pattern to add a framework to this (which I hope we will more or less never have to do) is to change this class.


shape = self.get_shape(tensor)
shape = tuple(tensor.shape)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was inlined because it is the same for all frameworks we have.

if len(shape) != len(self._expected_shape):
raise ValueError(_INVALID_SHAPE.format(self._expected_shape, shape))

for expected_d, actual_d in zip(self._expected_shape, shape):
if isinstance(expected_d, int) and expected_d != actual_d:
raise ValueError(_INVALID_SHAPE.format(self._expected_shape, shape))

dtype = self.get_dtype(tensor)
dtype = tensor.dtype
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also inlined because it is the same for all frameworks.

Copy link
Contributor

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a couple of qs?

if len(shape) != len(self._expected_shape):
raise ValueError(_INVALID_SHAPE.format(self._expected_shape, shape))

for expected_d, actual_d in zip(self._expected_shape, shape):
if isinstance(expected_d, int) and expected_d != actual_d:
raise ValueError(_INVALID_SHAPE.format(self._expected_shape, shape))

dtype = self.get_dtype(tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we still keep these abstractions? we can implement them by default but the base class would still be extendible. So we can have best of both worlds.

spec = SpecDict(
{
SampleBatch.ACTION_DIST: Distribution,
SampleBatch.ACTION_LOGP: TensorSpec("b"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we not set the framework here?

x = spec_class("b,h1,h2,h3", h1=2, h2=3, h3=3, dtype=double_type).fill(2)
x = TensorSpec(
"b,h1,h2,h3", h1=2, h2=3, h3=3, framework=fw, dtype=double_type
).fill(2)
self.assertEqual(x.shape, (1, 2, 3, 3))
self.assertEqual(x.dtype, double_type)

# def test_validation(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, should we remove these or enable them? I don't remember why these are commented?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test was broken.
tf2 returns Dimension objects in when you call tensor.shape.
That broke the test.
I fixed it and reenabled.

@sven1977 sven1977 merged commit 1531823 into ray-project:master Apr 24, 2023
krfricke added a commit that referenced this pull request Apr 27, 2023
#34493 unified the tenosr specifications in rllib, but it missed a doc test, that has been failing since. This PR updates this doc test to use the new API.

Signed-off-by: Kai Fricke <[email protected]>
ProjectsByJackHe pushed a commit to ProjectsByJackHe/ray that referenced this pull request May 4, 2023
ProjectsByJackHe pushed a commit to ProjectsByJackHe/ray that referenced this pull request May 4, 2023
ray-project#34493 unified the tenosr specifications in rllib, but it missed a doc test, that has been failing since. This PR updates this doc test to use the new API.

Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Jack He <[email protected]>
architkulkarni pushed a commit to architkulkarni/ray that referenced this pull request May 16, 2023
ray-project#34493 unified the tenosr specifications in rllib, but it missed a doc test, that has been failing since. This PR updates this doc test to use the new API.

Signed-off-by: Kai Fricke <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants