Skip to content

Commit

Permalink
Restructure model data and algorithm handling
Browse files Browse the repository at this point in the history
  • Loading branch information
nok committed Oct 15, 2017
1 parent 4cdcfde commit 59a0e91
Show file tree
Hide file tree
Showing 9 changed files with 198 additions and 154 deletions.
79 changes: 49 additions & 30 deletions examples/estimator/classifier/KNeighborsClassifier/java/basics.py

Large diffs are not rendered by default.

88 changes: 46 additions & 42 deletions examples/estimator/classifier/KNeighborsClassifier/js/basics.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,27 @@ def create_method(self):
temp_arr_ = self.temp('arr[]')
temp_arr__ = self.temp('arr[][]')

temp_method = self.temp('method.predict', n_indents=1, skipping=True)
out = temp_method.format(class_name=self.class_name,
method_name=self.method_name,
distance_computation=distance_comp)
return out

def create_class(self, method):
"""
Build the estimator class.
Returns
-------
:return out : string
The built class as string.
"""

temp_type = self.temp('type')
temp_arr = self.temp('arr')
temp_arr_ = self.temp('arr[]')
temp_arr__ = self.temp('arr[][]')

# Templates
temps = []
for atts in enumerate(self.estimator._fit_X): # pylint: disable=W0212
Expand All @@ -144,28 +165,12 @@ def create_method(self):
classes = temp_arr_.format(type='int', name='y', values=classes,
n=self.n_templates)

temp_method = self.temp('method.predict', n_indents=1, skipping=True)
out = temp_method.format(class_name=self.class_name,
method_name=self.method_name,
n_neighbors=self.n_neighbors,
n_templates=self.n_templates,
n_features=self.n_features,
n_classes=self.n_classes,
distance_computation=distance_comp,
power=self.power_param, X=temps, y=classes)
return out

def create_class(self, method):
"""
Build the estimator class.
Returns
-------
:return out : string
The built class as string.
"""
temp_class = self.temp('class')
out = temp_class.format(class_name=self.class_name,
method_name=self.method_name, method=method,
n_features=self.n_features)
n_features=self.n_features, X=temps, y=classes,
n_neighbors=self.n_neighbors,
n_templates=self.n_templates,
n_classes=self.n_classes,
power=self.power_param)
return out
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,22 @@ import java.util.*;

