Skip to content
forked from younggyoseo/MWM

Benchmarking MWM in confounded environments

License

Notifications You must be signed in to change notification settings

SSubhnil/MWM-bench

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Masked World Models for Visual Control

Implementation of the MWM in TensorFlow 2. Our code is based on the implementation of DreamerV2 and APV. We plan to release the raw data we used for reporting our main experimental results in data directory.

Method

Masked World Models (MWM) is a visual model-based RL algorithm that decouples visual representation learning and dynamics learning. The key idea of MWM is to train an autoencoder that reconstructs visual observations with convolutional feature masking, and a latent dynamics model on top of the autoencoder.

overview_figure

Instructions

Get dependencies:

pip install tensorflow==2.6.0 tensorflow_text==2.6.0 tensorflow_estimator==2.6.0 tensorflow_probability==0.14.1 ruamel.yaml 'gym[atari]' dm_control tfimm git+https://github.com/rlworkgroup/metaworld.git@a0009ed9a208ff9864a5c1368c04c273bb20dd06#egg=metaworld

Below are scripts to reproduce our experimental results. It is possible to run experiments with/without early convolution and reward prediction by leveraging mae.reward_pred and mae.early_conv arguments.

Meta-world experiments

TF_XLA_FLAGS=--tf_xla_auto_jit=2 python mwm/train.py --logdir logs --configs metaworld --task metaworld_peg_insert_side --steps 502000 --mae.reward_pred True --mae.early_conv True

RLBench experiments

TF_XLA_FLAGS=--tf_xla_auto_jit=2 python mwm/train.py --logdir logs --configs rlbench --task rlbench_reach_target --steps 502000 --mae.reward_pred True --mae.early_conv True

DeepMind Control Suite experiments

TF_XLA_FLAGS=--tf_xla_auto_jit=2 python mwm/train.py --logdir logs --configs dmc_vision --task dmc_manip_reach_duplo --steps 252000 --mae.reward_pred True --mae.early_conv True

Tips

  • Use TF_XLA_FLAGS=--tf_xla_auto_jit=2 to accelerate the training. This requires properly setting your CUDA and CUDNN paths in our machine. You can check this whether which ptxas gives you a path to the CUDA/bin path in your machine.

  • Also see the tips available in DreamerV2 repository.

About

Benchmarking MWM in confounded environments

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%