Skip to content

CUDA implementation of Generalized Advantage Estimation (GAE)

Notifications You must be signed in to change notification settings

garrett4wade/cugae

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CUDA Implementation of Generalized Advantage Estimation (GAE)

Introduction

Generalized Advantage Estimation (GAE) is widely used in RL, especially PPO. The computation of GAE involves a for-loop to iterate over the entire trajectory, which is expensive in Python and may become the training bottleneck.

This repository provides a simple implementation of GAE in CUDA, which can achive at most 2000x higher throughput than Python implementation.

Usage

Installation requires a CUDA-enabled GPU with nvcc and torch installed.

git clone https://github.com/garrett4wade/cugae
cd cugae && pip3 install -e .

After installation, run pytest -q -s test_cugae.py to run tests and validate your installation.

See cugae.py for detailed documentation of each implemented function.

Benchmark Results

This benchmark is performed using Python 3.10.12, CUDA 12.2 in WSL2 Unbuntu 22.04 on a laptop with Intel i7-12700H CPU and Nvidia 3070 GPU.

Benchmark Results

About

CUDA implementation of Generalized Advantage Estimation (GAE)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published