Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ACKTR model crashes using CnnLnLstmPolicy #387

Open
MartinBertran opened this issue Jun 25, 2019 · 8 comments
Open

ACKTR model crashes using CnnLnLstmPolicy #387

MartinBertran opened this issue Jun 25, 2019 · 8 comments
Labels
custom gym env Issue related to Custom Gym Env

Comments

@MartinBertran
Copy link

MartinBertran commented Jun 25, 2019

Describe the bug
Describe the bug
ACKTR example code crashes when modified to use MlpLnLstmPolicy. Apparent bug in KFAC code

Code example

import gym
import vizdoomgym
from stable_baselines.common.policies import CnnLnLstmPolicy, MlpLnLstmPolicy, MlpPolicy
from stable_baselines.common.vec_env import SubprocVecEnv, DummyVecEnv
from stable_baselines import ACKTR
n_cpu=4
if __name__=="__main__":
    env = SubprocVecEnv([lambda: gym.make('VizdoomCorridor-v0') for i in range(n_cpu)])
    model = ACKTR(CnnLnLstmPolicy, env, verbose=1)

    model.learn(total_timesteps=20000000)

results in:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
~/anaconda3/envs/pythonRL/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
   1658   try:
-> 1659     c_op = c_api.TF_FinishOperation(op_desc)
   1660   except errors.InvalidArgumentError as e:

InvalidArgumentError: Shape must be rank 2 but is rank 1 for 'kfac/MatMul_2' (op: 'MatMul') with input shapes: [32], [32].

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-2-12e263ec93c1> in <module>
     11 
     12     env = SubprocVecEnv([lambda: gym.make('VizdoomCorridor-v0') for i in range(n_cpu)])
---> 13     model = ACKTR(CnnLnLstmPolicy, env, verbose=1)
     14 
     15     model.learn(total_timesteps=20000000)

~/ReinforcementLearning/stable-baselines/stable_baselines/acktr/acktr_disc.py in __init__(self, policy, env, gamma, nprocs, n_steps, ent_coef, vf_coef, vf_fisher_coef, learning_rate, max_grad_norm, kfac_clip, lr_schedule, verbose, tensorboard_log, _init_setup_model, async_eigen_decomp, policy_kwargs, full_tensorboard_log)
    101 
    102         if _init_setup_model:
--> 103             self.setup_model()
    104 
    105     def _get_pretrain_placeholders(self):

~/ReinforcementLearning/stable-baselines/stable_baselines/acktr/acktr_disc.py in setup_model(self)
    195 
    196                         print(self.joint_fisher)
--> 197                         optim.compute_and_apply_stats(self.joint_fisher, var_list=params)
    198 
    199                 self.train_model = train_model

~/ReinforcementLearning/stable-baselines/stable_baselines/acktr/kfac.py in compute_and_apply_stats(self, loss_sampled, var_list)
    332             varlist = tf.trainable_variables()
    333 
--> 334         stats = self.compute_stats(loss_sampled, var_list=varlist)
    335         return self.apply_stats(stats)
    336 

~/ReinforcementLearning/stable-baselines/stable_baselines/acktr/kfac.py in compute_stats(self, loss_sampled, var_list)
    475 
    476                     cov_b = tf.matmul(bprop_factor, bprop_factor,
--> 477                                       transpose_a=True) / tf.cast(tf.shape(bprop_factor)[0], tf.float32)
    478 
    479                     update_ops.append(cov_b)

~/anaconda3/envs/pythonRL/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py in matmul(a, b, transpose_a, transpose_b, adjoint_a, adjoint_b, a_is_sparse, b_is_sparse, name)
   2453     else:
   2454       return gen_math_ops.mat_mul(
-> 2455           a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)
   2456 
   2457 

~/anaconda3/envs/pythonRL/lib/python3.6/site-packages/tensorflow/python/ops/gen_math_ops.py in mat_mul(a, b, transpose_a, transpose_b, name)
   5331   _, _, _op = _op_def_lib._apply_op_helper(
   5332         "MatMul", a=a, b=b, transpose_a=transpose_a, transpose_b=transpose_b,
-> 5333                   name=name)
   5334   _result = _op.outputs[:]
   5335   _inputs_flat = _op.inputs

~/anaconda3/envs/pythonRL/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
    786         op = g.create_op(op_type_name, inputs, output_types, name=scope,
    787                          input_types=input_types, attrs=attr_protos,
--> 788                          op_def=op_def)
    789       return output_structure, op_def.is_stateful, op
    790 

