Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Algorithm] Discrete CQL #1666

Merged
merged 28 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f71e2e7
init discrete cql objective
BY571 Oct 30, 2023
49d98c2
add converging base version
BY571 Nov 2, 2023
ecc0454
fixes
BY571 Nov 3, 2023
60cf8ec
update example tests
BY571 Nov 3, 2023
97dd57f
cleanup add tests
BY571 Nov 3, 2023
0800132
update loss docstring
BY571 Nov 3, 2023
a0f7189
fix warning
BY571 Nov 3, 2023
6a30a7d
Update examples/cql/discrete_cql_online.py
BY571 Nov 6, 2023
20bbd8a
Update test/test_cost.py
BY571 Nov 6, 2023
c5341c3
Update test/test_cost.py
BY571 Nov 6, 2023
e8847c4
Update torchrl/objectives/cql.py
BY571 Nov 6, 2023
f9427ef
objective fixes
BY571 Nov 6, 2023
dcba00c
Merge branch 'discrete_CQL' of https://github.com/BY571/rl into discr…
BY571 Nov 6, 2023
368d2fb
fix
BY571 Nov 6, 2023
b310858
init
vmoens Nov 6, 2023
8a7f75a
update categorical cql loss case
BY571 Nov 7, 2023
2fd5812
Merge remote-tracking branch 'origin/main' into discrete_CQL
vmoens Nov 7, 2023
e113951
Merge branch 'discrete_CQL' into discrete_CQL_refact
BY571 Nov 8, 2023
bfd7189
Merge pull request #1 from vmoens/discrete_CQL_refact
BY571 Nov 8, 2023
ab92393
fix loss sum
BY571 Nov 8, 2023
9af9f99
example test fixes
BY571 Nov 8, 2023
53ef411
fix categorical action in cql loss
BY571 Nov 8, 2023
dfaf89f
Merge remote-tracking branch 'origin/main' into discrete_CQL
vmoens Nov 8, 2023
6bcf240
lint
vmoens Nov 8, 2023
1bf7d0d
Merge branch 'main' into discrete_CQL
BY571 Nov 9, 2023
4885005
Merge branch 'discrete_CQL' of https://github.com/BY571/rl into discr…
BY571 Nov 9, 2023
054a8cb
add doc
vmoens Nov 10, 2023
9941055
Merge remote-tracking branch 'origin/main' into discrete_CQL
vmoens Nov 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn.py \
record_video=True \
record_frames=4 \
buffer_size=120
python .github/unittest/helpers/coverage_run_parallel.py examples/cql/discrete_cql_online.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
optim.batch_size=10 \
collector.frames_per_batch=16 \
collector.env_per_collector=2 \
collector.device=cuda:0 \
optim.optim_steps_per_batch=1 \
replay_buffer.size=120 \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \
num_workers=4 \
collector.total_frames=48 \
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ CQL
:template: rl_template_noinherit.rst

CQLLoss
DiscreteCQLLoss

DT
----
Expand Down
57 changes: 57 additions & 0 deletions examples/cql/discrete_cql_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Task and env
env:
name: CartPole-v1
task: ""
library: gym
exp_name: cql_cartpole_gym
n_samples_stats: 1000
max_episode_steps: 200
seed: 0

# Collector
collector:
frames_per_batch: 200
total_frames: 20000
multi_step: 0
init_random_frames: 1000
env_per_collector: 1
device: cpu
max_frames_per_traj: 200
annealing_frames: 10000
eps_start: 1.0
eps_end: 0.01
# logger
logger:
backend: wandb
log_interval: 5000 # record interval in frames
eval_steps: 200
mode: online
eval_iter: 1000

# Buffer
replay_buffer:
prb: 0
buffer_prefetch: 64
size: 1_000_000
scratch_dir: ${env.exp_name}_${env.seed}

# Optimization
optim:
utd_ratio: 1
device: cuda:0
lr: 1e-3
weight_decay: 0.0
batch_size: 256
lr_scheduler: ""
optim_steps_per_batch: 200

