Skip to content

Commit

Permalink
Add target normalization, rename sarsa to td (more appropriate?)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rick committed Dec 19, 2018
1 parent c77347d commit 771f864
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
parser.add_argument('--new', default=False, help='Create a new model instead of training the old one', action='store_true')
parser.add_argument('--val_episodes', default=0, type=float, help='Number of validation episodes (default:0)')
parser.add_argument('--val_total', default=25, type=int, help='Total number of validations (default:25)')
parser.add_argument('--sarsa', default=False, help='Sarsa update', action='store_true')
parser.add_argument('--save_loss', default=False, help='Save loss history', action='store_true')
parser.add_argument('--save_interval', default=100, type=int, help='Number of iterations between save_loss')
parser.add_argument('--shuffle', default=False, help='Shuffle dataset', action='store_true')
parser.add_argument('--target_normalization', default=False, help='Standardizes the targets', action='store_true')
parser.add_argument('--td', default=False, help='Temporal difference update', action='store_true')
args = parser.parse_args()

backend = args.backend
Expand All @@ -42,10 +43,11 @@
new = args.new
val_episodes = args.val_episodes
val_total = args.val_total
sarsa = args.sarsa
save_loss = args.save_loss
save_interval = args.save_interval
shuffle = args.shuffle
target_normalization = args.target_normalization
td = args.td

#========================
"""
Expand All @@ -64,7 +66,7 @@

loader = DataLoader(list_of_data)

if sarsa:
if td:
if eligibility_trace:
_child_stats = loader.child_stats
n = _child_stats[:,0]
Expand All @@ -88,7 +90,6 @@
else:
values = loader.value
variance = loader.variance
print(values.shape, variance.shape)
else:
values = np.zeros((len(loader.score), ), dtype=np.float32)
while idx < len(loader.episode):
Expand Down Expand Up @@ -134,6 +135,20 @@

batch_train = [states[t_idx], values[t_idx], variance[t_idx], policy[t_idx]]

#=========================
"""
TARGET NORMALIZATION
"""
if target_normalization:
v_mean = batch_train[1].mean()
v_std = batch_train[1].std()
var_mean = batch_train[2].mean()
var_std = batch_train[2].std()
batch_train[1] = (batch_train[1] - v_mean) / v_std
batch_train[2] = (batch_train[2] - var_mean) / var_std
batch_val[1] = (batch_val[1] - v_mean) / v_std
batch_val[2] = (batch_val[2] - var_mean) / var_std

#=========================
"""
MODEL SETUP
Expand All @@ -147,6 +162,10 @@
train_step = lambda batch, step: m.train(batch)
compute_loss = lambda batch: m.compute_loss(batch)
scheduler_step = lambda val_loss: m.update_scheduler(val_loss)
m.v_mean = v_mean
m.v_std = v_std
m.var_mean = var_mean
m.var_std = var_std
elif backend == 'tensorflow':
from model.model import Model
import tensorflow as tf
Expand Down

0 comments on commit 771f864

Please sign in to comment.