-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor Expr into its own module, ad module for forward-mode on labe…
…lled expressions
- Loading branch information
Showing
3 changed files
with
322 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
structure Expr = | ||
struct | ||
|
||
(* Labled expression with labels of type 'e on sub expressions, and | ||
labels of type 'v on variables. Constants are unlabled. | ||
*) | ||
datatype ('e, 'v) Labled = | ||
X of 'v * int | ||
| Con of real | ||
| Neg of ('e, 'v) Expr | ||
| Plus of ('e, 'v) Expr * ('e, 'v) Expr | ||
| Mult of ('e, 'v) Expr * ('e, 'v) Expr | ||
| Exp of ('e, 'v) Expr (* e^x *) | ||
| Sin of ('e, 'v) Expr (* sin x *) | ||
| Cos of ('e, 'v) Expr (* cos x *) | ||
withtype ('e, 'v) Expr = 'e * ('e, 'v) Labled | ||
|
||
|
||
|
||
fun lookup xs i = Vector.sub(xs, i) | ||
|
||
fun eval (_, exp) xs = | ||
case exp of | ||
X(_, i) => lookup xs i | ||
| Con c => c | ||
| Neg e => ~(eval e xs) | ||
| Plus(e1, e2) => eval e1 xs + eval e2 xs | ||
| Mult(e1, e2) => eval e1 xs * eval e2 xs | ||
| Exp e => Math.exp (eval e xs) | ||
| Sin e => Math.sin (eval e xs) | ||
| Cos e => Math.cos (eval e xs) | ||
|
||
|
||
|
||
(* Utility functions *) | ||
|
||
fun unit_elab e = ((), e) | ||
val & = unit_elab | ||
|
||
val const = & o Con | ||
|
||
fun ref_elab e = (ref NONE, e) | ||
|
||
fun uvar i = X((), i) | ||
fun rvar i = X(ref 0.0, i) | ||
|
||
|
||
|
||
fun pp (_, exp) = | ||
case exp of | ||
X (_, i) => "X"^Int.toString i | ||
| Con r => Real.toString r | ||
| Neg (l, Con r) => pp(l, Con(~r)) | ||
| Neg (e as (_, X _)) => "-"^pp e | ||
| Neg e => "-("^pp e^")" | ||
| Plus(e1, (_, Neg e2)) => concat["(", pp e1, " - ", pp e2, ")"] | ||
| Plus(e1, e2) => concat["(", pp e1, " + ", pp e2, ")"] | ||
| Mult(e1, e2) => concat["(", pp e1, " * ", pp e2, ")"] | ||
| Exp e => concat["e^", pp e] | ||
| Sin e => concat["sin(", pp e,")"] | ||
| Cos e => concat["cos(", pp e,")"] | ||
|
||
fun ppv evec = concat(Vector.foldr op:: [] (Vector.mapi (fn(i, e) => concat["df/dX", Int.toString i, " = ", pp e,"\n"]) evec)) | ||
|
||
|
||
|
||
fun makeExpr elab vlab n = | ||
let val one = elab(Con 1.0) | ||
fun minus e1 e2 = elab(Plus(e1, elab(Neg e2))) | ||
fun step e = elab(Exp (minus e one)) | ||
fun loop 0 acc = acc | ||
| loop i acc = loop (i-1) (step acc) | ||
in loop n (elab(vlab 0)) end | ||
|
||
val small = makeExpr unit_elab uvar 2 | ||
|
||
val wpFunc = | ||
let | ||
fun x i = &(uvar i) | ||
in &(Plus(&(Sin (x 0)), &(Mult(x 0, x 1)))) (* sin(x0) + (x0 * x1) *) | ||
end | ||
|
||
fun bigTime ad esize n_vars x0 = | ||
let val xs = Vector.tabulate(n_vars, fn i => if i = 0 then x0 else 0.0) | ||
in lookup (Mosml.time (ad xs) (makeExpr ref_elab uvar esize)) 0 | ||
end | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
fun makeFib elab vlab n = | ||
let infix ++ | ||
fun x ++ y = elab(Plus(x, y)) | ||
val fib0 = elab(vlab 0) | ||
val fib1 = elab(vlab 1) | ||
fun step (f0, f1) = (f1, f0 ++ f1) | ||
fun loop 0 acc = acc | ||
| loop i acc = loop (i-1) (step acc) | ||
in #2(loop (n-1) (fib0, fib1)) | ||
end | ||
|
||
fun fib n = makeFib unit_elab uvar n | ||
|
||
fun fibTime ad n n_vars = | ||
let val xs = Vector.tabulate(n_vars, fn i => if i < 0 then 1.0 else 0.0) | ||
in lookup (Mosml.time (ad xs) (makeFib ref_elab uvar n)) 0 | ||
end | ||
|
||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
(* | ||
datatype Expr = X of int | ||
| Con of real | ||
| Neg of Expr | ||
| Plus of Expr * Expr | ||
| Mult of Expr * Expr | ||
| Exp of Expr (* e^x *) | ||
| Sin of Expr (* sin x *) | ||
| Cos of Expr (* cos x *) | ||
*) | ||
|
||
datatype Labled = datatype Expr.Labled | ||
val & = Expr.& | ||
|
||
fun zipWith f xs ys = Vector.mapi (fn (i, x) => f(x, Vector.sub(ys, i))) xs | ||
|
||
|
||
fun zeroS n = Vector.tabulate(n, fn _ => Expr.const 0.0) | ||
fun directionS n i = Vector.tabulate(n, fn j => if i = j then Expr.const 1.0 else Expr.const 0.0) | ||
fun scalarS x v = Vector.map (fn e => &(Mult(x, e))) v | ||
|
||
|
||
fun diff (labled as (_, expr)) n = | ||
case expr of | ||
X(_, i) => directionS n i | ||
| Con _ => zeroS n | ||
| Neg e => Vector.map (& o Neg) (diff e n) | ||
| Plus(e1, e2) => zipWith (& o Plus) (diff e1 n) (diff e2 n) | ||
| Mult(e1, e2) => zipWith (& o Plus) | ||
(scalarS e1 (diff e2 n)) | ||
(scalarS e2 (diff e1 n)) | ||
| Exp e => scalarS labled (diff e n) | ||
| Sin e => scalarS (&(Cos e)) (diff e n) | ||
| Cos e => scalarS (&(Neg(&(Sin e)))) (diff e n) | ||
|
||
|
||
fun dumbAD xs exp = Vector.map (fn e => Expr.eval e xs) (diff exp (Vector.length xs)) | ||
|
||
fun zero n = Vector.tabulate(n, fn _ => 0.0) | ||
|
||
fun direction n i = Vector.tabulate(n, fn j => if i = j then 1.0 else 0.0) | ||
|
||
fun scalar x v = Vector.map (fn e => x * e) v | ||
|
||
|
||
type dualnum = real * real vector (* the result and the derivative *) | ||
|
||
|
||
fun forward xs expr = | ||
let val n = Vector.length xs | ||
(* diffEval : Expr -> dualnum *) | ||
fun diffEval (_, expr) = | ||
case expr of | ||
X(_, i) => (Expr.lookup xs i, direction n i) | ||
| Con c => (c, zero n) | ||
| Neg e => let val (ex, ed) = diffEval e | ||
in (~ex, Vector.map ~ ed) end | ||
| Plus(e1, e2) => let val (ex1, ed1) = diffEval e1 | ||
val (ex2, ed2) = diffEval e2 | ||
in (ex1 + ex2, zipWith op+ ed1 ed2) end | ||
| Mult (e, e') => let val (ex, ed) = diffEval e | ||
val (ex', ed') = diffEval e' | ||
in (ex * ex', zipWith op+ (scalar ex ed') | ||
(scalar ex' ed)) end | ||
| Exp e => let val (ex, ed) = diffEval e | ||
val exp_ex = Math.exp ex | ||
in (exp_ex, scalar exp_ex ed) end | ||
| Sin e => let val (ex, ed) = diffEval e | ||
in (Math.sin ex, scalar (Math.cos ex) ed) end | ||
| Cos e => let val (ex, ed) = diffEval e | ||
in (Math.cos ex, scalar (~(Math.sin ex)) ed) end | ||
in #2(diffEval expr) end | ||
|
Oops, something went wrong.