Skip to content

Commit

Permalink
Fixed Vanilla agent
Browse files Browse the repository at this point in the history
  • Loading branch information
Rick committed Dec 4, 2018
1 parent 425f9a1 commit b6cdcf7
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions agents/Vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
n_actions = 6

class Vanilla(Agent):
def __init__(self,conf,sims,tau=None,gamma=0.99):
def __init__(self, conf, sims, tau=None, env=None, env_args=None):

super().__init__(sims=sims,backend=None)
super().__init__(sims=sims, backend=None, env=env, env_args=env_args)

self.g_tmp = env(*env_args)

def mcts(self,root_index):
trace = select_index(root_index,self.arrs['child'],self.arrs['node_stats'])
Expand All @@ -20,12 +22,12 @@ def mcts(self,root_index):

value = leaf_game.getScore() #- self.game_arr[root_index].getScore()


if not leaf_game.end:

#v, p = self.evaluate_state(leaf_game.getState())

_g = leaf_game.clone()
_g = self.g_tmp
_g.copy_from(leaf_game)

while not _g.end:
_act = randint(0,n_actions-1)
Expand All @@ -34,7 +36,7 @@ def mcts(self,root_index):
value = _g.getScore()

for i in range(n_actions):
_g = leaf_game.clone()
_g.copy_from(leaf_game)
_g.play(i)
_n = self.new_node(_g)

Expand Down

0 comments on commit b6cdcf7

Please sign in to comment.