Skip to content

PyTorch implementation of Deep Reinforcement Learning: Policy Gradient methods (TRPO, PPO, A2C) and Generative Adversarial Imitation Learning (GAIL). Fast Fisher vector product TRPO.

Notifications You must be signed in to change notification settings

gjzheng93/PyTorch-RL

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

94 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch implementation of reinforcement learning algorithms

This repository contains:

  1. policy gradient methods (TRPO, PPO, A2C)
  2. Generative Adversarial Imitation Learning (GAIL)

Important notes

  • The code now works for PyTorch 0.4. For PyTorch 0.3, please check out the 0.3 branch.
  • To run mujoco environments, first install mujoco-py and gym.
  • If you have a GPU, I recommend setting the OMP_NUM_THREADS to 1 (PyTorch will create additional threads when performing computations which can damage the performance of multiprocessing. This problem is most serious with Linux, where multiprocessing can be even slower than a single thread):
export OMP_NUM_THREADS=1

Features

  • Support CUDA. (x10 faster than CPU implementation)
  • Support discrete and continous action space.
  • Support multiprocessing for agent to collect samples in multiple environments simultaneously. (x8 faster than single thread)
  • Fast Fisher vector product calculation. For this part, Ankur kindly wrote a blog explaining the implementation details.

Policy gradient methods

Example

  • python examples/ppo_gym.py --env-name Hopper-v2

Reference

Generative Adversarial Imitation Learning (GAIL)

To save trajectory

  • python gail/save_expert_traj.py --model-path assets/expert_traj/Hopper-v2_ppo.p

To do imitation learning

  • python gail/gail_gym.py --env-name Hopper-v2 --expert-traj-path assets/expert_traj/Hopper-v2_expert_traj.p

About

PyTorch implementation of Deep Reinforcement Learning: Policy Gradient methods (TRPO, PPO, A2C) and Generative Adversarial Imitation Learning (GAIL). Fast Fisher vector product TRPO.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%