Skip to content

Commit

Permalink
[RLlib] Clean up some deprecation messages (they shouldn't be there) …
Browse files Browse the repository at this point in the history
…and make others `error=True` (from `error=False`) (ray-project#38555)
  • Loading branch information
sven1977 committed Aug 28, 2023
1 parent 09c2cba commit 0f62ccc
Show file tree
Hide file tree
Showing 29 changed files with 68 additions and 366 deletions.
29 changes: 12 additions & 17 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2688,7 +2688,7 @@ def is_atari(self) -> bool:
# We do NOT attempt to auto-detect Atari env for other specified types like
# a callable, to avoid running heavy logics in validate().
# For these cases, users can explicitly set `environment(atari=True)`.
if not type(self.env) == str:
if type(self.env) is not str:
return False
try:
env = gym.make(self.env)
Expand Down Expand Up @@ -3039,13 +3039,15 @@ def get_multi_agent_setup(

# If container given, construct a simple default callable returning True
# if the PolicyID is found in the list/set of IDs.
is_policy_to_train = self.policies_to_train
if self.policies_to_train is not None and not callable(self.policies_to_train):
pols = set(self.policies_to_train)

def is_policy_to_train(pid, batch=None):
return pid in pols

else:
is_policy_to_train = self.policies_to_train

return policies, is_policy_to_train

# TODO: Move this to those algorithms that really need this, which is currently
Expand Down Expand Up @@ -3363,7 +3365,6 @@ def get_marl_module_spec(
return marl_module_spec

def get_learner_group_config(self, module_spec: ModuleSpec) -> LearnerGroupConfig:

if not self._is_frozen:
raise ValueError(
"Cannot call `get_learner_group_config()` on an unfrozen "
Expand Down Expand Up @@ -3697,21 +3698,15 @@ def _resolve_tf_settings(self, _tf1, _tfv):
)

@property
@Deprecated(error=False)
@Deprecated(
old="AlgorithmConfig.multiagent['[some key]']",
new="AlgorithmConfig.[some key]",
error=True,
)
def multiagent(self):
"""Shim method to help pretend we are a dict with 'multiagent' key."""
return {
"policies": self.policies,
"policy_mapping_fn": self.policy_mapping_fn,
"policies_to_train": self.policies_to_train,
"policy_map_capacity": self.policy_map_capacity,
"policy_map_cache": self.policy_map_cache,
"count_steps_by": self.count_steps_by,
"observation_fn": self.observation_fn,
}
pass

@property
@Deprecated(new="AlgorithmConfig.rollouts(num_rollout_workers=..)", error=False)
@Deprecated(new="AlgorithmConfig.rollouts(num_rollout_workers=..)", error=True)
def num_workers(self):
"""For backward-compatibility purposes only."""
return self.num_rollout_workers
pass
4 changes: 2 additions & 2 deletions rllib/algorithms/alpha_star/alpha_star.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def default_resource_request(
# RolloutWorkers (no GPUs).
"CPU": cf.num_cpus_per_worker,
}
for _ in range(cf.num_workers)
for _ in range(cf.num_rollout_workers)
]
+ [
{
Expand Down Expand Up @@ -453,7 +453,7 @@ def training_step(self) -> ResultDict:
sample_results = self._sampling_actor_manager.get_ready()
# Update sample counters.
for sample_result in sample_results.values():
for (env_steps, agent_steps) in sample_result:
for env_steps, agent_steps in sample_result:
self._counters[NUM_ENV_STEPS_SAMPLED] += env_steps
self._counters[NUM_AGENT_STEPS_SAMPLED] += agent_steps

Expand Down
2 changes: 0 additions & 2 deletions rllib/algorithms/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,6 @@ def on_learn_on_batch(

@override(DefaultCallbacks)
def on_train_result(self, *, algorithm=None, result: dict, **kwargs) -> None:

for callback in self._callback_list:
callback.on_train_result(algorithm=algorithm, result=result, **kwargs)

Expand All @@ -664,7 +663,6 @@ def on_train_result(self, *, algorithm=None, result: dict, **kwargs) -> None:

# This Callback is used by the RE3 exploration strategy.
# See rllib/examples/re3_exploration.py for details.
@Deprecated(error=False)
class RE3UpdateCallbacks(DefaultCallbacks):
"""Update input callbacks to mutate batch with states entropy rewards."""

Expand Down
4 changes: 2 additions & 2 deletions rllib/env/wrappers/atari_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from typing import Union

from ray.rllib.utils.annotations import Deprecated, PublicAPI
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.images import rgb2gray, resize


Expand Down Expand Up @@ -257,7 +257,7 @@ def observation(self, frame):
return frame[:, :, None]


@Deprecated(error=False)
@PublicAPI
class FrameStack(gym.Wrapper):
def __init__(self, env, k):
"""Stack k last frames."""
Expand Down
4 changes: 2 additions & 2 deletions rllib/execution/buffers/mixin_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,6 @@ def get_host(self) -> str:
"""
return platform.node()

@Deprecated(new="MixInMultiAgentReplayBuffer.add()", error=False)
@Deprecated(new="MixInMultiAgentReplayBuffer.add()", error=True)
def add_batch(self, *args, **kwargs):
return self.add(*args, **kwargs)
pass
3 changes: 1 addition & 2 deletions rllib/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
override,
ExperimentalAPI,
)
from ray.rllib.utils.deprecation import deprecation_warning, Deprecated
from ray.rllib.utils.deprecation import deprecation_warning
from ray.util import log_once


Expand Down Expand Up @@ -207,7 +207,6 @@ def _update_outputs_and_next_state(
return outputs, next_state


@Deprecated(error=False)
class Model(RecurrentModel):
"""A RecurrentModel made non-recurrent by ignoring
the input/output states.
Expand Down
6 changes: 1 addition & 5 deletions rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from ray.rllib.utils.deprecation import (
DEPRECATED_VALUE,
deprecation_warning,
Deprecated,
)
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.framework import try_import_tf, try_import_torch
Expand Down Expand Up @@ -202,7 +201,7 @@
# fmt: on


@Deprecated(old="rllib.models.catalog.ModelCatalog", error=False)
@DeveloperAPI
class ModelCatalog:
"""Registry of models, preprocessors, and action distributions for envs.
Expand Down Expand Up @@ -655,7 +654,6 @@ def track_var_creation(next_creator, **kw):
raise ValueError("ModelV2 class could not be determined!")

if model_config.get("use_lstm") or model_config.get("use_attention"):

from ray.rllib.models.tf.attention_net import (
AttentionWrapper,
)
Expand Down Expand Up @@ -702,7 +700,6 @@ def track_var_creation(next_creator, **kw):
raise ValueError("ModelV2 class could not be determined!")

if model_config.get("use_lstm") or model_config.get("use_attention"):

from ray.rllib.models.torch.attention_net import AttentionWrapper
from ray.rllib.models.torch.recurrent_net import LSTMWrapper

Expand Down Expand Up @@ -848,7 +845,6 @@ class wrapper(model_interface, model_cls):
def _get_v2_model_class(
input_space: gym.Space, model_config: ModelConfigDict, framework: str = "tf"
) -> Type[ModelV2]:

VisionNet = None
ComplexNet = None

Expand Down
2 changes: 0 additions & 2 deletions rllib/models/tf/noop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.deprecation import Deprecated

_, tf, _ = try_import_tf()


@Deprecated(error=False)
class NoopModel(TFModelV2):
"""Trivial model that just returns the obs flattened.
Expand Down
4 changes: 2 additions & 2 deletions rllib/models/tf/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# TODO (Kourosh): Find a better hierarchy for the primitives after the POC is done.


@Deprecated(error=False)
@Deprecated(error=True)
class FCNet(tf.keras.Model):
"""A simple fully connected network.
Expand Down Expand Up @@ -49,7 +49,7 @@ def call(self, inputs, training=None, mask=None):
return self.network(inputs)


@Deprecated(error=False)
@Deprecated(error=True)
class IdentityNetwork(tf.keras.Model):
"""A network that returns the input as the output."""

Expand Down
50 changes: 0 additions & 50 deletions rllib/models/tf/tf_action_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.typing import TensorType, List, Union, Tuple, ModelConfigDict
from ray.rllib.utils.deprecation import deprecation_warning
from ray.util import log_once

tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp()
Expand All @@ -25,11 +23,6 @@ class TFActionDistribution(ActionDistribution):

@override(ActionDistribution)
def __init__(self, inputs: List[TensorType], model: ModelV2):
if log_once("tf_action_dist_deprecation"):
deprecation_warning(
old="ray.rllib.models.tf.tf_action_dist.TFActionDistribution",
new="ray.rllib.models.tf.tf_distributions.TfDistribution",
)
super().__init__(inputs, model)
self.sample_op = self._build_sample_op()
self.sampled_action_logp_op = self.logp(self.sample_op)
Expand Down Expand Up @@ -60,11 +53,6 @@ class Categorical(TFActionDistribution):
def __init__(
self, inputs: List[TensorType], model: ModelV2 = None, temperature: float = 1.0
):
if log_once("tf_action_dist_categorical_deprecation"):
deprecation_warning(
old="ray.rllib.models.tf.tf_action_dist.Categorical",
new="ray.rllib.models.tf.tf_distributions.Categorical",
)
assert temperature > 0.0, "Categorical `temperature` must be > 0.0!"
# Allow softmax formula w/ temperature != 1.0:
# Divide inputs by temperature.
Expand Down Expand Up @@ -112,14 +100,6 @@ def required_model_output_shape(action_space, model_config):
@DeveloperAPI
def get_categorical_class_with_temperature(t: float):
"""Categorical distribution class that has customized default temperature."""
if log_once("tf_action_dist_categorical_w_temp_deprecation"):
deprecation_warning(
old=(
"ray.rllib.models.tf.tf_action_dist.get_categorical_class_with"
"_temperature"
),
new="ray.rllib.models.tf.tf_distributions.Categorical",
)

class CategoricalWithTemperature(Categorical):
def __init__(self, inputs, model=None, temperature=t):
Expand All @@ -139,11 +119,6 @@ def __init__(
input_lens: Union[List[int], np.ndarray, Tuple[int, ...]],
action_space=None,
):
if log_once("tf_action_dist_multicat_deprecation"):
deprecation_warning(
old="ray.rllib.models.tf.tf_action_dist.MultiCategorical",
new="ray.rllib.models.tf.tf_distributions.TfMultiCategorical",
)
# skip TFActionDistribution init
ActionDistribution.__init__(self, inputs, model)
self.cats = [
Expand Down Expand Up @@ -247,10 +222,6 @@ def __init__(
action_space: Optional[gym.spaces.MultiDiscrete] = None,
all_slates=None,
):
if log_once("tf_action_dist_slate_multi_categorical_deprecation"):
deprecation_warning(
old="ray.rllib.models.tf.tf_action_dist.SlateMultiCategorical"
)
assert temperature > 0.0, "Categorical `temperature` must be > 0.0!"
# Allow softmax formula w/ temperature != 1.0:
# Divide inputs by temperature.
Expand Down Expand Up @@ -305,8 +276,6 @@ def __init__(
For high temperatures, the expected value approaches a uniform
distribution.
"""
if log_once("tf_action_dist_gumbel_softmax_deprecation"):
deprecation_warning(old="ray.rllib.models.tf.tf_action_dist.GumbelSoftmax")
assert temperature >= 0.0
self.dist = tfp.distributions.RelaxedOneHotCategorical(
temperature=temperature, logits=inputs
Expand Down Expand Up @@ -365,11 +334,6 @@ def __init__(
*,
action_space: Optional[gym.spaces.Space] = None
):
if log_once("tf_action_dist_diag_gaussian_deprecation"):
deprecation_warning(
old="ray.rllib.models.tf.tf_action_dist.DiagGaussian",
new="ray.rllib.models.tf.tf_distributions.TfDiagGaussian",
)
mean, log_std = tf.split(inputs, 2, axis=1)
self.mean = mean
self.log_std = log_std
Expand Down Expand Up @@ -450,10 +414,6 @@ def __init__(
high: The highest possible sampling value
(excluding this value).
"""
if log_once("tf_action_dist_squashed_gaussian_deprecation"):
deprecation_warning(
old="ray.rllib.models.tf.tf_action_dist.SquashedGaussian"
)
assert tfp is not None
mean, log_std = tf.split(inputs, 2, axis=-1)
# Clip `scale` values (coming from NN) to reasonable values.
Expand Down Expand Up @@ -548,8 +508,6 @@ def __init__(
low: float = 0.0,
high: float = 1.0,
):
if log_once("tf_action_dist_beta_deprecation"):
deprecation_warning(old="ray.rllib.models.tf.tf_action_dist.Beta")
# Stabilize input parameters (possibly coming from a linear layer).
inputs = tf.clip_by_value(inputs, log(SMALL_NUMBER), -log(SMALL_NUMBER))
inputs = tf.math.log(tf.math.exp(inputs) + 1.0) + 1.0
Expand Down Expand Up @@ -627,12 +585,6 @@ class MultiActionDistribution(TFActionDistribution):
def __init__(
self, inputs, model, *, child_distributions, input_lens, action_space, **kwargs
):
if log_once("tf_action_dist_multi_action_deprecation"):
deprecation_warning(
old="ray.rllib.models.tf.tf_action_dist.MultiActionDistribution",
new="ray.rllib.models.tf.tf_distributions.TfMultiDistribution",
)

ActionDistribution.__init__(self, inputs, model)

self.action_space_struct = get_base_struct_from_space(action_space)
Expand Down Expand Up @@ -741,8 +693,6 @@ def __init__(self, inputs: List[TensorType], model: ModelV2):
See issue #4440 for more details.
"""
if log_once("tf_action_dist_dirichlet_deprecation"):
deprecation_warning(old="ray.rllib.models.tf.tf_action_dist.Dirichlet")
self.epsilon = 1e-7
concentration = tf.exp(inputs) + self.epsilon
self.dist = tf1.distributions.Dirichlet(
Expand Down
4 changes: 0 additions & 4 deletions rllib/models/torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
)
from ray.rllib.models.temp_spec_classes import TensorDict, ModelConfig
from ray.rllib.models.base_model import RecurrentModel, Model, ModelIO
from ray.rllib.utils.deprecation import Deprecated


@Deprecated(error=False)
class TorchModelIO(ModelIO):
"""Save/Load mixin for torch models
Expand Down Expand Up @@ -42,7 +40,6 @@ def load(self, path: str) -> RecurrentModel:
self.load_state_dict(torch.load(path))


@Deprecated(error=False)
class TorchRecurrentModel(RecurrentModel, nn.Module, TorchModelIO):
"""The base class for recurrent pytorch models.
Expand Down Expand Up @@ -154,7 +151,6 @@ def _initial_state(self) -> TensorDict:
)


@Deprecated(error=False)
class TorchModel(Model, nn.Module, TorchModelIO):
"""The base class for non-recurrent pytorch models.
Expand Down
2 changes: 0 additions & 2 deletions rllib/models/torch/noop.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated


@Deprecated(error=False)
class TorchNoopModel(TorchModelV2):
"""Trivial model that just returns the obs flattened.
Expand Down
Loading

0 comments on commit 0f62ccc

Please sign in to comment.