-
Notifications
You must be signed in to change notification settings - Fork 6
/
rl.py
61 lines (52 loc) · 2.34 KB
/
rl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from pylab import *
from stuff import *
GAMMA=0.9 #Discout factor
LAMBDA=0#.1 #Regularization coeff for LSTDQ
def greedy_policy( omega, phi, A, s_dim=2 ):
def policy( *args ):
state_actions = [hstack(args+(a,)) for a in A]
q_value = lambda sa: float(dot(omega.transpose(),phi(sa)))
best_action = argmax( state_actions, q_value )[-1] #FIXME6: does not work for multi dimensional actions
return best_action
vpolicy = non_scalar_vectorize( policy, (s_dim,), (1,1) )
return lambda state: vpolicy(state).reshape(state.shape[:-1]+(1,))
def lstdq(phi_sa, phi_sa_dash, rewards, phi_dim=1):
#print "shapes of phi de sa, phi de sprim a prim, rewards"+str(phi_sa.shape)+str(phi_sa_dash.shape)+str(rewards.shape)
A = zeros((phi_dim,phi_dim))
b = zeros((phi_dim,1))
for phi_t,phi_t_dash,reward in zip(phi_sa,phi_sa_dash,rewards):
A = A + dot( phi_t,
(phi_t - GAMMA*phi_t_dash).transpose())
b = b + phi_t*reward
return dot(inv(A + LAMBDA*identity( phi_dim )),b)
def lspi( data, s_dim=1, a_dim=1, A = [0], phi=None, phi_dim=1, epsilon=0.01, iterations_max=30,
plot_func=None):
nb_iterations=0
sa = data[:,0:s_dim+a_dim]
phi_sa = phi(sa)
s_dash = data[:,s_dim+a_dim:s_dim+a_dim+s_dim]
rewards = data[:,s_dim+a_dim+s_dim]
omega = zeros(( phi_dim, 1 ))
#omega = genfromtxt("../Code/InvertedPendulum/omega_E.mat").reshape(30,1)
diff = float("inf")
cont = True
policy = greedy_policy( omega, phi, A )
while cont:
if plot_func:
plot_func(omega)
sa_dash = hstack([s_dash,policy(s_dash)])
phi_sa_dash = phi(sa_dash)
omega_dash = lstdq(phi_sa, phi_sa_dash, rewards, phi_dim=phi_dim)
diff = norm( omega_dash-omega )
omega = omega_dash
policy = greedy_policy( omega, phi, A )
nb_iterations+=1
print "LSPI, iter :"+str(nb_iterations)+", diff : "+str(diff)
if nb_iterations > iterations_max or diff < epsilon:
cont = False
sa_dash = hstack([s_dash,policy(s_dash)])
phi_sa_dash = phi(sa_dash)
omega = lstdq(phi_sa, phi_sa_dash, rewards, phi_dim=phi_dim) #Omega is the Qvalue of pi, but pi is not the greedy policy w.r.t. omega
return policy,omega
def argmax( set, func ):
return max( zip( set, map(func,set) ), key=lambda x:x[1] )[0]