Skip to content

Commit

Permalink
Working model
Browse files Browse the repository at this point in the history
  • Loading branch information
Rick committed Dec 4, 2018
1 parent 386b485 commit 425f9a1
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 33 deletions.
9 changes: 4 additions & 5 deletions agents/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def select_index(index,child,node_stats):

has_unvisited_node = False

_stats = np.zeros((2,len_c))
_stats = np.zeros((2, len_c), dtype=np.float32)

_max = 1.0

Expand Down Expand Up @@ -201,8 +201,8 @@ def _tmp_func(stats, act, node_stats, childs):
stats[1][act] += childs[i][1] * node[1] / node[0]
stats[2][act] += childs[i][1] * node[4] * np.sqrt( 1 / node[0] )
q_max = max(q_max, node[4])
stats[1][act] /= stats[0][act]
stats[2][act] /= stats[0][act]
stats[1][act] /= (stats[0][act]+eps)
stats[2][act] /= (stats[0][act]+eps)
stats[3][act] = len(childs)

return q_max
Expand Down Expand Up @@ -232,8 +232,7 @@ def select_index_2(game, node_dict, node_stats, child_info):
_stats_tmp = np.zeros((4, n_actions), dtype=np.float32)

_max = max([_tmp_func(_stats_tmp, i, node_stats, child_info[idx][i]) for i in range(n_actions)])
_a = _tmp_select(_stats_tmp, _max)

_a = _tmp_select(_stats_tmp, _max)
action.append(_a)

game.play(_a)
Expand Down
62 changes: 34 additions & 28 deletions model/model_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,57 +11,61 @@

EXP_PATH = './pytorch_model/'

def convOutShape(shape_in,kernel_size,stride):
return ((shape_in[0] - kernel_size) // stride + 1, (shape_in[1] - kernel_size) // stride + 1 )

class Net(nn.Module):

def __init__(self):
super(Net,self).__init__()

kernel_size = 3
stride = 1
filters = 16
self.conv1 = nn.Conv2d(1,filters,kernel_size)
self.conv2 = nn.Conv2d(filters,filters,kernel_size)
self.conv1 = nn.Conv2d(1, filters, kernel_size, stride)
self.bn1 = nn.BatchNorm2d(filters)
_shape = convOutShape((22,10), kernel_size, stride)
self.conv2 = nn.Conv2d(filters, filters, kernel_size, stride)
self.bn2 = nn.BatchNorm2d(filters)
_shape = convOutShape(_shape, kernel_size, stride)

n_convs = 2

self.flat_in = ( IMG_H - ( kernel_size - 1 ) * n_convs ) * ( IMG_W - ( kernel_size - 1 ) * n_convs ) * filters

self.flat_in = _shape[0] * _shape[1] * filters
self.flat_out = 64
self.fc1 = nn.Linear(self.flat_in,self.flat_out)


self.fc_p = nn.Linear(self.flat_out,6)
self.fc_v = nn.Linear(self.flat_out,1)
"""
self.fc_p = nn.Linear(self.flat_in,6)
self.fc_v = nn.Linear(self.flat_in,1)
"""
self.fc1 = nn.Linear(self.flat_in, self.flat_out)

self.fc_p = nn.Linear(self.flat_out, 6)
self.fc_v = nn.Linear(self.flat_out, 1)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = self.bn1(F.relu(self.conv1(x)))
x = self.bn2(F.relu(self.conv2(x)))
x = x.view(-1,self.flat_in)
x = F.relu(self.fc1(x))

policy = F.softmax(self.fc_p(x),dim=1)
policy = F.softmax(self.fc_p(x), dim=1)
#value = F.relu(self.fc_v(x))
value = torch.exp(self.fc_v(x))

return value, policy

def convert(x):
return Variable(torch.from_numpy(x.astype(np.float32)))
return torch.from_numpy(x.astype(np.float32))

class Model:
def __init__(self,new=True):

self.model = Net()
"""
self.optimizer = optim.SGD(self.model.parameters(),lr=1.0,momentum=0.99)
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,patience=10,verbose=True)
"""
self.optimizer = optim.Adam(self.model.parameters(),amsgrad=True)

#self.optimizer = optim.SGD(self.model.parameters(), lr=0.05, momentum=0.9, nesterov=True)
#self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,patience=10,verbose=True)

#self.optimizer = optim.Adam(self.model.parameters(), eps=1e-16, amsgrad=True)
self.optimizer = optim.Adam(self.model.parameters())
self.scheduler = None

#print(self.model)
#print(self.optimizer)
#if self.scheduler:
# print(self.scheduler)
def _loss(self,batch):

state = convert(batch[0])
Expand Down Expand Up @@ -105,8 +109,9 @@ def inference(self,batch):
self.model.eval()

state = convert(batch)

output = self.model(state)

with torch.no_grad():
output = self.model(state)

return output[0].data.numpy(), output[1].data.numpy()

Expand Down Expand Up @@ -144,6 +149,7 @@ def load(self):
checkpoint = torch.load(filename)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

else:
sys.stdout.write('Checkpoint not found, using default model\n')
sys.stdout.flush()

0 comments on commit 425f9a1

Please sign in to comment.