Skip to content

Commit

Permalink
release/1.0.0: Add Go templates for the LinearSVC estimator
Browse files Browse the repository at this point in the history
  • Loading branch information
nok committed Jan 31, 2020
1 parent 9420ea6 commit 8f30d72
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 12 deletions.
2 changes: 1 addition & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ This table gives an overview over all supported combinations of estimators, prog
<td></td>
<td>×</td>
<td>✓</td>
<td></td>
<td></td>
<td>×</td>
<td>✓</td>
<td>✓</td>
Expand Down
1 change: 1 addition & 0 deletions sklearn_porter/estimator/LinearSVC/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class LinearSVC(EstimatorBase, EstimatorApiABC):
},
enum.Language.GO: {
enum.Template.ATTACHED: {enum.Method.PREDICT},
enum.Template.EXPORTED: {enum.Method.PREDICT},
},
enum.Language.JAVA: {
enum.Template.ATTACHED: {enum.Method.PREDICT},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ import (
)

type {{ class_name }} struct {
coeffs []float64
inters float64
Coeffs []float64
Inters float64
}
{% else %}
import (
Expand All @@ -29,8 +29,8 @@ import (
)

type {{ class_name }} struct {
coeffs [][]float64
inters []float64
Coeffs [][]float64
Inters []float64
}
{% endif %}

Expand All @@ -43,10 +43,10 @@ type Response struct {
{% if is_binary %}
func (svc {{ class_name }}) Predict(features []float64) int {
var prob float64
for i := 0; i < len(svc.coeffs); i++ {
prob = prob + svc.coeffs[i] * features[i]
for i := 0; i < len(svc.Coeffs); i++ {
prob = prob + svc.Coeffs[i] * features[i]
}
if (prob + svc.inters) > 0 {
if (prob + svc.Inters) > 0 {
return 1
}
return 0
Expand All @@ -55,14 +55,14 @@ func (svc {{ class_name }}) Predict(features []float64) int {
func (svc {{ class_name }}) Predict(features []float64) int {
classIdx := 0
classVal := math.Inf(-1)
outerCount, innerCount := len(svc.coeffs), len(svc.coeffs[0])
outerCount, innerCount := len(svc.Coeffs), len(svc.Coeffs[0])
for i := 0; i < outerCount; i++ {
var prob float64
for j := 0; j < innerCount; j++ {
prob = prob + svc.coeffs[i][j] * features[j]
prob = prob + svc.Coeffs[i][j] * features[j]
}
if prob + svc.inters[i] > classVal {
classVal = prob + svc.inters[i]
if prob + svc.Inters[i] > classVal {
classVal = prob + svc.Inters[i]
classIdx = i
}
}
Expand Down
109 changes: 109 additions & 0 deletions sklearn_porter/estimator/LinearSVC/templates/go/exported.class.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
{% extends 'base.exported.class' %}

{% block content %}
package main

{% if is_binary %}
import (
{% if is_test or to_json %}
"encoding/json"
{% endif %}
"fmt"
"io/ioutil"
"os"
"strconv"
)

type {{ class_name }} struct {
Coeffs []float64 `json:"coeffs"`
Inters float64 `json:"inters"`
}
{% else %}
import (
{% if is_test or to_json %}
"encoding/json"
{% endif %}
"fmt"
"io/ioutil"
"math"
"os"
"strconv"
)

type {{ class_name }} struct {
Coeffs [][]float64 `json:"coeffs"`
Inters []float64 `json:"inters"`
}
{% endif %}

{% if is_test or to_json %}
type Response struct {
Predict int `json:"predict"`
}
{% endif %}

{% if is_binary %}
func (svc {{ class_name }}) Predict(features []float64) int {
var prob float64
for i := 0; i < len(svc.Coeffs); i++ {
prob = prob + svc.Coeffs[i] * features[i]
}
if (prob + svc.Inters) > 0 {
return 1
}
return 0
}
{% else %}
func (svc {{ class_name }}) Predict(features []float64) int {
classIdx := 0
classVal := math.Inf(-1)
outerCount, innerCount := len(svc.Coeffs), len(svc.Coeffs[0])
for i := 0; i < outerCount; i++ {
var prob float64
for j := 0; j < innerCount; j++ {
prob = prob + svc.Coeffs[i][j] * features[j]
}
if prob + svc.Inters[i] > classVal {
classVal = prob + svc.Inters[i]
classIdx = i
}
}
return classIdx
}
{% endif %}

func main() {

// Features:
var features []float64
for _, arg := range os.Args[2:] {
if n, err := strconv.ParseFloat(arg, 64); err == nil {
features = append(features, n)
}
}

// Model data:
jsonFile, err := os.Open(os.Args[1])
if err != nil {
fmt.Println(err)
}
defer jsonFile.Close()
byteValue, _ := ioutil.ReadAll(jsonFile)

// Estimator:
var clf {{ class_name }}
json.Unmarshal(byteValue, &clf)

{% if is_test or to_json %}
// Get JSON:
prediction := clf.Predict(features)
res, _ := json.Marshal(&Response{Predict: prediction})
fmt.Println(string(res))
{% else %}
// Get class prediction:
prediction := clf.Predict(features)
fmt.Printf("Predicted class: #%d\n", prediction)
{% endif %}

}
{% endblock %}

0 comments on commit 8f30d72

Please sign in to comment.