Weidong Huang
Jiaming Ji
Chunhe Xia
Borong Zhang
Yaodong Yang
Beihang University Peking University
The deployment of Reinforcement Learning (RL) in real-world applications is constrained by its failure to satisfy safety criteria. Existing Safe Reinforcement Learning (SafeRL) methods, which rely on cost functions to enforce safety, often fail to achieve zero-cost performance in complex scenarios, especially vision-only tasks. These limitations are primarily due to model inaccuracies and inadequate sample efficiency. The integration of world models has proven effective in mitigating these shortcomings. In this work, we introduce SafeDreamer, a novel algorithm incorporating Lagrangian-based methods into world model planning processes within the superior Dreamer framework. Our method achieves nearly zero-cost performance on various tasks, spanning low-dimensional and vision-only input, within the Safety-Gymnasium benchmark, showcasing its efficacy in balancing performance and safety in RL tasks.
We have also open-sourced over 80+ model checkpoints for 20 tasks. Our codebase supports vector and vision observations. We hope this repository will become a valuable community resource for future research on model-based safe reinforcement learning.
- [2024-04] We have open-sourced the code and 80+ model checkpoints.
- [2024-01] SafeDreamer has been accepted for ICLR 2024.
If you find our work helpful, please cite:
@inproceedings{
safedreamer,
title={SafeDreamer: Safe Reinforcement Learning with World Models},
author={Weidong Huang and Jiaming Ji and Borong Zhang and Chunhe Xia and Yaodong Yang},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=tsE5HLYtYg}
}
git clone https://github.com/PKU-Alignment/SafeDreamer.git
cd SafeDreamer
Due to the strong dependency of JAX on CUDA and cuDNN, it is essential to ensure that the versions are compatible to run the code successfully. Before installing JAX, it is recommended to carefully check the CUDA and cuDNN versions installed on your machine. Here are some methods we provide for checking the versions:
- Checking CUDA version:
- Use the command
nvcc --version
in the terminal to check the installed CUDA version.
- Checking cuDNN version:
- Check the version by examining the file names or metadata in the cuDNN installation directory 'cat /usr/local/cuda/include/cudnn_version.h | grep CUDNN_MAJOR -A 2'.
- Or you can also use torch to check the CUDNN version 'python3 -c 'import torch;cudnn_version = torch.backends.cudnn.version();print(f"CUDNN Version: {cudnn_version}");print(torch.version.cuda)'
It is crucial to ensure that the installed CUDA and cuDNN versions are compatible with the specific version of JAX you intend to install.
Here is some subjections for install jax, the new manipulation should be found in jax documentation. we tested our code in the 0.3.25 version of jax.
conda create -n example python=3.8
conda activate example
pip install --upgrade pip
pip install jax==0.3.25
pip install jax-jumpy==1.0.0
# for gpu
pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# for cpu
pip install jaxlib==0.3.25
pip install -r requirements.txt
git clone https://github.com/PKU-Alignment/safety-gymnasium.git
cd safety-gymnasium
pip install -e .
cd ..
You can download the checkpoint from Hugging Face and then run it locally without training from scratch. If you're looking to see if the code can run correctly, we recommend you download the checkpoints of SafeDreamer(OSRP-Vector), as it has a smaller size:
Algorithm | Size | Checkpoint Link | |
---|---|---|---|
SafeDreamer(BSRP-Lag) | 392MB | Hugging Face | |
SafeDreamer(OSRP-Lag) | 392MB | Hugging Face | |
SafeDreamer(OSRP) | 392MB | Hugging Face | |
SafeDreamer(OSRP-Vector) | 26.6MB | Hugging Face | |
Unsafe-DreamerV3 | 340MB | Hugging Face |
# Background Safety-Reward Planning with Lagrangian (BSRP-Lag):
python SafeDreamer/train.py --configs bsrp_lag --method bsrp_lag --run.script eval_only --run.from_checkpoint /xxx/checkpoint.ckpt --task safetygym_SafetyPointGoal1-v0 --jax.logical_gpus 0 --run.steps 10000
# Online Safety-Reward Planning with Lagrangian (OSRP-Lag):
python SafeDreamer/train.py --configs osrp_lag --method osrp_lag --run.script eval_only --run.from_checkpoint /xxx/checkpoint.ckpt --task safetygym_SafetyPointGoal1-v0 --jax.logical_gpus 0 --run.steps 10000 --pid.init_penalty 0.1
# Online Safety-Reward Planning (OSRP):
python SafeDreamer/train.py --configs osrp --method osrp --run.script eval_only --run.from_checkpoint /xxx/checkpoint.ckpt --task safetygym_SafetyPointGoal1-v0 --jax.logical_gpus 0 --run.steps 10000
# Online Safety-Reward Planning (OSRP) for low-dimensional input:
python SafeDreamer/train.py --configs osrp_vector --method osrp --run.script eval_only --run.from_checkpoint /xxx/checkpoint.ckpt --task safetygymcoor_SafetyPointGoal1-v0 --jax.logical_gpus 0 --run.steps 10000
where checkpoint_path is '/xxx/xxx.ckpt'. If you use cpu, you should change the "--jax.logical_gpus 0" to "--jax.platform cpu".
# For cpu:
python SafeDreamer/train.py --configs osrp --method osrp --task safetygym_SafetyPointGoal1-v0 --jax.platform cpu
# For gpu:
# Online Safety-Reward Planning (OSRP):
python SafeDreamer/train.py --configs osrp --method osrp --task safetygym_SafetyPointGoal1-v0 --jax.logical_gpus 0
# Online Safety-Reward Planning with Lagrangian (OSRP-Lag):
python SafeDreamer/train.py --configs osrp_lag --method osrp_lag --task safetygym_SafetyPointGoal1-v0 --jax.logical_gpus 0
# Background Safety-Reward Planning with Lagrangian (BSRP-Lag):
python SafeDreamer/train.py --configs bsrp_lag --method bsrp_lag --task safetygym_SafetyPointGoal1-v0 --jax.logical_gpus 0
# Online Safety-Reward Planning (OSRP) for low-dimensional input:
python SafeDreamer/train.py --configs osrp_vector --method osrp_vector --task safetygymcoor_SafetyPointGoal1-v0 --jax.logical_gpus 0
- All configuration options are documented in
configs.yaml
, and you have the ability to override them through the command line. - If you encounter CUDA errors, it is recommended to scroll up through the error messages, as the root cause is often an issue that occurred earlier, such as running out of memory or having incompatible versions of JAX and CUDA.
- To customize the GPU memory requirement, you can modify the
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']
variable in thejaxagent.py
. This allows you to adjust the memory allocation according to your specific needs.
SafeDreamer is released under Apache License 2.0.
- DreamerV3: Our codebase is built upon DreamerV3.