class {class_name} {{

private int nNeighbors;
private int nTemplates;
private int nClasses;
private double power;
private double[][] X;
private int[] y;

public {class_name}(int nNeighbors, int nTemplates, int nClasses, double power, double[][] X, int[] y) {{
this.nNeighbors = nNeighbors;
this.nTemplates = nTemplates;
this.nClasses = nClasses;
this.power = power;
this.X = X;
this.y = y;
}}

private static class Neighbor {{
Integer clazz;
Double dist;
Expand All @@ -15,11 +31,23 @@ class {class_name} {{

public static void main(String[] args) {{
if (args.length == {n_features}) {{
double[] atts = new double[args.length];

// Features:
double[] features = new double[args.length];
for (int i = 0, l = args.length; i < l; i++) {{
atts[i] = Double.parseDouble(args[i]);
features[i] = Double.parseDouble(args[i]);
}}
System.out.println({class_name}.{method_name}(atts));

// Parameters:
{X}
{y}

// Prediction:
{class_name} clf = new {class_name}({n_neighbors}, {n_templates}, {n_classes}, {power}, X, y);
int estimation = clf.{method_name}(features);
System.out.println(estimation);

}}
}}

}}
Original file line number Diff line number Diff line change
@@ -1,46 +1,34 @@
{distance_computation}

public static int {method_name}(double[] atts) {{
if (atts.length != {n_features}) {{
return -1;
}}

{X}
{y}

public int {method_name}(double[] features) {{
int classIdx = -1;
int nNeighbors = {n_neighbors};
int nTemplates = {n_templates};
int nClasses = {n_classes};
double power = {power};

if (nNeighbors == 1) {{
if (this.nNeighbors == 1) {{
double minDist = Double.POSITIVE_INFINITY;
double curDist;
for (int i = 0; i < nTemplates; i++) {{
curDist = {class_name}.compDist(X[i], atts, power);
for (int i = 0; i < this.nTemplates; i++) {{
curDist = {class_name}.compute(this.X[i], features, this.power);
if (curDist <= minDist) {{
minDist = curDist;
classIdx = y[i];
}}
}}
}} else {{
int[] classes = new int[nClasses];
int[] classes = new int[this.nClasses];
ArrayList<Neighbor> dists = new ArrayList<Neighbor>();
for (int i = 0; i < nTemplates; i++) {{
dists.add(new Neighbor(y[i], {class_name}.compDist(X[i], atts, power)));
for (int i = 0; i < this.nTemplates; i++) {{
dists.add(new Neighbor(y[i], {class_name}.compute(this.X[i], features, this.power)));
}}
Collections.sort(dists, new Comparator<Neighbor>() {{
@Override
public int compare(Neighbor n1, Neighbor n2) {{
return n1.dist.compareTo(n2.dist);
}}
}});
for (Neighbor neighbor : dists.subList(0, nNeighbors)) {{
for (Neighbor neighbor : dists.subList(0, this.nNeighbors)) {{
classes[neighbor.clazz]++;
}}
int classVal = -1;
for (int i = 0; i < nClasses; i++) {{
for (int i = 0; i < this.nClasses; i++) {{
if (classes[i] > classVal) {{
classVal = classes[i];
classIdx = i;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
public static double compDist(double[] temp, double[] cand, double q) {
private static double compute(double[] temp, double[] cand, double q) {
double dist = 0.;
double diff;
for (int i = 0, l = temp.length; i < l; i++) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,35 @@
// Array.prototype.fill polyfill:
[].fill||(Array.prototype.fill=function(a){{for(var b=Object(this),c=parseInt(b.length,10),d=arguments[1],e=parseInt(d,10)||0,f=0>e?Math.max(c+e,0):Math.min(e,c),g=arguments[2],h=void 0===g?c:parseInt(g)||0,i=0>h?Math.max(c+h,0):Math.min(h,c);i>f;f++)b[f]=a;return b}});
var {class_name} = function(nNeighbors, nTemplates, nClasses, power, X, y) {{

var Neighbor = function(clazz, dist) {{
this.clazz = clazz;
this.dist = dist;
}};
this.nNeighbors = nNeighbors;
this.nTemplates = nTemplates;
this.nClasses = nClasses;
this.power = power;
this.X = X;
this.y = y;

var {class_name} = function() {{
var Neighbor = function(clazz, dist) {{
this.clazz = clazz;
this.dist = dist;
}};

{method}

}};

if (typeof process !== 'undefined' && typeof process.argv !== 'undefined') {{
if (process.argv.length - 2 == {n_features}) {{
var argv = process.argv.slice(2);
var prediction = new {class_name}().{method_name}(argv);
if (process.argv.length - 2 === {n_features}) {{

// Features:
var features = process.argv.slice(2);

// Parameters:
{X}
{y}

// Estimator:
var clf = new {class_name}({n_neighbors}, {n_templates}, {n_classes}, {power}, X, y);
var prediction = clf.{method_name}(features);
console.log(prediction);

}}
}}
Original file line number Diff line number Diff line change
@@ -1,45 +1,31 @@
{distance_computation}

this.{method_name} = function(atts) {{

if (atts.length != {n_features}) {{
return -1;
}}

{X}
{y}

var classIdx = -1,
nNeighbors = {n_neighbors},
nTemplates = {n_templates},
nClasses = {n_classes},
power = {power},
i;

if (nNeighbors == 1) {{
this.{method_name} = function(features) {{
var classIdx = -1, i;
if (this.nNeighbors == 1) {{
var minDist = Number.POSITIVE_INFINITY,
curDist;
for (i = 0; i < nTemplates; i++) {{
curDist = compDist(X[i], atts, power);
for (i = 0; i < this.nTemplates; i++) {{
curDist = compute(this.X[i], features, this.power);
if (curDist <= minDist) {{
minDist = curDist;
classIdx = y[i];
classIdx = this.y[i];
}}
}}
}} else {{
var classes = new Array(nClasses).fill(0);
var classes = new Array(this.nClasses).fill(0);
var dists = [];
for (i = 0; i < nTemplates; i++) {{
dists.push(new Neighbor(y[i], compDist(X[i], atts, power)));
for (i = 0; i < this.nTemplates; i++) {{
dists.push(new Neighbor(this.y[i], compute(this.X[i], features, this.power)));
}}
dists.sort(function compare(n1, n2) {{
return (n1.dist < n2.dist) ? -1 : 1;
}});
for (i = 0; i < nNeighbors; i++) {{
for (i = 0; i < this.nNeighbors; i++) {{
classes[dists[i].clazz]++;
}}
var classVal = -1;
for (i = 0; i < nClasses; i++) {{
for (i = 0; i < this.nClasses; i++) {{
if (classes[i] > classVal) {{
classVal = classes[i];
classIdx = i;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
var compDist = function(temp, cand, q) {
var compute = function(temp, cand, q) {
var dist = 0.,
diff;
for (var i = 0, l = temp.length; i < l; i++) {
Expand Down

0 comments on commit 59a0e91

Please sign in to comment.