-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Unify TensorSpecs to a single framework-agnostic class #34493
[RLlib] Unify TensorSpecs to a single framework-agnostic class #34493
Conversation
…ntly expanding tests to jax Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
_, tf, _ = try_import_tf() | ||
self._expected_type = tf.Tensor | ||
|
||
def _full(cls, shape, fill_value=0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't need to be a inner function. Lets change that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, thanks for catching it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe make it a non member function.
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
return jax.numpy.ndarray | ||
elif self._framework is None: | ||
# Don't restrict the type of the tensor if no framework is specified. | ||
return object |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before this PR, we'd extend this base class. Which made adding new frameworks very easy.
Now the pattern to add a framework to this (which I hope we will more or less never have to do) is to change this class.
|
||
shape = self.get_shape(tensor) | ||
shape = tuple(tensor.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was inlined because it is the same for all frameworks we have.
if len(shape) != len(self._expected_shape): | ||
raise ValueError(_INVALID_SHAPE.format(self._expected_shape, shape)) | ||
|
||
for expected_d, actual_d in zip(self._expected_shape, shape): | ||
if isinstance(expected_d, int) and expected_d != actual_d: | ||
raise ValueError(_INVALID_SHAPE.format(self._expected_shape, shape)) | ||
|
||
dtype = self.get_dtype(tensor) | ||
dtype = tensor.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also inlined because it is the same for all frameworks.
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a couple of qs?
if len(shape) != len(self._expected_shape): | ||
raise ValueError(_INVALID_SHAPE.format(self._expected_shape, shape)) | ||
|
||
for expected_d, actual_d in zip(self._expected_shape, shape): | ||
if isinstance(expected_d, int) and expected_d != actual_d: | ||
raise ValueError(_INVALID_SHAPE.format(self._expected_shape, shape)) | ||
|
||
dtype = self.get_dtype(tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we still keep these abstractions? we can implement them by default but the base class would still be extendible. So we can have best of both worlds.
spec = SpecDict( | ||
{ | ||
SampleBatch.ACTION_DIST: Distribution, | ||
SampleBatch.ACTION_LOGP: TensorSpec("b"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we not set the framework here?
x = spec_class("b,h1,h2,h3", h1=2, h2=3, h3=3, dtype=double_type).fill(2) | ||
x = TensorSpec( | ||
"b,h1,h2,h3", h1=2, h2=3, h3=3, framework=fw, dtype=double_type | ||
).fill(2) | ||
self.assertEqual(x.shape, (1, 2, 3, 3)) | ||
self.assertEqual(x.dtype, double_type) | ||
|
||
# def test_validation(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait, should we remove these or enable them? I don't remember why these are commented?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test was broken.
tf2 returns Dimension
objects in when you call tensor.shape.
That broke the test.
I fixed it and reenabled.
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
#34493 unified the tenosr specifications in rllib, but it missed a doc test, that has been failing since. This PR updates this doc test to use the new API. Signed-off-by: Kai Fricke <[email protected]>
…project#34493) Signed-off-by: Jack He <[email protected]>
ray-project#34493 unified the tenosr specifications in rllib, but it missed a doc test, that has been failing since. This PR updates this doc test to use the new API. Signed-off-by: Kai Fricke <[email protected]> Signed-off-by: Jack He <[email protected]>
ray-project#34493 unified the tenosr specifications in rllib, but it missed a doc test, that has been failing since. This PR updates this doc test to use the new API. Signed-off-by: Kai Fricke <[email protected]>
Why are these changes needed?
This PR makes it so that we can use a single TensorSpec class for all spec checks and can simply pass the desired framework if the checking should include the framework.
Downside: Makes it less easy to extend TensorSpec to a new framework.
Upside: We almost never have to do this and can save lots of code that currently separates between specs of different frameworks - and even more of that in the future.