Skip to content

Commit

Permalink
Checking dtype of data and label passed to update method
Browse files Browse the repository at this point in the history
Added missing label parameter to the train global function
  • Loading branch information
viktor-ferenczi committed May 28, 2015
1 parent 3b015c2 commit 6802688
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion wrapper/cxxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ def update(self, data, label = None):
raise Exception('Net.update: label need to be 2 dimension or one dimension ndarray')
if label.shape[0] != data.shape[0]:
raise Exception('Net.update: data size mismatch')
if data.dtype != numpy.float32:
raise Exception('Net.update: data must be of type numpy.float32')
if label.dtype != numpy.float32:
raise Exception('Net.update: label must be of type numpy.float32')
cxnlib.CXNNetUpdateBatch(self.handle,
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
shape2ctypes(data),
Expand Down Expand Up @@ -278,7 +282,7 @@ def get_weight(self, layer_name, tag):
return None
return ctypes2numpyT(ret, [oshape[i] for i in range(odim.value)], 'float32')

def train(cfg, data, num_round, param, eval_data = None):
def train(cfg, data, label, num_round, param, eval_data = None):
net = Net(cfg = cfg)
if isinstance(param, dict):
param = param.items()
Expand Down

0 comments on commit 6802688

Please sign in to comment.