Skip to content

Commit

Permalink
[RLlib] Replace "seq_lens" w/ SampleBatch.SEQ_LENS. (ray-project#17928)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 committed Aug 21, 2021
1 parent b969aa3 commit 494ddd9
Show file tree
Hide file tree
Showing 28 changed files with 200 additions and 191 deletions.
4 changes: 2 additions & 2 deletions rllib/agents/a3c/a3c_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def actor_critic_loss(policy: Policy, model: ModelV2,
model_out, _ = model.from_batch(train_batch)
action_dist = dist_class(model_out, model)
if policy.is_recurrent():
max_seq_len = tf.reduce_max(train_batch["seq_lens"])
mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len)
max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS])
mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
mask = tf.reshape(mask, [-1])
else:
mask = tf.ones_like(train_batch[SampleBatch.REWARDS])
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/a3c/a3c_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ def actor_critic_loss(policy: Policy, model: ModelV2,
values = model.value_function()

if policy.is_recurrent():
max_seq_len = torch.max(train_batch["seq_lens"])
mask_orig = sequence_mask(train_batch["seq_lens"], max_seq_len)
max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS])
mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS],
max_seq_len)
valid_mask = torch.reshape(mask_orig, [-1])
else:
valid_mask = torch.ones_like(values, dtype=torch.bool)
Expand Down
7 changes: 4 additions & 3 deletions rllib/agents/dqn/r2d2_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def r2d2_loss(policy: Policy, model, _,
model,
train_batch,
state_batches=state_batches,
seq_lens=train_batch.get("seq_lens"),
seq_lens=train_batch.get(SampleBatch.SEQ_LENS),
explore=False,
is_training=True)

Expand All @@ -91,7 +91,7 @@ def r2d2_loss(policy: Policy, model, _,
policy.target_model,
train_batch,
state_batches=state_batches,
seq_lens=train_batch.get("seq_lens"),
seq_lens=train_batch.get(SampleBatch.SEQ_LENS),
explore=False,
is_training=True)

Expand Down Expand Up @@ -140,7 +140,8 @@ def r2d2_loss(policy: Policy, model, _,
config["gamma"] ** config["n_step"] * q_target_best_masked_tp1

# Seq-mask all loss-related terms.
seq_mask = tf.sequence_mask(train_batch["seq_lens"], T)[:, :-1]
seq_mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS],
T)[:, :-1]
# Mask away also the burn-in sequence at the beginning.
burn_in = policy.config["burn_in"]
# Making sure, this works for both static graph and eager.
Expand Down
6 changes: 3 additions & 3 deletions rllib/agents/dqn/r2d2_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def r2d2_loss(policy: Policy, model, _,
model,
train_batch,
state_batches=state_batches,
seq_lens=train_batch.get("seq_lens"),
seq_lens=train_batch.get(SampleBatch.SEQ_LENS),
explore=False,
is_training=True)

Expand All @@ -100,7 +100,7 @@ def r2d2_loss(policy: Policy, model, _,
target_model,
train_batch,
state_batches=state_batches,
seq_lens=train_batch.get("seq_lens"),
seq_lens=train_batch.get(SampleBatch.SEQ_LENS),
explore=False,
is_training=True)

Expand Down Expand Up @@ -148,7 +148,7 @@ def r2d2_loss(policy: Policy, model, _,
config["gamma"] ** config["n_step"] * q_target_best_masked_tp1

# Seq-mask all loss-related terms.
seq_mask = sequence_mask(train_batch["seq_lens"], T)[:, :-1]
seq_mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], T)[:, :-1]
# Mask away also the burn-in sequence at the beginning.
burn_in = policy.config["burn_in"]
if burn_in > 0 and burn_in < T:
Expand Down
10 changes: 5 additions & 5 deletions rllib/agents/impala/vtrace_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ def build_vtrace_loss(policy, model, dist_class, train_batch):
output_hidden_shape = 1

def make_time_major(*args, **kw):
return _make_time_major(policy, train_batch.get("seq_lens"), *args,
**kw)
return _make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS),
*args, **kw)

actions = train_batch[SampleBatch.ACTIONS]
dones = train_batch[SampleBatch.DONES]
Expand All @@ -181,8 +181,8 @@ def make_time_major(*args, **kw):
values = model.value_function()

