Skip to content

Commit

Permalink
added examples for random goals, and option to not log in base sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
cbfinn committed Jan 14, 2017
1 parent 8cbd437 commit 152ad21
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 17 deletions.
33 changes: 33 additions & 0 deletions examples/sens_vpg_point.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#from rllab.algos.vpg import VPG
from sandbox.rocky.tf.algos.sensitive_vpg import SensitiveVPG
from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline
from examples.point_env import PointEnv
from examples.point_env_randgoal import PointEnvRandGoal
from rllab.envs.normalized_env import normalize
from rllab.misc.instrument import stub, run_experiment_lite
#from rllab.policies.gaussian_mlp_policy import GaussianMLPPolicy
from sandbox.rocky.tf.policies.sens_minimal_gauss_mlp_policy import SensitiveGaussianMLPPolicy
from sandbox.rocky.tf.envs.base import TfEnv

stub(globals())

#env = TfEnv(normalize(PointEnv()))
env = TfEnv(normalize(PointEnvRandGoal()))
policy = SensitiveGaussianMLPPolicy(
name="policy",
env_spec=env.spec,
)
baseline = LinearFeatureBaseline(env_spec=env.spec)
algo = SensitiveVPG(
env=env,
policy=policy,
baseline=baseline,
#plot=True,
)
run_experiment_lite(
algo.train(),
n_parallel=1,
snapshot_mode="last",
seed=1,
#plot=True,
)
4 changes: 2 additions & 2 deletions examples/vpg_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

stub(globals())

env = TfEnv(normalize(PointEnv()))
#env = TfEnv(normalize(PointEnvRandGoal()))
#env = TfEnv(normalize(PointEnv()))
env = TfEnv(normalize(PointEnvRandGoal()))
policy = GaussianMLPPolicy(
name="policy",
env_spec=env.spec,
Expand Down
33 changes: 18 additions & 15 deletions rllab/sampler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, algo):
"""
self.algo = algo

def process_samples(self, itr, paths):
def process_samples(self, itr, paths, log=True):
baselines = []
returns = []

Expand Down Expand Up @@ -160,23 +160,26 @@ def process_samples(self, itr, paths):
paths=paths,
)

logger.log("fitting baseline...")
if log:
logger.log("fitting baseline...")
if hasattr(self.algo.baseline, 'fit_with_samples'):
self.algo.baseline.fit_with_samples(paths, samples_data)
else:
self.algo.baseline.fit(paths)
logger.log("fitted")

logger.record_tabular('Iteration', itr)
logger.record_tabular('AverageDiscountedReturn',
average_discounted_return)
logger.record_tabular('AverageReturn', np.mean(undiscounted_returns))
logger.record_tabular('ExplainedVariance', ev)
logger.record_tabular('NumTrajs', len(paths))
logger.record_tabular('Entropy', ent)
logger.record_tabular('Perplexity', np.exp(ent))
logger.record_tabular('StdReturn', np.std(undiscounted_returns))
logger.record_tabular('MaxReturn', np.max(undiscounted_returns))
logger.record_tabular('MinReturn', np.min(undiscounted_returns))
if log:
logger.log("fitted")

if log:
#logger.record_tabular('Iteration', itr)
#logger.record_tabular('AverageDiscountedReturn',
# average_discounted_return)
logger.record_tabular('AverageReturn', np.mean(undiscounted_returns))
#logger.record_tabular('ExplainedVariance', ev)
#logger.record_tabular('NumTrajs', len(paths))
#logger.record_tabular('Entropy', ent)
#logger.record_tabular('Perplexity', np.exp(ent))
logger.record_tabular('StdReturn', np.std(undiscounted_returns))
#logger.record_tabular('MaxReturn', np.max(undiscounted_returns))
#logger.record_tabular('MinReturn', np.min(undiscounted_returns))

return samples_data

0 comments on commit 152ad21

Please sign in to comment.