Skip to content

Commit

Permalink
[RLlib] Remove native Keras Models. (ray-project#30986)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArturNiederfahrenhorst committed Dec 16, 2022
1 parent b96baeb commit acf4e49
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 981 deletions.
14 changes: 14 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,6 +1423,20 @@ def training(
error=True,
)
self.model.update(model)
if (
model.get("_use_default_native_models", DEPRECATED_VALUE)
!= DEPRECATED_VALUE
):
deprecation_warning(
old="AlgorithmConfig.training(_use_default_native_models=True)",
help="_use_default_native_models is not supported "
"anymore. To get rid of this error, set `experimental("
"_enable_rl_module_api` to True. Native models will "
"be better supported by the upcoming RLModule API.",
# Error out if user tries to enable this
error=model["_use_default_native_models"],
)

if optimizer is not NotProvided:
self.optimizer = merge_dicts(self.optimizer, optimizer)
if max_requests_in_flight_per_sampler_worker is not NotProvided:
Expand Down
3 changes: 0 additions & 3 deletions rllib/examples/attention_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,6 @@ def get_cli_args():
num_sgd_iter=10,
vf_loss_coeff=1e-5,
model={
# Attention net wrapping (for tf) can already use the native keras
# model versions. For torch, this will have no effect.
"_use_default_native_models": True,
"use_attention": not args.no_attention,
"max_seq_len": 10,
"attention_num_transformer_units": 1,
Expand Down
79 changes: 18 additions & 61 deletions rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,6 @@
# fmt: off
# __sphinx_doc_begin__
MODEL_DEFAULTS: ModelConfigDict = {
# Experimental flag.
# If True, try to use a native (tf.keras.Model or torch.Module) default
# model instead of our built-in ModelV2 defaults.
# If False (default), use "classic" ModelV2 default models.
# Note that this currently only works for:
# 1) framework != torch AND
# 2) fully connected and CNN default networks as well as
# auto-wrapped LSTM- and attention nets.
"_use_default_native_models": False,
# Experimental flag.
# If True, user specified no preprocessor to be created
# (via config._disable_preprocessor_api=True). If True, observations
Expand Down Expand Up @@ -186,6 +177,9 @@
# Deprecated keys:
# Use `lstm_use_prev_action` or `lstm_use_prev_reward` instead.
"lstm_use_prev_action_reward": DEPRECATED_VALUE,
# Deprecated in anticipation of RLModules API
"_use_default_native_models": DEPRECATED_VALUE,

}
# __sphinx_doc_end__
# fmt: on
Expand Down Expand Up @@ -488,34 +482,20 @@ def get_model_v2(
if model_config.get("use_lstm") or model_config.get("use_attention"):
from ray.rllib.models.tf.attention_net import (
AttentionWrapper,
Keras_AttentionWrapper,
)
from ray.rllib.models.tf.recurrent_net import (
LSTMWrapper,
Keras_LSTMWrapper,
)

wrapped_cls = model_cls
# Wrapped (custom) model is itself a keras Model ->
# wrap with keras LSTM/GTrXL (attention) wrappers.
if issubclass(wrapped_cls, tf.keras.Model):
model_cls = (
Keras_LSTMWrapper
if model_config.get("use_lstm")
else Keras_AttentionWrapper
)
model_config["wrapped_cls"] = wrapped_cls
# Wrapped (custom) model is ModelV2 ->
# wrap with ModelV2 LSTM/GTrXL (attention) wrappers.
else:
forward = wrapped_cls.forward
model_cls = ModelCatalog._wrap_if_needed(
wrapped_cls,
LSTMWrapper
if model_config.get("use_lstm")
else AttentionWrapper,
)
model_cls._wrapped_forward = forward
forward = wrapped_cls.forward
model_cls = ModelCatalog._wrap_if_needed(
wrapped_cls,
LSTMWrapper
if model_config.get("use_lstm")
else AttentionWrapper,
)
model_cls._wrapped_forward = forward

# Obsolete: Track and warn if vars were created but not
# registered. Only still do this, if users do register their
Expand Down Expand Up @@ -666,32 +646,20 @@ def track_var_creation(next_creator, **kw):

from ray.rllib.models.tf.attention_net import (
AttentionWrapper,
Keras_AttentionWrapper,
)
from ray.rllib.models.tf.recurrent_net import (
LSTMWrapper,
Keras_LSTMWrapper,
)

wrapped_cls = v2_class
if model_config.get("use_lstm"):
if issubclass(wrapped_cls, tf.keras.Model):
v2_class = Keras_LSTMWrapper
model_config["wrapped_cls"] = wrapped_cls
else:
v2_class = ModelCatalog._wrap_if_needed(
wrapped_cls, LSTMWrapper
)
v2_class._wrapped_forward = wrapped_cls.forward
v2_class = ModelCatalog._wrap_if_needed(wrapped_cls, LSTMWrapper)
v2_class._wrapped_forward = wrapped_cls.forward
else:
if issubclass(wrapped_cls, tf.keras.Model):
v2_class = Keras_AttentionWrapper
model_config["wrapped_cls"] = wrapped_cls
else:
v2_class = ModelCatalog._wrap_if_needed(
wrapped_cls, AttentionWrapper
)
v2_class._wrapped_forward = wrapped_cls.forward
v2_class = ModelCatalog._wrap_if_needed(
wrapped_cls, AttentionWrapper
)
v2_class._wrapped_forward = wrapped_cls.forward

# Wrap in the requested interface.
wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
Expand Down Expand Up @@ -893,17 +861,13 @@ def _get_v2_model_class(

VisionNet = None
ComplexNet = None
Keras_FCNet = None
Keras_VisionNet = None

if framework in ["tf2", "tf"]:
from ray.rllib.models.tf.fcnet import (
FullyConnectedNetwork as FCNet,
Keras_FullyConnectedNetwork as Keras_FCNet,
)
from ray.rllib.models.tf.visionnet import (
VisionNetwork as VisionNet,
Keras_VisionNetwork as Keras_VisionNet,
)
from ray.rllib.models.tf.complex_input_net import (
ComplexInputNetwork as ComplexNet,
Expand Down Expand Up @@ -932,8 +896,6 @@ def _get_v2_model_class(
if isinstance(input_space, Box) and len(input_space.shape) == 3:
if framework == "jax":
raise NotImplementedError("No non-FC default net for JAX yet!")
elif model_config.get("_use_default_native_models") and Keras_VisionNet:
return Keras_VisionNet
return VisionNet
# `input_space` is 1D Box -> FCNet.
elif (
Expand All @@ -947,12 +909,7 @@ def _get_v2_model_class(
)
)
):
# Keras native requested AND no auto-rnn-wrapping.
if model_config.get("_use_default_native_models") and Keras_FCNet:
return Keras_FCNet
# Classic ModelV2 FCNet.
else:
return FCNet
return FCNet
# Complex (Dict, Tuple, 2D Box (flatten), Discrete, MultiDiscrete).
else:
if framework == "jax":
Expand Down
Loading

0 comments on commit acf4e49

Please sign in to comment.