You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, my issue concerns the usage of net_acrh parameter inside LstmPolicy. This will helps to implement custom CnnLstmPolicy.
Now, LstmPolicy from stable_baselines.common.policies has following code with NotImplementedError() when net_arch is not None:
class LstmPolicy(RecurrentActorCriticPolicy):
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256, reuse=False, layers=None,
net_arch=None, act_fun=tf.tanh, cnn_extractor=nature_cnn, layer_norm=False, feature_extraction="cnn",
**kwargs):
# state_shape = [n_lstm * 2] dim because of the cell and hidden states of the LSTM
super(LstmPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch,
state_shape=(2 * n_lstm, ), reuse=reuse,
scale=(feature_extraction == "cnn"))
self._kwargs_check(feature_extraction, kwargs)
if net_arch is None: # Legacy mode
if layers is None:
layers = [64, 64]
else:
warnings.warn("The layers parameter is deprecated. Use the net_arch parameter instead.")
with tf.variable_scope("model", reuse=reuse):
if feature_extraction == "cnn":
extracted_features = cnn_extractor(self.processed_obs, **kwargs)
else:
extracted_features = tf.layers.flatten(self.processed_obs)
for i, layer_size in enumerate(layers):
extracted_features = act_fun(linear(extracted_features, 'pi_fc' + str(i), n_hidden=layer_size,
init_scale=np.sqrt(2)))
input_sequence = batch_to_seq(extracted_features, self.n_env, n_steps)
masks = batch_to_seq(self.dones_ph, self.n_env, n_steps)
rnn_output, self.snew = lstm(input_sequence, masks, self.states_ph, 'lstm1', n_hidden=n_lstm,
layer_norm=layer_norm)
rnn_output = seq_to_batch(rnn_output)
value_fn = linear(rnn_output, 'vf', 1)
self._proba_distribution, self._policy, self.q_value = \
self.pdtype.proba_distribution_from_latent(rnn_output, rnn_output)
self._value_fn = value_fn
else: # Use the new net_arch parameter
if layers is not None:
warnings.warn("The new net_arch parameter overrides the deprecated layers parameter.")
if feature_extraction == "cnn":
raise NotImplementedError()
....
So, the solution is simple, do it in analogy to net_arch==None case. just use provided cnn_extractor in a case of "cnn" feature_extraction to preprocess input images to flatten layer that would go next through net_arch layers.
The text was updated successfully, but these errors were encountered:
Hello, my issue concerns the usage of
net_acrh
parameter insideLstmPolicy
. This will helps to implement customCnnLstmPolicy
.Now,
LstmPolicy
fromstable_baselines.common.policies
has following code withNotImplementedError()
whennet_arch
is not None:So, the solution is simple, do it in analogy to
net_arch==None
case. just use providedcnn_extractor
in a case of "cnn"feature_extraction
to preprocess input images to flatten layer that would go next throughnet_arch
layers.The text was updated successfully, but these errors were encountered: