Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CrossQ #28

Merged
merged 41 commits into from
Apr 3, 2024
Merged

Add CrossQ #28

merged 41 commits into from
Apr 3, 2024

Conversation

araffin
Copy link
Owner

@araffin araffin commented Feb 8, 2024

Description

Implementing https://openreview.net/forum?id=PczQtTsTIX
on top of #21

Discussion in #36

perf report:
https://wandb.ai/openrlbenchmark/sbx/reports/CrossQ-SBX-Perf-Report--Vmlldzo3MzQxOTAw

Motivation and Context

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist:

  • I've read the CONTRIBUTION guide (required)
  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)
  • I have checked that the documentation builds using make doc (required)

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

@araffin araffin mentioned this pull request Mar 29, 2024
14 tasks
@araffin araffin changed the title Feat/crossq Add CrossQ Mar 29, 2024
@araffin araffin marked this pull request as ready for review March 29, 2024 15:35
@araffin
Copy link
Owner Author

araffin commented Mar 29, 2024

@danielpalen after reading the paper, I'm wondering if you have the learning curves for relu6?
or is it similar to SAC - TN + tanh?

sbx/common/jax_layers.py Outdated Show resolved Hide resolved
Comment on lines 155 to 161
if optimizer_kwargs is None:
# Note: the default value for b1 is 0.9 in Adam.
# b1=0.5 is used in the original CrossQ implementation and is found
# but shows only little overall improvement.
optimizer_kwargs = {}
if optimizer_class in [optax.adam, optax.adamw]:
optimizer_kwargs["b1"] = 0.5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, the default value of b1 is set to 0.5 only if no other arguments (or even an empty dict) are passed to the optimizer. It would be cleaner to set the default value to 0.5 regardless of the other optimizer parameters.

Suggested change
if optimizer_kwargs is None:
# Note: the default value for b1 is 0.9 in Adam.
# b1=0.5 is used in the original CrossQ implementation and is found
# but shows only little overall improvement.
optimizer_kwargs = {}
if optimizer_class in [optax.adam, optax.adamw]:
optimizer_kwargs["b1"] = 0.5
if optimizer_kwargs is None:
optimizer_kwargs = {}
if optimizer_class in [optax.adam, optax.adamw] and "b1" not in optimizer_kwargs:
# Note: the default value for b1 is 0.9 in Adam.
# b1=0.5 is used in the original CrossQ implementation but shows only little overall improvement.
optimizer_kwargs["b1"] = 0.5

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep it as is to be consistent with what is done in the rest of SB3.

Comment on lines +9 to +13
PRNGKey = Any
Array = Any
Shape = Tuple[int, ...]
Dtype = Any # this could be a real type?
Axes = Union[int, Sequence[int]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flax v0.8.1 introduced flax.typing, which we could use here for more descriptive type hints, similar to the current version of flax.linen.normalization. However, we should probably wait a bit here since this would require a relatively recent flax version.

@araffin
Copy link
Owner Author

araffin commented Apr 1, 2024

@danielpalen Some early results of DroQ + CrossQ (only 2 random seeds on 3 pybullet envs, need more runs): https://wandb.ai/openrlbenchmark/sbx/reports/DroQ-CrossQ-SBX-Perf-Report--Vmlldzo3MzcxNDUy

I also quickly checked the warmup steps and could see an impact on AntBulletEnv-v0 only when it was too small.

@danielpalen
Copy link
Contributor

@danielpalen after reading the paper, I'm wondering if you have the learning curves for relu6? or is it similar to SAC - TN + tanh?

I quickly checked and it looked pretty similar.

@danielpalen
Copy link
Contributor

@danielpalen Some early results of DroQ + CrossQ (only 2 random seeds on 3 pybullet envs, need more runs): https://wandb.ai/openrlbenchmark/sbx/reports/DroQ-CrossQ-SBX-Perf-Report--Vmlldzo3MzcxNDUy

I have also played around with REDQ/DroQ + CrossQ on MuJoCo but from what I remember, the results were not really consistent, sometimes better, sometimes worse.

I also quickly checked the warmup steps and could see an impact on AntBulletEnv-v0 only when it was too small.

That makes sense. If you go to low you don't have a good estimate for the running statistics yet, so you need to give them enough time to warm up. But the exact time will be environment specific I guess

@araffin
Copy link
Owner Author

araffin commented Apr 2, 2024

I have also played around with REDQ/DroQ + CrossQ on MuJoCo but from what I remember, the results were not really consistent, sometimes better, sometimes worse.

So far, it always improved the results in my case (need more seeds to confirm, I have tried on different pybullet and mujoco envs), or at least to quickly get "good enough" solution (using up to 2x less samples than CrossQ).

One last point in case you missed it (because from #36 (comment)):
@danielpalen would you be interested in providing a PyTorch implementation for SB3 contrib? (https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)

@danielpalen
Copy link
Contributor

One last point in case you missed it (because from #36 (comment)): @danielpalen would you be interested in providing a PyTorch implementation for SB3 contrib? (https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)

Yes, absolutely :) I put it on my todo. But I think I won't be able to get on that right away at the moment.

@araffin araffin merged commit c8db73f into master Apr 3, 2024
@araffin araffin deleted the feat/crossq branch April 3, 2024 10:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants