Skip to content

Commit

Permalink
Refactor Expr into its own module, ad module for forward-mode on labe…
Browse files Browse the repository at this point in the history
…lled expressions
  • Loading branch information
kfl committed Mar 23, 2017
1 parent 1089dc4 commit f2d2fbc
Show file tree
Hide file tree
Showing 3 changed files with 322 additions and 51 deletions.
115 changes: 115 additions & 0 deletions Expr.sml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
structure Expr =
struct

(* Labled expression with labels of type 'e on sub expressions, and
labels of type 'v on variables. Constants are unlabled.
*)
datatype ('e, 'v) Labled =
X of 'v * int
| Con of real
| Neg of ('e, 'v) Expr
| Plus of ('e, 'v) Expr * ('e, 'v) Expr
| Mult of ('e, 'v) Expr * ('e, 'v) Expr
| Exp of ('e, 'v) Expr (* e^x *)
| Sin of ('e, 'v) Expr (* sin x *)
| Cos of ('e, 'v) Expr (* cos x *)
withtype ('e, 'v) Expr = 'e * ('e, 'v) Labled



fun lookup xs i = Vector.sub(xs, i)

fun eval (_, exp) xs =
case exp of
X(_, i) => lookup xs i
| Con c => c
| Neg e => ~(eval e xs)
| Plus(e1, e2) => eval e1 xs + eval e2 xs
| Mult(e1, e2) => eval e1 xs * eval e2 xs
| Exp e => Math.exp (eval e xs)
| Sin e => Math.sin (eval e xs)
| Cos e => Math.cos (eval e xs)



(* Utility functions *)

fun unit_elab e = ((), e)
val & = unit_elab

val const = & o Con

fun ref_elab e = (ref NONE, e)

fun uvar i = X((), i)
fun rvar i = X(ref 0.0, i)



fun pp (_, exp) =
case exp of
X (_, i) => "X"^Int.toString i
| Con r => Real.toString r
| Neg (l, Con r) => pp(l, Con(~r))
| Neg (e as (_, X _)) => "-"^pp e
| Neg e => "-("^pp e^")"
| Plus(e1, (_, Neg e2)) => concat["(", pp e1, " - ", pp e2, ")"]
| Plus(e1, e2) => concat["(", pp e1, " + ", pp e2, ")"]
| Mult(e1, e2) => concat["(", pp e1, " * ", pp e2, ")"]
| Exp e => concat["e^", pp e]
| Sin e => concat["sin(", pp e,")"]
| Cos e => concat["cos(", pp e,")"]

fun ppv evec = concat(Vector.foldr op:: [] (Vector.mapi (fn(i, e) => concat["df/dX", Int.toString i, " = ", pp e,"\n"]) evec))



fun makeExpr elab vlab n =
let val one = elab(Con 1.0)
fun minus e1 e2 = elab(Plus(e1, elab(Neg e2)))
fun step e = elab(Exp (minus e one))
fun loop 0 acc = acc
| loop i acc = loop (i-1) (step acc)
in loop n (elab(vlab 0)) end

val small = makeExpr unit_elab uvar 2

val wpFunc =
let
fun x i = &(uvar i)
in &(Plus(&(Sin (x 0)), &(Mult(x 0, x 1)))) (* sin(x0) + (x0 * x1) *)
end

fun bigTime ad esize n_vars x0 =
let val xs = Vector.tabulate(n_vars, fn i => if i = 0 then x0 else 0.0)
in lookup (Mosml.time (ad xs) (makeExpr ref_elab uvar esize)) 0
end









fun makeFib elab vlab n =
let infix ++
fun x ++ y = elab(Plus(x, y))
val fib0 = elab(vlab 0)
val fib1 = elab(vlab 1)
fun step (f0, f1) = (f1, f0 ++ f1)
fun loop 0 acc = acc
| loop i acc = loop (i-1) (step acc)
in #2(loop (n-1) (fib0, fib1))
end

fun fib n = makeFib unit_elab uvar n

fun fibTime ad n n_vars =
let val xs = Vector.tabulate(n_vars, fn i => if i < 0 then 1.0 else 0.0)
in lookup (Mosml.time (ad xs) (makeFib ref_elab uvar n)) 0
end


end
73 changes: 73 additions & 0 deletions Forward.sml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
(*
datatype Expr = X of int
| Con of real
| Neg of Expr
| Plus of Expr * Expr
| Mult of Expr * Expr
| Exp of Expr (* e^x *)
| Sin of Expr (* sin x *)
| Cos of Expr (* cos x *)
*)

datatype Labled = datatype Expr.Labled
val & = Expr.&

fun zipWith f xs ys = Vector.mapi (fn (i, x) => f(x, Vector.sub(ys, i))) xs


fun zeroS n = Vector.tabulate(n, fn _ => Expr.const 0.0)
fun directionS n i = Vector.tabulate(n, fn j => if i = j then Expr.const 1.0 else Expr.const 0.0)
fun scalarS x v = Vector.map (fn e => &(Mult(x, e))) v


fun diff (labled as (_, expr)) n =
case expr of
X(_, i) => directionS n i
| Con _ => zeroS n
| Neg e => Vector.map (& o Neg) (diff e n)
| Plus(e1, e2) => zipWith (& o Plus) (diff e1 n) (diff e2 n)
| Mult(e1, e2) => zipWith (& o Plus)
(scalarS e1 (diff e2 n))
(scalarS e2 (diff e1 n))
| Exp e => scalarS labled (diff e n)
| Sin e => scalarS (&(Cos e)) (diff e n)
| Cos e => scalarS (&(Neg(&(Sin e)))) (diff e n)


fun dumbAD xs exp = Vector.map (fn e => Expr.eval e xs) (diff exp (Vector.length xs))

fun zero n = Vector.tabulate(n, fn _ => 0.0)

fun direction n i = Vector.tabulate(n, fn j => if i = j then 1.0 else 0.0)

fun scalar x v = Vector.map (fn e => x * e) v


type dualnum = real * real vector (* the result and the derivative *)


fun forward xs expr =
let val n = Vector.length xs
(* diffEval : Expr -> dualnum *)
fun diffEval (_, expr) =
case expr of
X(_, i) => (Expr.lookup xs i, direction n i)
| Con c => (c, zero n)
| Neg e => let val (ex, ed) = diffEval e
in (~ex, Vector.map ~ ed) end
| Plus(e1, e2) => let val (ex1, ed1) = diffEval e1
val (ex2, ed2) = diffEval e2
in (ex1 + ex2, zipWith op+ ed1 ed2) end
| Mult (e, e') => let val (ex, ed) = diffEval e
val (ex', ed') = diffEval e'
in (ex * ex', zipWith op+ (scalar ex ed')
(scalar ex' ed)) end
| Exp e => let val (ex, ed) = diffEval e
val exp_ex = Math.exp ex
in (exp_ex, scalar exp_ex ed) end
| Sin e => let val (ex, ed) = diffEval e
in (Math.sin ex, scalar (Math.cos ex) ed) end
| Cos e => let val (ex, ed) = diffEval e
in (Math.cos ex, scalar (~(Math.sin ex)) ed) end
in #2(diffEval expr) end

Loading

0 comments on commit f2d2fbc

Please sign in to comment.