# Policy and model
model:
hidden_sizes: [256, 256]
activation: relu

# loss
loss:
loss_function: l2
gamma: 0.99
tau: 0.005
199 changes: 199 additions & 0 deletions examples/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Discrete (DQN) CQL Example.

This is a simple self-contained example of a discrete CQL training script.

It supports state environments like gym and gymnasium.

The helper functions are coded in the utils.py associated with this script.
"""

import time

import hydra
import numpy as np
import torch
import torch.cuda
import tqdm

from torchrl.envs.utils import ExplorationType, set_exploration_type

from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
log_metrics,
make_collector,
make_cql_optimizer,
make_discretecql_model,
make_discreteloss,
make_environment,
make_replay_buffer,
)


@hydra.main(version_base="1.1", config_path=".", config_name="discrete_cql_config")
def main(cfg: "DictConfig"): # noqa: F821
device = torch.device(cfg.optim.device)

# Create logger
exp_name = generate_exp_name("DiscreteCQL", cfg.env.exp_name)
logger = None
if cfg.logger.backend:
logger = get_logger(
logger_type=cfg.logger.backend,
logger_name="discretecql_logging",
experiment_name=exp_name,
wandb_kwargs={"mode": cfg.logger.mode, "config": cfg},
)

# Set seeds
torch.manual_seed(cfg.env.seed)
np.random.seed(cfg.env.seed)

# Create environments
train_env, eval_env = make_environment(cfg)

# Create agent
model, explore_policy = make_discretecql_model(cfg, train_env, eval_env, device)

# Create loss
loss_module, target_net_updater = make_discreteloss(cfg.loss, model)

# Create off-policy collector
collector = make_collector(cfg, train_env, explore_policy)

# Create replay buffer
replay_buffer = make_replay_buffer(
batch_size=cfg.optim.batch_size,
prb=cfg.replay_buffer.prb,
buffer_size=cfg.replay_buffer.size,
buffer_scratch_dir=cfg.replay_buffer.scratch_dir,
device="cpu",
)

# Create optimizers
optimizer = make_cql_optimizer(cfg, loss_module)

# Main loop
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

init_random_frames = cfg.collector.init_random_frames
num_updates = int(
cfg.collector.env_per_collector
* cfg.collector.frames_per_batch
* cfg.optim.utd_ratio
)
prb = cfg.replay_buffer.prb
eval_rollout_steps = cfg.env.max_episode_steps
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch

start_time = sampling_start = time.time()
for tensordict in collector:
sampling_time = time.time() - sampling_start

# Update exploration policy
explore_policy[1].step(tensordict.numel())

# Update weights of the inference policy
collector.update_policy_weights_()

pbar.update(tensordict.numel())

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
collected_frames += current_frames

# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
(
q_losses,
cql_losses,
) = ([], [])
for _ in range(num_updates):

# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(
device, non_blocking=True
)
else:
sampled_tensordict = sampled_tensordict.clone()

# Compute loss
loss_dict = loss_module(sampled_tensordict)

q_loss = loss_dict["loss_qvalue"]
cql_loss = loss_dict["loss_cql"]
loss = q_loss + cql_loss

# Update model
optimizer.zero_grad()
loss.backward()
optimizer.step()
q_losses.append(q_loss.item())
cql_losses.append(cql_loss.item())

# Update target params
target_net_updater.step()
# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)

training_time = time.time() - training_start
episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
else tensordict["next", "truncated"]
)
episode_rewards = tensordict["next", "episode_reward"][episode_end]

# Logging
metrics_to_log = {}
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][episode_end]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
metrics_to_log["train/epsilon"] = explore_policy[1].eps

if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = np.mean(q_losses)
metrics_to_log["train/cql_loss"] = np.mean(cql_losses)
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
model,
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time
if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()

collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
print(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
main()
Loading
Loading