~/anaconda3/envs/pythonRL/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    505                 'in a future version' if date is None else ('after %s' % date),
    506                 instructions)
--> 507       return func(*args, **kwargs)
    508 
    509     doc = _add_deprecated_arg_notice_to_docstring(

~/anaconda3/envs/pythonRL/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in create_op(***failed resolving arguments***)
   3298           input_types=input_types,
   3299           original_op=self._default_original_op,
-> 3300           op_def=op_def)
   3301       self._create_op_helper(ret, compute_device=compute_device)
   3302     return ret

~/anaconda3/envs/pythonRL/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in __init__(self, node_def, g, inputs, output_types, control_inputs, input_types, original_op, op_def)
   1821           op_def, inputs, node_def.attr)
   1822       self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
-> 1823                                 control_input_ops)
   1824 
   1825     # Initialize self._outputs.

~/anaconda3/envs/pythonRL/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
   1660   except errors.InvalidArgumentError as e:
   1661     # Convert to ValueError for backwards compatibility.
-> 1662     raise ValueError(str(e))
   1663 
   1664   return c_op

ValueError: Shape must be rank 2 but is rank 1 for 'kfac/MatMul_2' (op: 'MatMul') with input shapes: [32], [32].

System Info
Describe the characteristic of your environment:

  • Describe how the library was installed (pip, docker, source, ...)
    pip install from cloned repo
  • GPU models and configuration
    GeForce GTX 1080
  • Python version
    Python 3.6.8
  • Tensorflow version
    1.13.1
  • Versions of any other relevant libraries

Additional context
seems to expect bprop to be a batch x channel tensor, but is instead a batch tensor, stems from optim.compute_and_apply_stats(self.joint_fisher, var_list=params), joint_fisher is a (838980, 32) tensor

@araffin
Copy link
Collaborator

araffin commented Jun 25, 2019

Hello,

It seems that it may be related from your custom environment.
The following code works on my machine:

from stable_baselines.common.cmd_util import make_atari_env
from stable_baselines import ACKTR

env = make_atari_env("BreakoutNoFrameskip-v4", num_env=2, seed=1)
# Reduce number of steps to avoid memory issue
model = ACKTR("CnnLnLstmPolicy", env, n_steps=4, verbose=1)
model.learn(1000)
  • Master version of stable baselines (2.6.1a0)
  • tf cpu 1.8.0
  • python 3.6

@araffin araffin added the custom gym env Issue related to Custom Gym Env label Jun 25, 2019
@MartinBertran
Copy link
Author

That code snippet does not work for me

Process ForkProcess-1:
Traceback (most recent call last):
File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
self.run()
File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/multiprocessing/process.py", line 93, in run
self._target(*self._args, **self._kwargs)
File "/home/martin/ReinforcementLearning/stable-baselines/stable_baselines/common/vec_env/subproc_vec_env.py", line 13, in _worker
env = env_fn_wrapper.var()
File "/home/martin/ReinforcementLearning/stable-baselines/stable_baselines/common/cmd_util.py", line 38, in _thunk
env = make_atari(env_id)
File "/home/martin/ReinforcementLearning/stable-baselines/stable_baselines/common/atari_wrappers.py", line 284, in make_atari
env = gym.make(env_id)
File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/site-packages/gym/envs/registration.py", line 156, in make
return registry.make(id, **kwargs)
File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/site-packages/gym/envs/registration.py", line 101, in make
env = spec.make(kwargs)
File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/site-packages/gym/envs/registration.py", line 73, in make
env = cls(
_kwargs)
File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/site-packages/gym/envs/atari/atari_env.py", line 69, in init
self.seed()
File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/site-packages/gym/envs/atari/atari_env.py", line 93, in seed
modes = self.ale.getAvailableModes()
AttributeError: 'ALEInterface' object has no attribute 'getAvailableModes'
Process ForkProcess-2:
Traceback (most recent call last):
File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
self.run()
File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/multiprocessing/process.py", line 93, in run
self._target(*self._args, **self._kwargs)
File "/home/martin/ReinforcementLearning/stable-baselines/stable_baselines/common/vec_env/subproc_vec_env.py", line 13, in _worker
env = env_fn_wrapper.var()
File "/home/martin/ReinforcementLearning/stable-baselines/stable_baselines/common/cmd_util.py", line 38, in _thunk
env = make_atari(env_id)
File "/home/martin/ReinforcementLearning/stable-baselines/stable_baselines/common/atari_wrappers.py", line 284, in make_atari
env = gym.make(env_id)
File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/site-packages/gym/envs/registration.py", line 156, in make
return registry.make(id, **kwargs)
File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/site-packages/gym/envs/registration.py", line 101, in make
env = spec.make(kwargs)
File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/site-packages/gym/envs/registration.py", line 73, in make
env = cls(
_kwargs)
File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/site-packages/gym/envs/atari/atari_env.py", line 69, in init
self.seed()
File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/site-packages/gym/envs/atari/atari_env.py", line 93, in seed
modes = self.ale.getAvailableModes()
AttributeError: 'ALEInterface' object has no attribute 'getAvailableModes'

