Skip to content

Commit

Permalink
BUGFIX+TESTS: fixed env unwrapped bug. Paths were using wrapped envs.…
Browse files Browse the repository at this point in the history
… By default path_utils uses unwrapped envs now
  • Loading branch information
vikashplus committed Mar 9, 2024
1 parent 4898218 commit 76b3455
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 11 deletions.
47 changes: 46 additions & 1 deletion robohive/tests/test_examine_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,39 @@
import unittest
from robohive.utils.examine_env import main as examine_env
import os
import glob
import time


class TestExamineEnv(unittest.TestCase):

def delete_recent_file(self, filename_pattern, directory='.', age=5):

# Get the current time
current_time = time.time()

# Use glob to find files matching the pattern in the specified directory
matching_files = glob.glob(os.path.join(directory, filename_pattern))

# Iterate over the matching files
for file_path in matching_files:
try:
# Get the creation time of the file
creation_time = os.path.getctime(file_path)

# Calculate the time difference between current time and creation time
time_difference = current_time - creation_time

# If the file was created within the last 5 seconds, delete it
if time_difference <= 5:
os.remove(file_path)
print(f"Deleted file created within {age} seconds: {file_path}")
else:
print(f"File not deleted: {file_path}, created {time_difference} seconds ago.")
except Exception as e:
print(f"Error deleting file: {file_path} - {e}")


def test_main(self):
# Call your function and test its output/assertions
print("Testing env with random policy")
Expand All @@ -32,7 +62,22 @@ def test_offscreen_rendering(self):
print("EXCEPTION", result.exception, flush=True)
# print(result.output.strip())
self.assertEqual(result.exception, None, result.exception)
os.remove('random_policy0.mp4')
self.delete_recent_file(filename_pattern="random_policy*.mp4")

def test_paths_plotting(self):
# Call your function and test its output/assertions
print("Testing plotting paths")
runner = click.testing.CliRunner()
result = runner.invoke(examine_env, ["--env_name", "door-v1", \
"--num_episodes", 1, \
"--render", "none",\
"--plot_paths", True])
print("OUTPUT", result.output.strip(), flush=True)
print("RESULT", result, flush=True)
print("EXCEPTION", result.exception, flush=True)
# print(result.output.strip())
self.assertEqual(result.exception, None, result.exception)
self.delete_recent_file(filename_pattern="random_policy*Trial*.pdf")

def no_test_scripted_policy_loading(self):
# Call your function and test its output/assertions
Expand Down
64 changes: 54 additions & 10 deletions robohive/utils/paths_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,20 @@
from robohive.utils.dict_utils import flatten_dict, dict_numpify
import json

#TODO: Harmonize names, remove rollout_paths, use path for one and paths for multiple

# Useful to check the horizon for teleOp / Hardware experiments
# Check the horizon for teleOp / Hardware experiments
def plot_horizon(paths, env, fileName_prefix=None):
"""
Check the horizon for teleOp / Hardware experiments
Args:
paths: paths to examine
env: unwrapped env
fileName_prefix (str): prefix to use in the filename
Saves:
fileName_prefix + '_horizon.pdf'
"""
import matplotlib as mpl
mpl.use('TkAgg')
import matplotlib.pyplot as plt
Expand All @@ -30,7 +41,7 @@ def plot_horizon(paths, env, fileName_prefix=None):
# plot timesteps
plt.clf()

rl_dt_ideal = env.env.frame_skip * env.env.model.opt.timestep
rl_dt_ideal = env.frame_skip * env.model.opt.timestep
for i, path in enumerate(paths):
dt = path['env_infos']['time'][1:] - path['env_infos']['time'][:-1]
horizon[i] = path['env_infos']['time'][-1] - path['env_infos'][
Expand Down Expand Up @@ -75,8 +86,19 @@ def plot_horizon(paths, env, fileName_prefix=None):
print("Saved:", file_name)


# Plot paths to a pdf file
# 2D-plot of paths detailing obs, act, rwds across time
def plot(paths, env=None, fileName_prefix=''):
"""
2D-plot of paths detailing obs, act, rwds across time
Args:
paths: paths to examine
env: unwrapped env
fileName_prefix: prefix to use in the filename
Saves:
fileName_prefix + path_name + '.pdf'
"""
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -108,7 +130,7 @@ def plot(paths, env=None, fileName_prefix=''):
nplt2 = 3
ax = plt.subplot(nplt2, 2, 2)
ax.set_prop_cycle(None)
# h4 = plt.plot(path['env_infos']['time'], env.env.act_mid + path['actions']*env.env.act_rng, '-', label='act') # plot scaled actions
# h4 = plt.plot(path['env_infos']['time'], env.act_mid + path['actions']*env.act_rng, '-', label='act') # plot scaled actions
h4 = plt.plot(
path['env_infos']['time'], path['actions'], '-',
label='act') # plot normalized actions
Expand Down Expand Up @@ -143,13 +165,14 @@ def plot(paths, env=None, fileName_prefix=''):
ax.axes.xaxis.set_ticklabels([])
plt.ylabel('rewards')
ax.yaxis.tick_right()
if env and hasattr(env.env, "rwd_keys_wt"):

if env and hasattr(env, "rwd_keys_wt"):
ax = plt.subplot(nplt2, 2, 6)
ax.set_prop_cycle(None)
for key in sorted(env.env.rwd_keys_wt.keys()):
for key in sorted(env.rwd_keys_wt.keys()):
plt.plot(
path['env_infos']['time'],
path['env_infos']['rwd_dict'][key]*env.env.rwd_keys_wt[key],
path['env_infos']['rwd_dict'][key]*env.rwd_keys_wt[key],
label=key)
plt.legend(
loc='upper left',
Expand All @@ -167,9 +190,30 @@ def plot(paths, env=None, fileName_prefix=''):

# Render frames/videos
def render(rollout_path, render_format:str="mp4", cam_names:list=["left"]):
# rollout_path: Absolute path of the rollout (h5/pickle)', default=None
# format: Format to save. Choice['rgb', 'mp4']
# cam: list of cameras to render. Example ['left', 'right', 'top', 'Franka_wrist']
"""
Render the frames from a given rollout.
Parameters:
rollout_path (str): Absolute path of the rollout (h5/pickle).
render_format (str, optional): Format to save the rendered frames. Default is "mp4".
cam_names (list, optional): List of cameras to render. Default is ["left"]. Example ['left', 'right', 'top', 'Franka_wrist']
Returns:
None
Raises:
TypeError: If the path format is unknown.
Notes:
- The frames are saved in the specified render format.
- The rendered frames can be saved as an mp4 video or as individual RGB images.
- The frames are rendered for each camera specified in the cam_names list.
- The frames are saved in the same directory as the rollout path.
- The output file names are generated based on the rollout name and the camera names.
Example:
render(rollout_path="/path/to/rollout.h5", render_format="mp4", cam_names=["left", "right"])
"""

output_dir = os.path.dirname(rollout_path)
rollout_name = os.path.split(rollout_path)[-1]
Expand Down

0 comments on commit 76b3455

Please sign in to comment.