Skip to content

Commit

Permalink
[RLlib] Fix ope evaluation bug (ray-project#35697)
Browse files Browse the repository at this point in the history
Signed-off-by: Rohan Potdar <[email protected]>
  • Loading branch information
Rohan138 committed May 24, 2023
1 parent 09e07fd commit 8e49d2a
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 15 deletions.
14 changes: 2 additions & 12 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
DirectMethod,
DoublyRobust,
)
from ray.rllib.offline.offline_evaluation_utils import remove_time_dim
from ray.rllib.offline.offline_evaluator import OfflineEvaluator
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch, concat_samples
Expand Down Expand Up @@ -666,17 +665,8 @@ def setup(self, config: AlgorithmConfig) -> None:
# the num worker is set to 0 to avoid creating shards. The dataset will not
# be repartioned to num_workers blocks.
logger.info("Creating evaluation dataset ...")
ds, _ = get_dataset_and_shards(self.evaluation_config, num_workers=0)

# Dataset should be in form of one episode per row. in case of bandits each
# row is just one time step. To make the computation more efficient later
# we remove the time dimension here.
parallelism = self.evaluation_config.evaluation_num_workers or 1
batch_size = max(ds.count() // parallelism, 1)
self.evaluation_dataset = ds.map_batches(
remove_time_dim,
batch_size=batch_size,
batch_format="pandas",
self.evaluation_dataset, _ = get_dataset_and_shards(
self.evaluation_config, num_workers=0
)
logger.info("Evaluation dataset created")

Expand Down
8 changes: 7 additions & 1 deletion rllib/offline/estimators/importance_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.offline.offline_evaluator import OfflineEvaluator
from ray.rllib.offline.offline_evaluation_utils import compute_is_weights
from ray.rllib.offline.offline_evaluation_utils import (
remove_time_dim,
compute_is_weights,
)
from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
from ray.rllib.policy.sample_batch import SampleBatch

Expand Down Expand Up @@ -96,6 +99,9 @@ def estimate_on_dataset(
the behavior policy.
"""
batch_size = max(dataset.count() // n_parallelism, 1)
dataset = dataset.map_batches(
remove_time_dim, batch_size=batch_size, batch_format="pandas"
)
updated_ds = dataset.map_batches(
compute_is_weights,
batch_size=batch_size,
Expand Down
8 changes: 7 additions & 1 deletion rllib/offline/estimators/weighted_importance_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from ray.rllib.offline.offline_evaluator import OfflineEvaluator
from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator
from ray.rllib.offline.offline_evaluation_utils import compute_is_weights
from ray.rllib.offline.offline_evaluation_utils import (
remove_time_dim,
compute_is_weights,
)
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy import Policy
from ray.rllib.utils.annotations import override, DeveloperAPI
Expand Down Expand Up @@ -152,6 +155,9 @@ def estimate_on_dataset(
"""
# compute the weights and weighted rewards
batch_size = max(dataset.count() // n_parallelism, 1)
dataset = dataset.map_batches(
remove_time_dim, batch_size=batch_size, batch_format="pandas"
)
updated_ds = dataset.map_batches(
compute_is_weights,
batch_size=batch_size,
Expand Down
2 changes: 1 addition & 1 deletion rllib/offline/offline_evaluation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def compute_is_weights(
The modified batch with the importance sampling weights, weighted rewards, new
and old propensities added as columns.
"""
policy = policy = Policy.from_state(policy_state)
policy = Policy.from_state(policy_state)
estimator = estimator_class(policy=policy, gamma=0, epsilon_greedy=0)
sample_batch = SampleBatch(
{
Expand Down

0 comments on commit 8e49d2a

Please sign in to comment.