Skip to content

Commit

Permalink
[Feature] Restructure torchrl/objectives (#580)
Browse files Browse the repository at this point in the history
* restructure torchrl/objectives

* restructure torchrl/objectives

* lint

* fix formatting of file tree in demo notebook

Co-authored-by: Grigory Sizov <[email protected]>
  • Loading branch information
sgrigory and sgrigory committed Oct 18, 2022
1 parent 685c08e commit 002e1e9
Show file tree
Hide file tree
Showing 28 changed files with 2,101 additions and 2,120 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -317,15 +317,15 @@ algorithms. For instance, here's how to code a rollout in TorchRL:

### Loss modules
```python
from torchrl.objectives.costs import DQNLoss
from torchrl.objectives import DQNLoss
loss_module = DQNLoss(value_network=value_network, gamma=0.99)
tensordict = replay_buffer.sample(batch_size)
loss = loss_module(tensordict)
```

### Advantage computation
```python
from torchrl.objectives.returns.functional import vec_td_lambda_return_estimate
from torchrl.objectives.value.functional import vec_td_lambda_return_estimate
advantage = vec_td_lambda_return_estimate(gamma, lmbda, next_state_value, reward, done)
```

Expand Down
2 changes: 1 addition & 1 deletion test/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ def test_imports():
from torchrl.envs import Transform, TransformedEnv # noqa: F401
from torchrl.envs.gym_like import GymLikeEnv # noqa: F401
from torchrl.modules import TensorDictModule # noqa: F401
from torchrl.objectives.costs.common import LossModule # noqa: F401
from torchrl.objectives.common import LossModule # noqa: F401
16 changes: 8 additions & 8 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,22 @@
PPOLoss,
SACLoss,
)
from torchrl.objectives.costs.common import LossModule
from torchrl.objectives.costs.deprecated import (
from torchrl.objectives.common import LossModule
from torchrl.objectives.deprecated import (
DoubleREDQLoss_deprecated,
REDQLoss_deprecated,
)
from torchrl.objectives.costs.redq import REDQLoss
from torchrl.objectives.costs.reinforce import ReinforceLoss
from torchrl.objectives.costs.utils import HardUpdate, hold_out_net, SoftUpdate
from torchrl.objectives.returns.advantages import GAE, TDEstimate, TDLambdaEstimate
from torchrl.objectives.returns.functional import (
from torchrl.objectives.redq import REDQLoss
from torchrl.objectives.reinforce import ReinforceLoss
from torchrl.objectives.utils import HardUpdate, hold_out_net, SoftUpdate
from torchrl.objectives.value.advantages import GAE, TDEstimate, TDLambdaEstimate
from torchrl.objectives.value.functional import (
generalized_advantage_estimate,
td_lambda_advantage_estimate,
vec_generalized_advantage_estimate,
vec_td_lambda_advantage_estimate,
)
from torchrl.objectives.returns.utils import _custom_conv1d, _make_gammas_tensor
from torchrl.objectives.value.utils import _custom_conv1d, _make_gammas_tensor


@pytest.fixture
Expand Down
16 changes: 14 additions & 2 deletions torchrl/objectives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,17 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .costs import *
from .returns import *
from .common import LossModule
from .ddpg import DDPGLoss
from .dqn import DQNLoss, DistributionalDQNLoss
from .ppo import PPOLoss, ClipPPOLoss, KLPENPPOLoss
from .redq import REDQLoss
from .sac import SACLoss
from .utils import (
SoftUpdate,
HardUpdate,
distance_loss,
hold_out_params,
next_state_value,
)
from .value import bellman_max, c_val, dv_val, vtrace, GAE, TDLambdaEstimate, TDEstimate
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

from __future__ import annotations

__all__ = ["LossModule"]

from typing import Iterator, Optional, Tuple, List, Union

import functorch
Expand Down
12 changes: 0 additions & 12 deletions torchrl/objectives/costs/__init__.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from torchrl.data.tensordict.tensordict import TensorDictBase, TensorDict
from torchrl.modules import TensorDictModule
from torchrl.modules.tensordict_module.actors import ActorCriticWrapper
from torchrl.objectives.costs.utils import (
from torchrl.objectives.utils import (
distance_loss,
hold_out_params,
next_state_value,
)
from ...envs.utils import set_exploration_mode
from ..envs.utils import set_exploration_mode
from .common import LossModule


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
next_state_value as get_next_state_value,
distance_loss,
)
from torchrl.objectives.costs.common import LossModule
from torchrl.objectives.common import LossModule


class REDQLoss_deprecated(LossModule):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,10 @@
DistributionalQValueActor,
QValueActor,
)
from ...data.tensordict.tensordict import TensorDictBase
from ..data.tensordict.tensordict import TensorDictBase
from .common import LossModule
from .utils import distance_loss, next_state_value

__all__ = [
"DQNLoss",
"DistributionalDQNLoss",
]


class DQNLoss(LossModule):
"""
Expand Down
File renamed without changes.
7 changes: 2 additions & 5 deletions torchrl/objectives/costs/ppo.py → torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@

from torchrl.data.tensordict.tensordict import TensorDictBase, TensorDict
from torchrl.modules import TensorDictModule
from ...modules.tensordict_module import ProbabilisticTensorDictModule

__all__ = ["PPOLoss", "ClipPPOLoss", "KLPENPPOLoss"]

from torchrl.objectives.costs.utils import distance_loss
from torchrl.objectives.utils import distance_loss
from ..modules.tensordict_module import ProbabilisticTensorDictModule
from .common import LossModule


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@
from torchrl.data.tensordict.tensordict import TensorDictBase, TensorDict
from torchrl.envs.utils import set_exploration_mode, step_mdp
from torchrl.modules import TensorDictModule
from torchrl.objectives.costs.common import LossModule
from torchrl.objectives.costs.utils import (
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import (
distance_loss,
hold_out_params,
next_state_value as get_next_state_value,
)

__all__ = ["REDQLoss"]


class REDQLoss(LossModule):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torchrl.envs.utils import step_mdp
from torchrl.modules import TensorDictModule, ProbabilisticTensorDictModule
from torchrl.objectives import distance_loss
from torchrl.objectives.costs.common import LossModule
from torchrl.objectives.common import LossModule


class ReinforceLoss(LossModule):
Expand Down
7 changes: 2 additions & 5 deletions torchrl/objectives/costs/sac.py → torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,10 @@
from torchrl.modules.tensordict_module.actors import (
ActorCriticWrapper,
)
from torchrl.objectives.costs.utils import distance_loss, next_state_value
from torchrl.objectives.utils import distance_loss, next_state_value
from ..envs.utils import set_exploration_mode
from .common import LossModule

__all__ = ["SACLoss"]

from ...envs.utils import set_exploration_mode


class SACLoss(LossModule):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from torchrl.envs.utils import step_mdp
from torchrl.modules import TensorDictModule

__all__ = ["SoftUpdate", "HardUpdate", "distance_loss", "hold_out_params"]


class _context_manager:
def __init__(self, value=True):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .pg import *
from .returns import *
from .vtrace import *
from .advantages import *
from .advantages import GAE, TDLambdaEstimate, TDEstimate
from .returns import bellman_max
from .vtrace import c_val, dv_val, vtrace
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,14 @@
from torchrl.data.tensordict.tensordict import TensorDictBase
from torchrl.envs.utils import step_mdp
from torchrl.modules import TensorDictModule
from torchrl.objectives.returns.functional import (
from torchrl.objectives.value.functional import (
vec_generalized_advantage_estimate,
td_lambda_advantage_estimate,
vec_td_lambda_advantage_estimate,
)
from ..utils import hold_out_net
from .functional import td_advantage_estimate

__all__ = ["GAE", "TDLambdaEstimate", "TDEstimate"]

from ..costs.utils import hold_out_net


class TDEstimate(nn.Module):
"""Temporal Difference estimate of advantage function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"td_advantage_estimate",
]

from torchrl.objectives.returns.utils import _custom_conv1d, _make_gammas_tensor
from torchrl.objectives.value.utils import _custom_conv1d, _make_gammas_tensor


def generalized_advantage_estimate(
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
10 changes: 5 additions & 5 deletions torchrl/trainers/helpers/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@
SACLoss,
SoftUpdate,
)
from torchrl.objectives.costs.common import LossModule
from torchrl.objectives.costs.deprecated import REDQLoss_deprecated
from torchrl.objectives.common import LossModule
from torchrl.objectives.deprecated import REDQLoss_deprecated

# from torchrl.objectives.costs.redq import REDQLoss
# from torchrl.objectives.redq import REDQLoss

from torchrl.objectives.costs.utils import TargetNetUpdater
from torchrl.objectives.returns.advantages import GAE
from torchrl.objectives.utils import TargetNetUpdater
from torchrl.objectives.value.advantages import GAE


def make_target_updater(
Expand Down
8 changes: 4 additions & 4 deletions torchrl/trainers/helpers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from torchrl.data import ReplayBuffer
from torchrl.envs.common import EnvBase
from torchrl.modules import TensorDictModule, TensorDictModuleWrapper, reset_noise
from torchrl.objectives.costs.common import LossModule
from torchrl.objectives.costs.utils import TargetNetUpdater
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import TargetNetUpdater
from torchrl.trainers.loggers import Logger
from torchrl.trainers.trainers import (
Trainer,
Expand Down Expand Up @@ -116,8 +116,8 @@ def make_trainer(
>>> from torchrl.data import TensorDictReplayBuffer
>>> from torchrl.envs.libs.gym import GymEnv
>>> from torchrl.modules import TensorDictModuleWrapper, TensorDictModule, ValueOperator, EGreedyWrapper
>>> from torchrl.objectives.costs.common import LossModule
>>> from torchrl.objectives.costs.utils import TargetNetUpdater
>>> from torchrl.objectives.common import LossModule
>>> from torchrl.objectives.utils import TargetNetUpdater
>>> from torchrl.objectives import DDPGLoss
>>> env_maker = EnvCreator(lambda: GymEnv("Pendulum-v0"))
>>> env_proof = env_maker()
Expand Down
2 changes: 1 addition & 1 deletion torchrl/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from torchrl.envs.common import EnvBase
from torchrl.envs.utils import set_exploration_mode
from torchrl.modules import TensorDictModule
from torchrl.objectives.costs.common import LossModule
from torchrl.objectives.common import LossModule
from torchrl.trainers.loggers import Logger

REPLAY_BUFFER_CLASS = {
Expand Down
6 changes: 3 additions & 3 deletions tutorials/coding_ddpg.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
" ValueOperator,\n",
")\n",
"from torchrl.modules.distributions.continuous import TanhDelta\n",
"from torchrl.objectives.costs.utils import hold_out_net\n",
"from torchrl.objectives.utils import hold_out_net\n",
"from torchrl.trainers import Recorder\n",
"from torchrl.trainers.helpers.envs import (\n",
" get_stats_random_rollout,\n",
Expand Down Expand Up @@ -895,7 +895,7 @@
"- The value network is designed using the `ValueOperator` TensorDictModule subclass. This class will write a `\"state_action_value\"` if one of its `in_keys` is named \"action\", otherwise it will assume that only the state-value is returned and the output key will simply be `\"state_value\"`. In the case of DDPG, the value if of the state-action pair, hence the first name is used.\n",
"- The `step_mdp` helper function returns a new TensorDict that essentially does the `obs = next_obs`. In other words, it will return a new tensordict where the values that are related to the next state (next observations of various type) are selected and written as if they were current. This makes it possible to pass this new tensordict to the policy or value network (which expects an `\"observation_vector\"` key, not `\"next_observation_vector\"`.\n",
"- When using prioritized replay buffer, a priority key is added to the sampled tensordict (named `\"td_error\"` by default). Then, this TensorDict will be fed back to the replay buffer using the `update_priority` method. Under the hood, this method will read the index present in the TensorDict as well as the priority value, and update its list of priorities at these indices.\n",
"- TorchRL provides optimized versions of the loss functions (such as this one) where one only needs to pass a sampled tensordict and obtains a dictionary of losses and metadata in return (see `torchrl.objectives.costs` for more context). Here we write the full loss function in the optimization loop for transparency. Similarly, the target network updates are written explicitely but TorchRL provides a couple of dedicated classes for this (see `torchrl.objectives.SoftUpdate` and `torchrl.objectives.HardUpdate`).\n",
"- TorchRL provides optimized versions of the loss functions (such as this one) where one only needs to pass a sampled tensordict and obtains a dictionary of losses and metadata in return (see `torchrl.objectives` for more context). Here we write the full loss function in the optimization loop for transparency. Similarly, the target network updates are written explicitely but TorchRL provides a couple of dedicated classes for this (see `torchrl.objectives.SoftUpdate` and `torchrl.objectives.HardUpdate`).\n",
"- After each collection of data, we call `collector.update_policy_weights_()`, which will update the policy network weights on the data collector. If the code is executed on cpu or with a single cuda device, this part can be ommited. If the collector is executed on another device, then its weights must be synced with those on the main, training process and this method should be incorporated in the training loop (ideally early in the loop in async settings, and at the end of it in sync settings)."
]
},
Expand Down Expand Up @@ -1312,7 +1312,7 @@
},
"outputs": [],
"source": [
"from torchrl.objectives.returns.functional import vec_td_lambda_advantage_estimate\n",
"from torchrl.objectives.value.functional import vec_td_lambda_advantage_estimate\n",
"lmbda = 0.95"
]
},
Expand Down
2 changes: 1 addition & 1 deletion tutorials/coding_dqn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@
"outputs": [],
"source": [
"from torchrl.data.tensordict.tensordict import pad\n",
"from torchrl.objectives.returns.functional import vec_td_lambda_advantage_estimate"
"from torchrl.objectives.value.functional import vec_td_lambda_advantage_estimate"
]
},
{
Expand Down
Loading

0 comments on commit 002e1e9

Please sign in to comment.