Skip to content

Commit

Permalink
Merge pull request dmlc#165 from viktor-ferenczi/viktor-ferenczi
Browse files Browse the repository at this point in the history
Python wrapper fixes
  • Loading branch information
antinucleon committed May 28, 2015
2 parents df6c7d0 + aac1dfd commit 7b195c4
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions wrapper/cxxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# set this line correctly
if os.name == 'nt':
# TODO windows
CXXNET_PATH = os.path.dirname(__file__) + '/libcxxnetwrapper.dll'
CXXNET_PATH = os.path.join(os.path.dirname(__file__), 'libcxxnetwrapper.dll')
else:
CXXNET_PATH = os.path.dirname(__file__) + '/libcxxnetwrapper.so'
CXXNET_PATH = os.path.join(os.path.dirname(__file__), 'libcxxnetwrapper.so')

# load in xgboost library
cxnlib = ctypes.cdll.LoadLibrary(CXXNET_PATH)
Expand All @@ -35,7 +35,8 @@ def ctypes2numpy(cptr, length, dtype=numpy.float32):
"""convert a ctypes pointer array to numpy array """
#assert isinstance(cptr, ctypes.POINTER(ctypes.c_float))
res = numpy.zeros(length, dtype=dtype)
assert ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0])
if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]):
raise AssertionError('ctypes.memmove failed')
return res

def ctypes2numpyT(cptr, shape, dtype=numpy.float32, stride = None):
Expand All @@ -45,11 +46,13 @@ def ctypes2numpyT(cptr, shape, dtype=numpy.float32, stride = None):
size *= x
if stride is None:
res = numpy.zeros(size, dtype=dtype)
assert ctypes.memmove(res.ctypes.data, cptr, size * res.strides[0])
if not ctypes.memmove(res.ctypes.data, cptr, size * res.strides[0]):
raise AssertionError('ctypes.memmove failed')
else:
dsize = size / shape[-1] * stride
res = numpy.zeros(dsize, dtype=dtype)
assert ctypes.memmove(res.ctypes.data, cptr, dsize * res.strides[0])
if not ctypes.memmove(res.ctypes.data, cptr, dsize * res.strides[0]):
raise AssertionError('ctypes.memmove failed')
res = res.reshape((dsize / shape[-1], shape[-1]))
res = res[:, 0 :shape[-1]]
return res.reshape(shape)
Expand Down Expand Up @@ -166,11 +169,15 @@ def update(self, data, label = None):
if not isinstance(label, numpy.ndarray):
raise Exception('Net.update: label need to be ndarray')
if label.ndim == 1:
label = label.reshape((label.size(0),1))
label = label.reshape(label.shape[0], 1)
if label.ndim != 2:
raise Exception('Net.update: label need to be 2 dimension or one dimension ndarray')
if label.shape[0] != data.shape[0]:
raise Exception('Net.update: data size mismatch')
if data.dtype != numpy.float32:
raise Exception('Net.update: data must be of type numpy.float32')
if label.dtype != numpy.float32:
raise Exception('Net.update: label must be of type numpy.float32')
cxnlib.CXNNetUpdateBatch(self.handle,
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
shape2ctypes(data),
Expand Down Expand Up @@ -278,7 +285,7 @@ def get_weight(self, layer_name, tag):
return None
return ctypes2numpyT(ret, [oshape[i] for i in range(odim.value)], 'float32')

def train(cfg, data, num_round, param, eval_data = None):
def train(cfg, data, label, num_round, param, eval_data = None):
net = Net(cfg = cfg)
if isinstance(param, dict):
param = param.items()
Expand Down

0 comments on commit 7b195c4

Please sign in to comment.