Skip to content

Commit

Permalink
fix meta_rl evironmental setup (#76)
Browse files Browse the repository at this point in the history
* fix meta_rl evironmental setup

* Update makefile

* Update torch version

* modify torch-cpu version requirement

* Update torch version

* ..

* ..

* ..

* ..

* ..

* ..

* ..

Co-authored-by: dongminlee94 <[email protected]>
  • Loading branch information
Clyde21c and dongminlee94 committed Jun 16, 2022
1 parent 69cfdd6 commit 5f3ba5b
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 31 deletions.
16 changes: 6 additions & 10 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
GPU := $(shell which nvcc)
ifdef GPU
DEVICE="gpu"
else
DEVICE="cpu"
endif

STAGED := $(shell git diff --cached --name-only --diff-filter=ACMR -- 'src/***.py' | sed 's| |\\ |g')

all: format lint
Expand Down Expand Up @@ -34,10 +27,13 @@ endif

init:
pip install -U pip
pip install -U setuptools
pip install -e .
pip install -r requirements-common.txt
pip install -r requirements-$(DEVICE).txt
pip install -r requirements.txt
python3 ./scripts/download-torch.py
conda install -y tensorboard
jupyter contrib nbextension install --user
jupyter nbextensions_configurator enable --user
python3 -m ipykernel install --user
bash ./hooks/install.sh

init-dev:
Expand Down
6 changes: 0 additions & 6 deletions requirements-common.txt

This file was deleted.

3 changes: 0 additions & 3 deletions requirements-cpu.txt

This file was deleted.

3 changes: 0 additions & 3 deletions requirements-gpu.txt

This file was deleted.

9 changes: 9 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
autopep8==1.5.0
GPUtil==1.4.0
gym>=0.24.1
jupyter==1.0.0
jupyter-contrib-nbextensions==0.5.1
jupyter-nbextensions-configurator==0.4.1
mujoco>=2.2.0
torchmeta>=1.8.0
tqdm>=4.62.3
21 changes: 21 additions & 0 deletions scripts/download-torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
import os
import sys

import GPUtil


def main() -> None:
if sys.platform == "win32" or sys.platform == "linux":
if GPUtil.getAvailable():
cli = "pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html"
else:
cli = "pip install torch==1.8.1+cpu torchvision==0.9.1+cpu -f https://download.pytorch.org/whl/torch_stable.html"
elif sys.platform == "darwin":
cli = "pip install torch==1.8.1 torchvision==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html"
print(cli)
os.system(cli)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion src/meta_rl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ def register_env_fn(filename):
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
module = file[: file.find(".py")]
importlib.import_module("src.envs." + module)
importlib.import_module("meta_rl.envs." + module)
12 changes: 9 additions & 3 deletions src/meta_rl/envs/half_cheetah.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,25 @@
https://github.com/katerakelly/oyster/blob/master/rlkit/envs/half_cheetah.py
"""

from typing import List, Union
from typing import List, Optional, Union

import numpy as np
from gym import utils
from gym.envs.mujoco import HalfCheetahEnv as HalfCheetahEnv_
from gym.envs.mujoco import mujoco_env


class HalfCheetahEnv(HalfCheetahEnv_):
def __init__(self):
mujoco_env.MujocoEnv.__init__(self, "half_cheetah.xml", 5)
utils.EzPickle.__init__(self)

def _get_obs(self) -> np.ndarray:
return (
np.concatenate(
[
self.sim.data.qpos.flat[1:],
self.sim.data.qvel.flat,
self.data.qpos.flat[1:],
self.data.qvel.flat,
self.get_body_com("torso").flat,
],
)
Expand Down
4 changes: 2 additions & 2 deletions src/meta_rl/envs/half_cheetah_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def __init__(self, num_tasks: int) -> None:
super().__init__()

def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool, Dict[str, Any]]:
xposbefore = self.sim.data.qpos[0]
xposbefore = self.data.qpos[0]
self.do_simulation(action, self.frame_skip)
xposafter = self.sim.data.qpos[0]
xposafter = self.data.qpos[0]

progress = (xposafter - xposbefore) / self.dt
run_cost = self._goal_dir * progress
Expand Down
4 changes: 2 additions & 2 deletions src/meta_rl/envs/half_cheetah_vel.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def __init__(self, num_tasks: int) -> None:
super().__init__()

def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool, Dict[str, Any]]:
xposbefore = self.sim.data.qpos[0]
xposbefore = self.data.qpos[0]
self.do_simulation(action, self.frame_skip)
xposafter = self.sim.data.qpos[0]
xposafter = self.data.qpos[0]

progress = (xposafter - xposbefore) / self.dt
run_cost = progress - self._goal_vel
Expand Down
2 changes: 1 addition & 1 deletion src/meta_rl/rl2/rl2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
tasks: List[int] = env.get_all_task_idx()

# Set a random seed
env.seed(experiment_config["seed"])
env.reset(seed=experiment_config["seed"])
np.random.seed(experiment_config["seed"])
torch.manual_seed(experiment_config["seed"])

Expand Down

0 comments on commit 5f3ba5b

Please sign in to comment.