Skip to content

Commit

Permalink
Rename actions_input to action_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamStelmaszczyk committed Jun 8, 2018
1 parent 47ed2a7 commit 1493849
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def one_hot_encode(n, action):


def predict(env, model, observations):
actions_input = np.ones((len(observations), env.action_space.n))
return model.predict(x=[observations, actions_input])
action_mask = np.ones((len(observations), env.action_space.n))
return model.predict(x=[observations, action_mask])


def fit_batch(env, model, target_model, batch):
Expand All @@ -47,7 +47,6 @@ def fit_batch(env, model, target_model, batch):
next_q_values[dones] = 0.0
# The Q values of each start state is the reward + gamma * the max next state Q value
q_values = rewards + DISCOUNT_FACTOR_GAMMA * np.max(next_q_values, axis=1)
# Passing the actions as the mask and multiplying the targets by the actions masks.
one_hot_actions = np.array([one_hot_encode(env.action_space.n, action) for action in actions])
history = model.fit(
x=[observations, one_hot_actions],
Expand All @@ -62,12 +61,12 @@ def create_model(env):
n_actions = env.action_space.n
obs_shape = env.observation_space.shape
observations_input = keras.layers.Input(obs_shape, name='observations_input')
actions_input = keras.layers.Input((n_actions,), name='actions_input')
action_mask = keras.layers.Input((n_actions,), name='action_mask')
hidden = keras.layers.Dense(32, activation='relu')(observations_input)
hidden_2 = keras.layers.Dense(32, activation='relu')(hidden)
output = keras.layers.Dense(n_actions)(hidden_2)
filtered_output = keras.layers.multiply([output, actions_input])
model = keras.models.Model([observations_input, actions_input], filtered_output)
filtered_output = keras.layers.multiply([output, action_mask])
model = keras.models.Model([observations_input, action_mask], filtered_output)
optimizer = keras.optimizers.Adam(lr=LEARNING_RATE, clipnorm=1.0)
model.compile(optimizer, loss='mean_squared_error')
return model
Expand Down

0 comments on commit 1493849

Please sign in to comment.