Repository for environment encoder.
The idea is to train an RL agent to play games given a multimodal model's embeddings, instead of the actual environment. I believe embeddings are a more universal representation, and if learned on it might make it easier for an agent to generalize to unseen environments.
Run the following to install dependencies
apt-get install swig
conda env create -f conda_env.yml
pip install flash-attn --no-build-isolation
- Setup RL training code
- Setup code to extract VLM embeddings from Gym environment frames.
- Train RL agent on embeddings instead of game frames.
- Improve efficiency of generating VLM embeddings through caching and batching.
- Setup code to autotune hyperparameters for RL agent.
- Train autoencoder on Atari game frame VLM embeddings.
- Replace embeddings input to RL agent with autoencoder dense representation input.
- Train RL agent on autoencoder dense representation.
- Show improved performance / generalizability
Big thanks to CleanRL for providing the basis of the RL training code and hyperparameter tuning code, made my life a lot easier :)