if policy.is_recurrent():
max_seq_len = tf.reduce_max(train_batch["seq_lens"])
mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len)
max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS])
mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
mask = tf.reshape(mask, [-1])
else:
mask = tf.ones_like(rewards)
Expand Down Expand Up @@ -223,7 +223,7 @@ def make_time_major(*args, **kw):
def stats(policy, train_batch):
values_batched = _make_time_major(
policy,
train_batch.get("seq_lens"),
train_batch.get(SampleBatch.SEQ_LENS),
policy.model.value_function(),
drop_last=policy.config["vtrace"])

Expand Down
11 changes: 6 additions & 5 deletions rllib/agents/impala/vtrace_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def build_vtrace_loss(policy, model, dist_class, train_batch):
output_hidden_shape = 1

def _make_time_major(*args, **kw):
return make_time_major(policy, train_batch.get("seq_lens"), *args,
**kw)
return make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS),
*args, **kw)

actions = train_batch[SampleBatch.ACTIONS]
dones = train_batch[SampleBatch.DONES]
Expand All @@ -145,8 +145,9 @@ def _make_time_major(*args, **kw):
values = model.value_function()

if policy.is_recurrent():
max_seq_len = torch.max(train_batch["seq_lens"])
mask_orig = sequence_mask(train_batch["seq_lens"], max_seq_len)
max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS])
mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS],
max_seq_len)
mask = torch.reshape(mask_orig, [-1])
else:
mask = torch.ones_like(rewards)
Expand Down Expand Up @@ -186,7 +187,7 @@ def _make_time_major(*args, **kw):
policy.loss = loss
values_batched = make_time_major(
policy,
train_batch.get("seq_lens"),
train_batch.get(SampleBatch.SEQ_LENS),
values,
drop_last=policy.config["vtrace"])
policy._vf_explained_var = explained_variance(
Expand Down
10 changes: 5 additions & 5 deletions rllib/agents/ppo/appo_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def appo_surrogate_loss(

# TODO: (sven) deprecate this when trajectory view API gets activated.
def make_time_major(*args, **kw):
return _make_time_major(policy, train_batch.get("seq_lens"), *args,
**kw)
return _make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS),
*args, **kw)

actions = train_batch[SampleBatch.ACTIONS]
dones = train_batch[SampleBatch.DONES]
Expand All @@ -131,8 +131,8 @@ def make_time_major(*args, **kw):
policy.target_model_vars = policy.target_model.variables()

if policy.is_recurrent():
max_seq_len = tf.reduce_max(train_batch["seq_lens"])
mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len)
max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS])
mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
mask = tf.reshape(mask, [-1])
mask = make_time_major(mask, drop_last=policy.config["vtrace"])