ConnectionResetError Traceback (most recent call last)
in
19 from stable_baselines import ACKTR
20
---> 21 env = make_atari_env("BreakoutNoFrameskip-v4", num_env=2, seed=1)
22 # Reduce number of steps to avoid memory issue
23 model = ACKTR("CnnLnLstmPolicy", env, n_steps=4, verbose=1)

~/ReinforcementLearning/stable-baselines/stable_baselines/common/cmd_util.py in make_atari_env(env_id, num_env, seed, wrapper_kwargs, start_index, allow_early_resets, start_method)
49
50 return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)],
---> 51 start_method=start_method)
52
53

~/ReinforcementLearning/stable-baselines/stable_baselines/common/vec_env/subproc_vec_env.py in init(self, env_fns, start_method)
91
92 self.remotes[0].send(('get_spaces', None))
---> 93 observation_space, action_space = self.remotes[0].recv()
94 VecEnv.init(self, len(env_fns), observation_space, action_space)
95

~/anaconda3/envs/pythonRL/lib/python3.6/multiprocessing/connection.py in recv(self)
248 self._check_closed()
249 self._check_readable()
--> 250 buf = self._recv_bytes()
251 return _ForkingPickler.loads(buf.getbuffer())
252

~/anaconda3/envs/pythonRL/lib/python3.6/multiprocessing/connection.py in _recv_bytes(self, maxsize)
405
406 def _recv_bytes(self, maxsize=None):
--> 407 buf = self._recv(4)
408 size, = struct.unpack("!i", buf.getvalue())
409 if maxsize is not None and size > maxsize:

~/anaconda3/envs/pythonRL/lib/python3.6/multiprocessing/connection.py in _recv(self, size, read)
377 remaining = size
378 while remaining > 0:
--> 379 chunk = read(handle, remaining)
380 n = len(chunk)
381 if n == 0:

ConnectionResetError: [Errno 104] Connection reset by peer

The same happens for other atari environments using ACKTR + CnnLstmPolicies, like

env = SubprocVecEnv([lambda: gym.make('Breakout-v0') for i in range(n_cpu)])
model = ACKTR(CnnLnLstmPolicy, env, verbose=False, tensorboard_log="./test/")

But it works fine on MlpLstmPolicy

if __name__=="__main__":
    env = SubprocVecEnv([lambda: gym.make('CartPole-v0') for i in range(n_cpu)])
    model = ACKTR(MlpLnLstmPolicy, env, verbose=False, tensorboard_log="./test/")

This seems to be an ACKTR-specific issue for me, PPO2 works for all listed examples

@araffin
Copy link
Collaborator

araffin commented Jun 25, 2019

What is your gym version ? (+ associated, like atary-py)

@MartinBertran
Copy link
Author

These are all I could think of
stable_baselines.version ='2.6.1a0'
atari-py==0.1.15
gym==0.13.0
tensorboard==1.14.0
tensorflow==1.13.1
tensorflow-estimator==1.14.0
tensorflow-gpu==1.14.0
vizdoom==1.1.7

@araffin
Copy link
Collaborator

araffin commented Oct 3, 2019

The error seems related to tensorflow version (I could reproduce the bug in google colab)

@ChengYen-Tang
Copy link

@araffin
I also got this error, may I ask which version of tensorflow you are using?

@araffin
Copy link
Collaborator

araffin commented Feb 1, 2020

I also got this error, may I ask which version of tensorflow you are using?

tensorflow==1.8.0

So 1.8.0 cpu version

@ChengYen-Tang
Copy link

Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
custom gym env Issue related to Custom Gym Env
Projects
None yet
Development

No branches or pull requests

3 participants