Skip to content

Commit

Permalink
ad matrix test
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcel Ullrich committed Jan 16, 2023
1 parent 6a524b2 commit ee33da0
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 4 deletions.
1 change: 1 addition & 0 deletions dialects/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ add_thorin_dialect(autodiff
clos
affine
direct
matrix
HEADER_DEPENDS
mem
INSTALL
Expand Down
57 changes: 55 additions & 2 deletions dialects/autodiff/autodiff.thorin
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
.import clos;
/// For derivatives:
.import core;
// .import matrix;
.import matrix;
///
///
/// ## Types
Expand Down Expand Up @@ -303,7 +303,60 @@
// };

// (M × N)' = (S × Nᵀ, Mᵀ × S)
// .lam .extern internal_diff_core_wrap_add!s:.Nat -> .Nat -> (.Cn[[.Idx s, .Idx s], .Cn[.Idx s, .Cn[.Idx s, .Cn[.Idx s, .Idx s]]]]) =
.lam .extern internal_diff_matrix_prod![m: .Nat, k: .Nat, l: .Nat, [p: .Nat, e:.Nat]] ->
(.Cn[
[%mem.M,%matrix.Mat (2,(m, k),%math.F (p,e)), %matrix.Mat (2,(k, l),%math.F (p,e))],
.Cn[
[%mem.M,%matrix.Mat (2,(m, l),%math.F (p,e))], // output
.Cn[ // pullback
[%mem.M,%matrix.Mat (2,(m, l),%math.F (p,e))], // out tangent
.Cn [%mem.M,%matrix.Mat (2,(m, k),%math.F (p,e)), %matrix.Mat (2,(k, l),%math.F (p,e))] // input tangent
]
]
]) =
// .lm m: .Nat -> (.Cn[[.Idx s, .Idx s], .Cn[.Idx s, .Cn[.Idx s, .Cn[.Idx s, .Idx s]]]]) =
.cn ![
[mem:%mem.M,m1:%matrix.Mat (2,(m, k),%math.F (p,e)), m2:%matrix.Mat (2,(k, l),%math.F (p,e))],
ret :
.Cn[
[%mem.M,%matrix.Mat (2,(m, l),%math.F (p,e))], // output
.Cn[ // pullback
[%mem.M,%matrix.Mat (2,(m, l),%math.F (p,e))], // out tangent
.Cn [%mem.M,%matrix.Mat (2,(m, k),%math.F (p,e)), %matrix.Mat (2,(k, l),%math.F (p,e))] // input tangent
]
]] = {
.con pb [
[mem:%mem.M,ms:%matrix.Mat (2,(m, l),%math.F (p,e))],
pb_ret:.Cn [%mem.M,%matrix.Mat (2,(m, k),%math.F (p,e)), %matrix.Mat (2,(k, l),%math.F (p,e))]
] = {
.let (mem1, m1_t) = %matrix.transpose ((m,k), %math.F (p,e)) (mem ,m1);
.let (mem2, m2_t) = %matrix.transpose ((k,l), %math.F (p,e)) (mem1,m2);
.let (mem3, m1_s) = %matrix.prod (m,l,k, (p,e)) (mem2,ms,m2_t);
.let (mem4, m2_s) = %matrix.prod (k,m,l, (p,e)) (mem3,m1_t,ms);
.let (mem5, result) = %matrix.prod (m,k,l, (p,e)) (mem4,m1,m2);
pb_ret (mem5,m1_s,m2_s)
};
.let result = %matrix.prod (m,k,l, (p,e)) (mem,m1,m2);
ret (result, pb)
};



// (a b:.Idx s), ret:.Cn[.Idx s, .Cn[.Idx s, .Cn[.Idx s, .Idx s]]]) = {
// .let result = %core.wrap.add s m (a,b);
// ret (result, .cn ![i:(.Idx s), pb_ret:(.Cn [.Idx s, .Idx s])] = pb_ret (i,i))
// };
//
// .lam .extern internal_diff_matrix_prod![m: .Nat, k: .Nat, l: .Nat, [p: .Nat, e:.Nat]] ->
// (.Cn[
// [.Idx s, .Idx s],
// .Cn[
// .Idx s,
// .Cn[.Idx s, .Cn[.Idx s, .Idx s]]
// ]
// ]) =
//
// .lam .extern internal_diff_matrix_prod!s:.Nat -> .Nat -> (.Cn[[.Idx s, .Idx s], .Cn[.Idx s, .Cn[.Idx s, .Cn[.Idx s, .Idx s]]]]) =
// .lm m: .Nat -> (.Cn[[.Idx s, .Idx s], .Cn[.Idx s, .Cn[.Idx s, .Cn[.Idx s, .Idx s]]]]) =
// .cn !((a b:.Idx s), ret:.Cn[.Idx s, .Cn[.Idx s, .Cn[.Idx s, .Idx s]]]) = {
// .let result = %core.wrap.add s m (a,b);
Expand Down
23 changes: 23 additions & 0 deletions dialects/autodiff/auxiliary/autodiff_aux.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ const Pi* autodiff_type_fun_pi(const Pi* pi) {
// Performs the type transformation `A` => `A'`.
// This is of special importance for functions: `P->Q` => `P'->Q' * (Q* -> P*)`.
const Def* autodiff_type_fun(const Def* ty) {
// TODO: handle dependencies using memoization

auto& world = ty->world();
// TODO: handle DS (operators)
if (auto pi = ty->isa<Pi>()) { return autodiff_type_fun_pi(pi); }
Expand Down Expand Up @@ -163,8 +165,29 @@ const Def* autodiff_type_fun(const Def* ty) {
return ty;
}

if (auto app = ty->isa<App>()) {
// axiom args
auto callee = app->callee();
auto arg = app->arg();
auto callee_ad = autodiff_type_fun(callee);
if (!callee_ad) return nullptr;
auto arg_ad = autodiff_type_fun(arg);
if (!arg_ad) return nullptr;
return world.app(callee_ad, arg_ad);
}
if (auto axiom = ty->isa<Axiom>()) { return ty; }
if (auto sig = ty->isa<Tuple>()) {
// Type argument
DefArray ops(sig->ops(), [&](const Def* op) { return autodiff_type_fun(op); });
return world.tuple(ops);
}
// TODO: extract
if (auto lit = ty->isa<Lit>()) { return ty; }
if (auto nat = ty->isa<Nat>()) { return ty; }

world.WLOG("no-diff type: {}", ty);
return nullptr;
// return ty;
}

const Def* zero_def(const Def* T) {
Expand Down
3 changes: 2 additions & 1 deletion dialects/autodiff/normalizers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ const Def* normalize_add(const Def* type, const Def* callee, const Def* arg, con
return real_add;
} else if (auto app = T->isa<App>()) {
auto callee = app->callee();
assert(0 && "not handled");
// assert(0 && "not handled");
world.ELOG("not handled: add app {} {} {}", T, a, b);
}
// TODO: mem stays here (only resolved after direct simplification)

Expand Down
2 changes: 1 addition & 1 deletion dialects/matrix/passes/lower_matrix_mediumlevel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ std::pair<Lam*, const Def*> counting_for(const Def* bound, Defs acc, const Def*
auto body = world.nom_lam(world.cn({
world.type_int(32), // iterator
acc_ty, // acc = memory+extra
world.cn({acc_ty}) // exit = return
world.cn(acc_ty) // exit = return
}),
world.dbg(name));
auto for_loop = affine::op_for(world, world.lit_int(32, 0), bound, world.lit_int(32, 1), acc, body, exit);
Expand Down

0 comments on commit ee33da0

Please sign in to comment.