Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] make tune.with_parameters() work with the class API #14532

Merged
merged 2 commits into from
Mar 9, 2021

Conversation

krfricke
Copy link
Contributor

@krfricke krfricke commented Mar 8, 2021

Why are these changes needed?

tune.with_parameters() currently only works with function trainables, but it is actually a nice way to pass data to class trainables, too.

This questions came up several times in the past, e.g. https://discuss.ray.io/t/improper-run-not-string-nor-trainable/1135/8

Related issue number

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Comment on lines +300 to +312
if inspect.isclass(trainable):
# Class trainable
keys = list(kwargs.keys())

class _Inner(trainable):
def setup(self, config):
setup_kwargs = {}
for k in keys:
setup_kwargs[k] = parameter_registry.get(prefix + k)
super(_Inner, self).setup(config, **setup_kwargs)

_Inner.__name__ = trainable_name
return _Inner
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main change here

Comment on lines +313 to +345
else:
# Function trainable
use_checkpoint = detect_checkpoint_function(trainable, partial=True)
keys = list(kwargs.keys())

def inner(config, checkpoint_dir=None):
fn_kwargs = {}
if use_checkpoint:
default = checkpoint_dir
sig = inspect.signature(trainable)
if "checkpoint_dir" in sig.parameters:
default = sig.parameters["checkpoint_dir"].default \
or default
fn_kwargs["checkpoint_dir"] = default

for k in keys:
fn_kwargs[k] = parameter_registry.get(prefix + k)
trainable(config, **fn_kwargs)

inner.__name__ = trainable_name

# Use correct function signature if no `checkpoint_dir` parameter
# is set
if not use_checkpoint:

def _inner(config):
inner(config, checkpoint_dir=None)

_inner.__name__ = trainable_name

if hasattr(trainable, "__mixins__"):
_inner.__mixins__ = trainable.__mixins__
return _inner
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just copy/pasted from the old file

@krfricke krfricke merged commit 43e0984 into ray-project:master Mar 9, 2021
@krfricke krfricke deleted the class-api-parameters branch March 9, 2021 08:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants