Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CI] Fix Brax 0.9.0 #1011

Merged
merged 7 commits into from
Apr 6, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
tmp
  • Loading branch information
vmoens committed Apr 3, 2023
commit 53fdbfe581cf218244a5b83d9f24e703fff9070d
41 changes: 22 additions & 19 deletions torchrl/envs/libs/brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,44 +342,47 @@ def __repr__(self) -> str:

class _BraxEnvStep(torch.autograd.Function):
@staticmethod
def forward(ctx, env: BraxWrapper, state, action, *qp_values):
def forward(ctx, env: BraxWrapper, state_td, action_tensor, *qp_values):

# convert tensors to ndarrays
state = _tensordict_to_object(state, env._state_example)
action = _tensor_to_ndarray(action)
print(state_td)
state_obj = _tensordict_to_object(state_td, env._state_example)
action_nd = _tensor_to_ndarray(action_tensor)

# flatten batch size
state = _tree_flatten(state, env.batch_size)
action = _tree_flatten(action, env.batch_size)
state = _tree_flatten(state_obj, env.batch_size)
action = _tree_flatten(action_nd, env.batch_size)

# call vjp with jit and vmap
next_state, vjp_fn = jax.vjp(env._vmap_jit_env_step, state, action)

# reshape batch size
next_state = _tree_reshape(next_state, env.batch_size)
next_state_reshape = _tree_reshape(next_state, env.batch_size)

# convert ndarrays to tensors
next_state = _object_to_tensordict(
next_state, device=env.device, batch_size=env.batch_size
next_state_tensor = _object_to_tensordict(
next_state_reshape, device=env.device, batch_size=env.batch_size
)

# save context
ctx.vjp_fn = vjp_fn
ctx.next_state = next_state
ctx.next_state = next_state_tensor
ctx.env = env

return (
next_state, # no gradient
next_state["obs"],
next_state["reward"],
*next_state["pipeline_state"].values(),
next_state_tensor, # no gradient
next_state_tensor["obs"],
next_state_tensor["reward"],
*next_state_tensor["pipeline_state"].values(),
)

@staticmethod
def backward(ctx, _, grad_next_obs, grad_next_reward, *grad_next_qp_values):

# build gradient tensordict with zeros in fields with no grad
grad_next_state = TensorDict(
if grad_next_reward is None:
grad_next_reward = torch.zeros((*ctx.env.batch_size, 1), device=ctx.env.device)
grad_next_state_td = TensorDict(
source={
"pipeline_state": dict(
zip(ctx.next_state["pipeline_state"].keys(), grad_next_qp_values)
Expand All @@ -396,17 +399,17 @@ def backward(ctx, _, grad_next_obs, grad_next_reward, *grad_next_qp_values):
},
device=ctx.env.device,
batch_size=ctx.env.batch_size,
_run_checks=False,
# _run_checks=False,
)

print(grad_next_state_td)
# convert tensors to ndarrays
grad_next_state = _tensordict_to_object(grad_next_state, ctx.env._state_example)
grad_next_state_obj = _tensordict_to_object(grad_next_state_td, ctx.env._state_example)

# flatten batch size
grad_next_state = _tree_flatten(grad_next_state, ctx.env.batch_size)
grad_next_state_flat = _tree_flatten(grad_next_state_obj, ctx.env.batch_size)

# call vjp to get gradients
grad_state, grad_action = ctx.vjp_fn(grad_next_state)
grad_state, grad_action = ctx.vjp_fn(grad_next_state_flat)

# reshape batch size
grad_state = _tree_reshape(grad_state, ctx.env.batch_size)
Expand Down