Skip to content

Commit

Permalink
Change impl
Browse files Browse the repository at this point in the history
  • Loading branch information
jharrilim committed Apr 21, 2019
1 parent ec3980a commit 0de353a
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 45 deletions.
23 changes: 23 additions & 0 deletions lib/core/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion lib/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ crate-type = ["cdylib"]
[dependencies]
wasm-bindgen = "0.2"
rand = "0.6"
nalgebra = "0.18"
nalgebra = "0.18"
petgraph = "0.4.13"
88 changes: 65 additions & 23 deletions lib/core/src/activation.rs
Original file line number Diff line number Diff line change
@@ -1,45 +1,87 @@
use nalgebra::max;

pub fn sigmoid(a: f32) -> f32 {
1f32 / (1f32 + a.powf(-1f32))
#[derive(Copy, Clone)]
pub enum Activation {
Sigmoid,
Tanh,
ReLU,
Linear
}

pub fn sigmoid_prime(a: f32) -> f32 {
sigmoid(a) * (1f32 - sigmoid(a))

pub fn activation_from(activation: Activation) -> impl Fn(f32) -> f32 {
match activation {
Activation::Sigmoid => Sigmoid::call,
Activation::Tanh => Tanh::call,
Activation::ReLU => Relu::call,
Activation::Linear => Linear::call
}
}

pub fn tanh(a: f32) -> f32 {
a.tanh()

pub fn activation_derivative(activation: Activation) -> impl Fn(f32) -> f32 {
match activation {
Activation::Sigmoid => Sigmoid::first_prime,
Activation::Tanh => Tanh::first_prime,
Activation::ReLU => Relu::first_prime,
Activation::Linear => Linear::first_prime
}
}

pub fn tanh_prime(a: f32) -> f32 {
let tanh_a = a.tanh();

1 - tanh_a * tanh_a
pub trait ActivationFunction {
fn call(a: f32) -> f32;
fn first_prime(a: f32) -> f32;
}

pub fn relu(a: f32) -> f32 {
max(0f32, a)

pub struct Sigmoid;
impl ActivationFunction for Sigmoid {
fn call(a: f32) -> f32 {
1f32 / (1f32 + a.powf(-1f32))
}
fn first_prime(a: f32) -> f32 {
Sigmoid::call(a) * (1f32 - Sigmoid::call(a))
}
}

pub fn relu_prime (a: f32) -> f32 {
if a <= 0f32 { 0f32 } else { 1f32 }

pub struct Tanh;
impl ActivationFunction for Tanh {
fn call(a: f32) -> f32 {
a.tanh()
}

fn first_prime(a: f32) -> f32 {
let tanh_a = a.tanh();
1f32 - tanh_a * tanh_a
}
}

pub fn linear(a: f32) -> f32 {
a

pub struct Relu;
impl ActivationFunction for Relu {
fn call(a: f32) -> f32 {
if a > 0f32 { a } else { 0f32 }
}

fn first_prime(a: f32) -> f32 {
if a <= 0f32 { 0f32 } else { 1f32 }
}
}

pub fn linear_prime(a: f32) -> f32 {
1f32

pub struct Linear;
impl ActivationFunction for Linear {
fn call(a: f32) -> f32 {
a
}

fn first_prime(a: f32) -> f32 {
1f32
}
}


#[cfg(test)]
mod test {
use super::*;

pub fn test() {
sigmoid(5f32);
}
}
1 change: 1 addition & 0 deletions lib/core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
extern crate wasm_bindgen;
extern crate rand;
extern crate nalgebra;
extern crate petgraph;

use wasm_bindgen::prelude::*;

Expand Down
27 changes: 17 additions & 10 deletions lib/core/src/network.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
use nalgebra::Vector2;

use crate::neuron::Neuron;
use nalgebra::base::Matrix;
use nalgebra::base::dimension::Dynamic;
use crate::activation::Activation;

pub struct Network {
neurons: Vec<Neuron>,
center: Vector2<f32>,
hidden_layer_matrix: Vec<Vec<Neuron>>
hidden_layer_matrix: Vec<Vec<Neuron>>,
epoch: u64
}

impl Network {
pub fn new() -> Network {
Network {
neurons: Vec::<Neuron>::new(),
center: Vector2::new(0f32, 0f32),
hidden_layer_matrix: Vec::new()
hidden_layer_matrix: Vec::new(),
epoch: 0u64
}
}

Expand All @@ -27,12 +24,22 @@ impl Network {
self.neurons.extend(neurons)
}

pub fn add_hidden_layer(&mut self, size: u32) {
pub fn add_hidden_layer(&mut self, size: u32, activation: Option<Activation>) {
if size == 0 {
self.hidden_layer_matrix.push(Vec::new());
} else {
let layer: Vec<Neuron> = (0..size).map(|_| Neuron::new()).collect::<Vec<Neuron>>();
let layer: Vec<Neuron> = (0..size)
.map(|_| Neuron::new(
activation
.unwrap_or(Activation::ReLU)
)).collect::<Vec<Neuron>>();
self.hidden_layer_matrix.push(layer);
}
}

pub fn remove_hidden_layer_at(&mut self, index: usize) {
self.hidden_layer_matrix.remove(index);
}

}

42 changes: 38 additions & 4 deletions lib/core/src/neuron.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,47 @@
use nalgebra::base::Vector2;
use crate::synapse::Synapse;
use crate::activation::{Activation};

pub struct Neuron {
pub location: Vector2<f32>
inputs: Vec<Synapse>,
outputs: Vec<Synapse>,
bias: f32,
input_total: u32,
output_total: u32,

/// Error derivative with respect to the node's output.
output_derivative: f32,

/// Error derivative with respect to this node's total input.
input_derivative: f32,


/// Accumulated error derivative with respect to this node's total input since
/// the last update. This derivative equals dE/db where b is the node's
/// bias term.
accumulated_input_derivative: f32,

/// Number of accumulated err. derivatives with respect to the total input
/// since the last update.
accumulated_derivatives: u32,

/// A function that accepts an input, in this case the sum of the total input, and returns
/// an output.
activation: Activation
}

impl Neuron {
pub fn new() -> Neuron {
pub fn new(activation: Activation) -> Neuron {
Neuron {
location: Vector2::new(0f32, 0f32)
inputs: Vec::new(),
outputs: Vec::new(),
bias: 0.1f32,
input_total: 0,
output_total: 0,
output_derivative: 0f32,
input_derivative: 0f32,
accumulated_input_derivative: 0f32,
accumulated_derivatives: 0,
activation
}
}
}
15 changes: 8 additions & 7 deletions lib/core/src/synapse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ use rand::{ thread_rng, Rng };

use crate::neuron::Neuron;

pub struct Synapse<'n> {
a: &'n Neuron,
b: &'n Neuron,
pub struct Synapse {
a: Neuron,
b: Neuron,
weight: f32
}

impl <'n> Synapse<'n> {
pub fn new(a: &'n Neuron, b: &'n Neuron) -> Synapse<'n> {
impl Synapse {
pub fn new(a: Neuron, b: Neuron) -> Synapse {
Synapse { a, b, weight: thread_rng().gen::<f32>() }
}
}
Expand All @@ -21,10 +21,11 @@ mod tests {
use super::*;
use crate::point::Point;
use nalgebra::geometry::Vector2;
use crate::activation::Activation;

pub fn create_synapse() {
let n1 = Neuron { location: Vector2::new(3f32, 2f32) };
let n2 = Neuron { location: Vector2::new(40f32, 70f32) };
let n1 = Neuron::new(Activation::ReLU);
let n2 = Neuron::new(Activation::ReLU);

let syn = Synapse::new(&n1, &n2);
}
Expand Down

0 comments on commit 0de353a

Please sign in to comment.