Skip to content

Commit

Permalink
type check lambdas when last op (body) is set
Browse files Browse the repository at this point in the history
* fixed bug in alpha-equiv - relaged to #174?
* fixed bug in lower_matrix_mediumlevel.cpp (wrong filter type)
* still some problems when trying to infer return type of lam
  • Loading branch information
leissa committed May 13, 2023
1 parent 3e67917 commit 01fb7d2
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 12 deletions.
5 changes: 3 additions & 2 deletions dialects/matrix/passes/lower_matrix_mediumlevel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ Ref LowerMatrixMediumLevel::rewrite_(Ref def) {
iterator[idx] = world.call<core::bitcast>(world.type_idx(dim_nat_def), iter);
auto [new_mem, new_mat] = new_acc->projs<2>();
acc = {new_mem, new_mat};
current_mut->set(dim_nat_def, for_call);
current_mut->set(false, for_call); // TODO correct filter?
current_mut = body;
}

Expand Down Expand Up @@ -292,6 +292,7 @@ Ref LowerMatrixMediumLevel::rewrite_(Ref def) {
acc = {current_mem, element_acc};
cont = write_back;

// TODO this is copy&paste code from above
for (auto idx : in_indices) {
char for_name[32];
sprintf(for_name, "forIn_%lu", idx);
Expand All @@ -306,7 +307,7 @@ Ref LowerMatrixMediumLevel::rewrite_(Ref def) {
iterator[idx] = world.call<core::bitcast>(world.type_idx(dim_nat_def), iter);
auto [new_mem, new_element] = new_acc->projs<2>();
acc = {new_mem, new_element};
current_mut->set(dim_nat_def, for_call);
current_mut->set(false, for_call); // TODO
current_mut = body;
}

Expand Down
35 changes: 26 additions & 9 deletions thorin/check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,19 @@ bool Checker::equiv_internal(Ref d1, Ref d2) {
if (!equiv(d1->type(), d2->type())) return false;
if (d1->isa<Top>() || d2->isa<Top>()) return equiv(d1->type(), d2->type());

struct Pop {
~Pop() {
if (vars) vars->pop_back();
}

Vars* vars = nullptr;
} pop;

if (auto n1 = d1->isa_mut()) {
if (auto n2 = d2->isa_mut()) vars_.emplace_back(n1, n2);
if (auto n2 = d2->isa_mut()) {
vars_.emplace_back(n1, n2);
pop.vars = &vars_; // make sure vars_ is popped again
}
}

if (d1->isa<Sigma, Arr>()) {
Expand All @@ -109,10 +120,18 @@ bool Checker::equiv_internal(Ref d1, Ref d2) {

if (d1->node() != d2->node() || d1->flags() != d2->flags() || d1->num_ops() != d2->num_ops()) return false;

if (auto var = d1->isa<Var>()) { // vars are equal if they appeared under the same binder
for (auto [n1, n2] : vars_)
if (var->mut() == n1) return d2->as<Var>()->mut() == n2;
// TODO what if Var is free?
if (auto var1 = d1->isa<Var>()) { // vars are equal if they appeared under the same binder
auto var2 = d2->as<Var>();
bool bound1 = false, bound2 = false;
for (auto [n1, n2] : vars_) {
if (var1->mut() == n1) {
bound1 = true;
return d2->as<Var>()->mut() == n2;
}
assert(var1->mut() != n2);
if (var2->mut() == n1 || var2->mut() == n2) bound2 = true;
}
if (!bound1 && !bound2) return true; // both var1 and var2 are free
return false;
}

Expand Down Expand Up @@ -185,19 +204,17 @@ void Sigma::check() {

void Lam::check() {
auto& w = world();
return; // TODO
if (!w.checker().equiv(filter()->type(), w.type_bool()))
error(filter(), "filter '{}' of lambda is of type '{}' but must be of type '.Bool'", filter(),
filter()->type());
if (!w.checker().equiv(body()->type(), codom()))
error(body(), "body '{}' of lambda is of type '{}' but its codomain is of type '{}'", body(), body()->type(),
codom());
error(body(), "body '{}' of lambda is of type \n'{}' but its codomain is of type \n'{}'", body(),
body()->type(), codom());
}

void Pi::check() {
auto& w = world();
auto t = infer(dom(), codom());

if (!w.checker().equiv(t, type()))
error(type(), "declared sort '{}' of function type does not match inferred one '{}'", type(), t);
}
Expand Down
3 changes: 2 additions & 1 deletion thorin/check.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class Checker {

World* world_;
DefDefMap<Equiv> equiv_;
std::deque<std::pair<Def*, Def*>> vars_;
using Vars = std::deque<std::pair<Def*, Def*>>;
Vars vars_;
};

} // namespace thorin

0 comments on commit 01fb7d2

Please sign in to comment.