Skip to content

Commit

Permalink
Refactoring FastXML -> Trainer, Inferencer
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Stanton committed Jul 16, 2017
1 parent 6d34303 commit 3b415e2
Show file tree
Hide file tree
Showing 4 changed files with 576 additions and 3 deletions.
2 changes: 1 addition & 1 deletion fastxml/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .fastxml import FastXML
from .fastxml import Trainer, Inferencer
48 changes: 46 additions & 2 deletions fastxml/fastxml.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,51 @@ def dense_rows_iter(dense):

yield ' '.join(dense_lines)

class FastXML(object):
class Inferencer(object):
"""
Loads up a model for inferencing
"""
def __init__(self, dname, gamma=30, blend=0.8, leaf_probs=False):
with file(os.path.join(dname, 'settings')) as f:
self.__dict__.update(json.load(f))

self.gamma = gamma
self.blend = blend
self.leaf_probs = leaf_probs

forest = IForest(dname, self.n_trees, self.n_labels)
if self.leaf_classifiers:
lc = LeafComputer(dname)
predictor = Blender(forest, lc)
else:
predictor = IForestBlender(forest)

self.predictor = predictor

def predict(self, X, fmt='sparse'):
assert fmt in ('sparse', 'dict')
s = []
num = X.shape[0] if isinstance(X, sp.csr_matrix) else len(X)
for i in xrange(num):
Xi = X[i]
mean = self.predictor.predict(Xi.data, Xi.indices, self.blend, self.gamma, self.leaf_probs)

if fmt == 'sparse':
s.append(mean)

else:
od = OrderedDict()
for idx in reversed(mean.data.argsort()):
od[mean.indices[idx]] = mean.data[idx]

s.append(od)

if fmt == 'sparse':
return sp.vstack(s)

return s

class Trainer(object):

def __init__(self, n_trees=1, max_leaf_size=10, max_labels_per_leaf=20,
re_split=0, n_jobs=1, alpha=1e-4, n_epochs=2, n_updates=100, bias=True,
Expand Down Expand Up @@ -298,7 +342,7 @@ def _save_settings(self, dname):
import json
settings = {}
for k, v in self.__dict__.iteritems():
if k == 'roots' or k.endswith('_'):
if k == 'roots' or k == 'predictor' or k.endswith('_'):
continue

settings[k] = v
Expand Down
Empty file added fastxml/inferencer.py
Empty file.
Loading

0 comments on commit 3b415e2

Please sign in to comment.