diff --git a/robohive/tests/test_examine_env.py b/robohive/tests/test_examine_env.py index dd7ded87..d4e531ca 100644 --- a/robohive/tests/test_examine_env.py +++ b/robohive/tests/test_examine_env.py @@ -3,9 +3,39 @@ import unittest from robohive.utils.examine_env import main as examine_env import os +import glob +import time class TestExamineEnv(unittest.TestCase): + + def delete_recent_file(self, filename_pattern, directory='.', age=5): + + # Get the current time + current_time = time.time() + + # Use glob to find files matching the pattern in the specified directory + matching_files = glob.glob(os.path.join(directory, filename_pattern)) + + # Iterate over the matching files + for file_path in matching_files: + try: + # Get the creation time of the file + creation_time = os.path.getctime(file_path) + + # Calculate the time difference between current time and creation time + time_difference = current_time - creation_time + + # If the file was created within the last 5 seconds, delete it + if time_difference <= 5: + os.remove(file_path) + print(f"Deleted file created within {age} seconds: {file_path}") + else: + print(f"File not deleted: {file_path}, created {time_difference} seconds ago.") + except Exception as e: + print(f"Error deleting file: {file_path} - {e}") + + def test_main(self): # Call your function and test its output/assertions print("Testing env with random policy") @@ -32,7 +62,22 @@ def test_offscreen_rendering(self): print("EXCEPTION", result.exception, flush=True) # print(result.output.strip()) self.assertEqual(result.exception, None, result.exception) - os.remove('random_policy0.mp4') + self.delete_recent_file(filename_pattern="random_policy*.mp4") + + def test_paths_plotting(self): + # Call your function and test its output/assertions + print("Testing plotting paths") + runner = click.testing.CliRunner() + result = runner.invoke(examine_env, ["--env_name", "door-v1", \ + "--num_episodes", 1, \ + "--render", "none",\ + "--plot_paths", True]) + print("OUTPUT", result.output.strip(), flush=True) + print("RESULT", result, flush=True) + print("EXCEPTION", result.exception, flush=True) + # print(result.output.strip()) + self.assertEqual(result.exception, None, result.exception) + self.delete_recent_file(filename_pattern="random_policy*Trial*.pdf") def no_test_scripted_policy_loading(self): # Call your function and test its output/assertions diff --git a/robohive/utils/paths_utils.py b/robohive/utils/paths_utils.py index 7b2e5f36..32c44e7d 100644 --- a/robohive/utils/paths_utils.py +++ b/robohive/utils/paths_utils.py @@ -16,9 +16,20 @@ from robohive.utils.dict_utils import flatten_dict, dict_numpify import json +#TODO: Harmonize names, remove rollout_paths, use path for one and paths for multiple -# Useful to check the horizon for teleOp / Hardware experiments +# Check the horizon for teleOp / Hardware experiments def plot_horizon(paths, env, fileName_prefix=None): + """ + Check the horizon for teleOp / Hardware experiments + + Args: + paths: paths to examine + env: unwrapped env + fileName_prefix (str): prefix to use in the filename + Saves: + fileName_prefix + '_horizon.pdf' + """ import matplotlib as mpl mpl.use('TkAgg') import matplotlib.pyplot as plt @@ -30,7 +41,7 @@ def plot_horizon(paths, env, fileName_prefix=None): # plot timesteps plt.clf() - rl_dt_ideal = env.env.frame_skip * env.env.model.opt.timestep + rl_dt_ideal = env.frame_skip * env.model.opt.timestep for i, path in enumerate(paths): dt = path['env_infos']['time'][1:] - path['env_infos']['time'][:-1] horizon[i] = path['env_infos']['time'][-1] - path['env_infos'][ @@ -75,8 +86,19 @@ def plot_horizon(paths, env, fileName_prefix=None): print("Saved:", file_name) -# Plot paths to a pdf file +# 2D-plot of paths detailing obs, act, rwds across time def plot(paths, env=None, fileName_prefix=''): + """ + 2D-plot of paths detailing obs, act, rwds across time + + Args: + paths: paths to examine + env: unwrapped env + fileName_prefix: prefix to use in the filename + + Saves: + fileName_prefix + path_name + '.pdf' + """ import matplotlib as mpl mpl.use('Agg') import matplotlib.pyplot as plt @@ -108,7 +130,7 @@ def plot(paths, env=None, fileName_prefix=''): nplt2 = 3 ax = plt.subplot(nplt2, 2, 2) ax.set_prop_cycle(None) - # h4 = plt.plot(path['env_infos']['time'], env.env.act_mid + path['actions']*env.env.act_rng, '-', label='act') # plot scaled actions + # h4 = plt.plot(path['env_infos']['time'], env.act_mid + path['actions']*env.act_rng, '-', label='act') # plot scaled actions h4 = plt.plot( path['env_infos']['time'], path['actions'], '-', label='act') # plot normalized actions @@ -143,13 +165,14 @@ def plot(paths, env=None, fileName_prefix=''): ax.axes.xaxis.set_ticklabels([]) plt.ylabel('rewards') ax.yaxis.tick_right() - if env and hasattr(env.env, "rwd_keys_wt"): + + if env and hasattr(env, "rwd_keys_wt"): ax = plt.subplot(nplt2, 2, 6) ax.set_prop_cycle(None) - for key in sorted(env.env.rwd_keys_wt.keys()): + for key in sorted(env.rwd_keys_wt.keys()): plt.plot( path['env_infos']['time'], - path['env_infos']['rwd_dict'][key]*env.env.rwd_keys_wt[key], + path['env_infos']['rwd_dict'][key]*env.rwd_keys_wt[key], label=key) plt.legend( loc='upper left', @@ -167,9 +190,30 @@ def plot(paths, env=None, fileName_prefix=''): # Render frames/videos def render(rollout_path, render_format:str="mp4", cam_names:list=["left"]): - # rollout_path: Absolute path of the rollout (h5/pickle)', default=None - # format: Format to save. Choice['rgb', 'mp4'] - # cam: list of cameras to render. Example ['left', 'right', 'top', 'Franka_wrist'] + """ + Render the frames from a given rollout. + + Parameters: + rollout_path (str): Absolute path of the rollout (h5/pickle). + render_format (str, optional): Format to save the rendered frames. Default is "mp4". + cam_names (list, optional): List of cameras to render. Default is ["left"]. Example ['left', 'right', 'top', 'Franka_wrist'] + + Returns: + None + + Raises: + TypeError: If the path format is unknown. + + Notes: + - The frames are saved in the specified render format. + - The rendered frames can be saved as an mp4 video or as individual RGB images. + - The frames are rendered for each camera specified in the cam_names list. + - The frames are saved in the same directory as the rollout path. + - The output file names are generated based on the rollout name and the camera names. + + Example: + render(rollout_path="/path/to/rollout.h5", render_format="mp4", cam_names=["left", "right"]) + """ output_dir = os.path.dirname(rollout_path) rollout_name = os.path.split(rollout_path)[-1]