Expand Down Expand Up @@ -282,7 +282,7 @@ def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
"""
values_batched = _make_time_major(
policy,
train_batch.get("seq_lens"),
train_batch.get(SampleBatch.SEQ_LENS),
policy.model.value_function(),
drop_last=policy.config["vtrace"])

Expand Down
10 changes: 5 additions & 5 deletions rllib/agents/ppo/appo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def appo_surrogate_loss(policy: Policy, model: ModelV2,
is_multidiscrete = False
output_hidden_shape = 1

def _make_time_major(*args, **kw):
return make_time_major(policy, train_batch.get("seq_lens"), *args,
**kw)
def _make_time_major(*args, **kwargs):
return make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS),
*args, **kwargs)

actions = train_batch[SampleBatch.ACTIONS]
dones = train_batch[SampleBatch.DONES]
Expand All @@ -85,8 +85,8 @@ def _make_time_major(*args, **kw):
values_time_major = _make_time_major(values)

if policy.is_recurrent():
max_seq_len = torch.max(train_batch["seq_lens"])
mask = sequence_mask(train_batch["seq_lens"], max_seq_len)
max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS])
mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
mask = torch.reshape(mask, [-1])
mask = _make_time_major(mask, drop_last=policy.config["vtrace"])
num_valid = torch.sum(mask)
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/ppo/ppo_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def ppo_surrogate_loss(
# Derive max_seq_len from the data itself, not from the seq_lens
# tensor. This is in case e.g. seq_lens=[2, 3], but the data is still
# 0-padded up to T=5 (as it's the case for attention nets).
B = tf.shape(train_batch["seq_lens"])[0]
B = tf.shape(train_batch[SampleBatch.SEQ_LENS])[0]
max_seq_len = tf.shape(logits)[0] // B

mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len)
mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
mask = tf.reshape(mask, [-1])

def reduce_mean_valid(t):
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/ppo/ppo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def ppo_surrogate_loss(

# RNN case: Mask away 0-padded chunks at end of time axis.
if state:
B = len(train_batch["seq_lens"])
B = len(train_batch[SampleBatch.SEQ_LENS])
max_seq_len = logits.shape[0] // B
mask = sequence_mask(
train_batch["seq_lens"],
train_batch[SampleBatch.SEQ_LENS],
max_seq_len,
time_major=model.is_time_major())
mask = torch.reshape(mask, [-1])
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/sac/rnnsac_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def actor_critic_loss(
state_batches.append(train_batch["state_in_{}".format(i)])
i += 1
assert state_batches
seq_lens = train_batch.get("seq_lens")
seq_lens = train_batch.get(SampleBatch.SEQ_LENS)

model_out_t, state_in_t = model({
"obs": train_batch[SampleBatch.CUR_OBS],
Expand Down Expand Up @@ -343,7 +343,7 @@ def actor_critic_loss(
# BURNIN #
B = state_batches[0].shape[0]
T = q_t_selected.shape[0] // B
seq_mask = sequence_mask(train_batch["seq_lens"], T)
seq_mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], T)
# Mask away also the burn-in sequence at the beginning.
burn_in = policy.config["burn_in"]
if burn_in > 0 and burn_in < T:
Expand Down
16 changes: 8 additions & 8 deletions rllib/evaluation/collectors/simple_list_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def add_init_obs(self, episode_id: EpisodeID, agent_index: int,
episode_id (EpisodeID): Unique ID for the episode we are adding the
initial observation for.
agent_index (int): Unique int index (starting from 0) for the agent
within its episode.
within its episode. Not to be confused with AGENT_ID (Any).
env_id (EnvID): The environment index (in a vectorized setup).
t (int): The time step (episode length - 1). The initial obs has
ts=-1(!), then an action/reward/next-obs at t=0, etc..
Expand All @@ -85,14 +85,14 @@ def add_init_obs(self, episode_id: EpisodeID, agent_index: int,
single_row={
SampleBatch.OBS: init_obs,
SampleBatch.AGENT_INDEX: agent_index,
"env_id": env_id,
"t": t,
SampleBatch.ENV_ID: env_id,
SampleBatch.T: t,
})
self.buffers[SampleBatch.OBS].append(init_obs)
self.episode_id = episode_id
self.buffers[SampleBatch.AGENT_INDEX].append(agent_index)
self.buffers["env_id"].append(env_id)
self.buffers["t"].append(t)
self.buffers[SampleBatch.ENV_ID].append(env_id)
self.buffers[SampleBatch.T].append(t)

def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \
None:
Expand Down Expand Up @@ -279,7 +279,7 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None:
continue
shift = self.shift_before - (1 if col in [
SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
"env_id", "t"
SampleBatch.ENV_ID, SampleBatch.T
] else 0)
# Python primitive, tensor, or dict (e.g. INFOs).
self.buffers[col] = [data for _ in range(shift)]
Expand Down Expand Up @@ -546,8 +546,8 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \
# Create the batch of data from the different buffers.
data_col = view_req.data_col or view_col
delta = -1 if data_col in [
SampleBatch.OBS, "t", "env_id", SampleBatch.EPS_ID,
SampleBatch.AGENT_INDEX
SampleBatch.OBS, SampleBatch.ENV_ID, SampleBatch.EPS_ID,
SampleBatch.AGENT_INDEX, SampleBatch.T
] else 0
# Range of shifts, e.g. "-100:0". Note: This includes index 0!
if view_req.shift_from is not None:
Expand Down
20 changes: 11 additions & 9 deletions rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ray.rllib.models.preprocessors import Preprocessor
from ray.rllib.offline import InputReader
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.deprecation import deprecation_warning
Expand Down Expand Up @@ -833,19 +834,20 @@ def _process_observations(
else:
# Add actions, rewards, next-obs to collectors.
values_dict = {
"t": episode.length - 1,
"env_id": env_id,
"agent_index": episode._agent_index(agent_id),
SampleBatch.T: episode.length - 1,
SampleBatch.ENV_ID: env_id,
SampleBatch.AGENT_INDEX: episode._agent_index(agent_id),
# Action (slot 0) taken at timestep t.
"actions": episode.last_action_for(agent_id),
SampleBatch.ACTIONS: episode.last_action_for(agent_id),
# Reward received after taking a at timestep t.
"rewards": rewards[env_id].get(agent_id, 0.0),
SampleBatch.REWARDS: rewards[env_id].get(agent_id, 0.0),
# After taking action=a, did we reach terminal?
"dones": (False if (no_done_at_end
or (hit_horizon and soft_horizon)) else
agent_done),
SampleBatch.DONES: (False
if (no_done_at_end
or (hit_horizon and soft_horizon))
else agent_done),
# Next observation.
"new_obs": filtered_obs,
SampleBatch.NEXT_OBS: filtered_obs,
}
# Add extra-action-fetches to collectors.
pol = worker.policy_map[policy_id]
Expand Down
6 changes: 3 additions & 3 deletions rllib/evaluation/tests/test_trajectory_view_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ class MyCallbacks(DefaultCallbacks):
@override(DefaultCallbacks)
def on_learn_on_batch(self, *, policy, train_batch, result, **kwargs):
assert train_batch.count == 201
assert sum(train_batch["seq_lens"]) == 201
assert sum(train_batch[SampleBatch.SEQ_LENS]) == 201
for k, v in train_batch.items():
if k == "state_in_0":
assert len(v) == len(train_batch["seq_lens"])
assert len(v) == len(train_batch[SampleBatch.SEQ_LENS])
else:
assert len(v) == 201
current = None
Expand Down Expand Up @@ -403,7 +403,7 @@ def analyze_rnn_batch(batch, max_seq_len):

# Check after seq-len 0-padding.
cursor = 0
for i, seq_len in enumerate(batch["seq_lens"]):
for i, seq_len in enumerate(batch[SampleBatch.SEQ_LENS]):
state_in_0 = batch["state_in_0"][i]
state_in_1 = batch["state_in_1"][i]
for j in range(seq_len):
Expand Down
4 changes: 2 additions & 2 deletions rllib/examples/models/modelv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ def __init__(self,

def call(self, sample_batch):
dense_out = self.dense(sample_batch["obs"])
B = tf.shape(sample_batch["seq_lens"])[0]
B = tf.shape(sample_batch[SampleBatch.SEQ_LENS])[0]
lstm_in = tf.reshape(dense_out, [B, -1, dense_out.shape.as_list()[1]])
lstm_out, h, c = self.lstm(
inputs=lstm_in,
mask=tf.sequence_mask(sample_batch["seq_lens"]),
mask=tf.sequence_mask(sample_batch[SampleBatch.SEQ_LENS]),
initial_state=[
sample_batch["state_in_0"], sample_batch["state_in_1"]
],
Expand Down
5 changes: 3 additions & 2 deletions rllib/models/modelv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def __call__(
restored = input_dict.copy(shallow=True)
# Backward compatibility.
if seq_lens is None:
seq_lens = input_dict.get("seq_lens")
seq_lens = input_dict.get(SampleBatch.SEQ_LENS)
if not state:
state = []
i = 0
Expand Down Expand Up @@ -260,7 +260,8 @@ def from_batch(self, train_batch: SampleBatch,
while "state_in_{}".format(i) in input_dict:
states.append(input_dict["state_in_{}".format(i)])
i += 1
ret = self.__call__(input_dict, states, input_dict.get("seq_lens"))
ret = self.__call__(input_dict, states,
input_dict.get(SampleBatch.SEQ_LENS))
return ret

def import_from_h5(self, h5_file: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion rllib/models/tf/attention_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ def __init__(

def call(self, input_dict: SampleBatch) -> \
(TensorType, List[TensorType], Dict[str, TensorType]):
assert input_dict["seq_lens"] is not None
assert input_dict[SampleBatch.SEQ_LENS] is not None
# Push obs through "unwrapped" net's `forward()` first.
wrapped_out, _, _ = self.wrapped_keras_model(input_dict)

Expand Down
Loading

0 comments on commit 494ddd9

Please sign in to comment.