Skip to content

Commit

Permalink
Added public functions to allow working with uncanonicalized Ids
Browse files Browse the repository at this point in the history
  • Loading branch information
dewert99 committed Nov 26, 2023
1 parent 9043f3b commit 44a2529
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 66 deletions.
99 changes: 72 additions & 27 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,15 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
}
}

/// Like [`id_to_expr`](EGraph::id_to_expr) but only goes one layer deep
pub fn id_to_node(&self, id: Id) -> &L {
if let Some(explain) = &self.explain {
explain.node(id)
} else {
panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get unique expressions per id");
}
}

/// Like [`id_to_expr`](EGraph::id_to_expr), but creates a pattern instead of a term.
/// When an eclass listed in the given substitutions is found, it creates a variable.
/// It also adds this variable and the corresponding Id value to the resulting [`Subst`]
Expand Down Expand Up @@ -404,12 +413,20 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
left_expr: &RecExpr<L>,
right_expr: &RecExpr<L>,
) -> Explanation<L> {
let left = self.add_expr_internal(left_expr);
let right = self.add_expr_internal(right_expr);
let left = self.add_expr_uncanonical(left_expr);
let right = self.add_expr_uncanonical(right_expr);

self.explain_id_equivalence(left, right)
}

/// Equivalent to calling [`explain_equivalence`](EGraph::explain_equivalence)`(`[`id_to_expr`](EGraph::id_to_expr)`(left),`
/// [`id_to_expr`](EGraph::id_to_expr)`(right))` but more efficient
fn explain_id_equivalence(&mut self, left: Id, right: Id) -> Explanation<L> {
if self.find(left) != self.find(right) {
panic!(
"Tried to explain equivalence between non-equal terms {:?} and {:?}",
left_expr, right_expr
self.id_to_expr(left),
self.id_to_expr(left)
);
}
if let Some(explain) = &mut self.explain {
Expand All @@ -428,7 +445,13 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// Note that this function can be called again to explain any intermediate terms
/// used in the output [`Explanation`].
pub fn explain_existance(&mut self, expr: &RecExpr<L>) -> Explanation<L> {
let id = self.add_expr_internal(expr);
let id = self.add_expr_uncanonical(expr);
self.explain_existance_id(id)
}

/// Equivalent to calling [`explain_existance`](EGraph::explain_existance)`(`[`id_to_expr`](EGraph::id_to_expr)`(id))`
/// but more efficient
fn explain_existance_id(&mut self, id: Id) -> Explanation<L> {
if let Some(explain) = &mut self.explain {
explain.explain_existance(id)
} else {
Expand All @@ -442,7 +465,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
pattern: &PatternAst<L>,
subst: &Subst,
) -> Explanation<L> {
let id = self.add_instantiation_internal(pattern, subst);
let id = self.add_instantiation_noncanonical(pattern, subst);
if let Some(explain) = &mut self.explain {
explain.explain_existance(id)
} else {
Expand All @@ -457,8 +480,8 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
right_pattern: &PatternAst<L>,
subst: &Subst,
) -> Explanation<L> {
let left = self.add_expr_internal(left_expr);
let right = self.add_instantiation_internal(right_pattern, subst);
let left = self.add_expr_uncanonical(left_expr);
let right = self.add_instantiation_noncanonical(right_pattern, subst);

if self.find(left) != self.find(right) {
panic!(
Expand Down Expand Up @@ -549,19 +572,22 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
///
/// [`add_expr`]: EGraph::add_expr()
pub fn add_expr(&mut self, expr: &RecExpr<L>) -> Id {
let id = self.add_expr_internal(expr);
let id = self.add_expr_uncanonical(expr);
self.find(id)
}

/// Adds an expr to the egraph, and returns the uncanonicalized id of the top enode.
fn add_expr_internal(&mut self, expr: &RecExpr<L>) -> Id {
/// Similar to [`add_expr`](EGraph::add_expr) but the `Id` returned may not be canonical
///
/// Like [`add_uncanonical`](EGraph::add_uncanonical), when explanations are enabled calling
/// [`id_to_expr`](EGraph::id_to_expr) on this `Id` will correspond to the parameter `enode`
pub fn add_expr_uncanonical(&mut self, expr: &RecExpr<L>) -> Id {
let nodes = expr.as_ref();
let mut new_ids = Vec::with_capacity(nodes.len());
let mut new_node_q = Vec::with_capacity(nodes.len());
for node in nodes {
let new_node = node.clone().map_children(|i| new_ids[usize::from(i)]);
let size_before = self.unionfind.size();
let next_id = self.add_internal(new_node);
let next_id = self.add_uncanonical(new_node);
if self.unionfind.size() > size_before {
new_node_q.push(true);
} else {
Expand All @@ -583,11 +609,16 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// Adds a [`Pattern`] and a substitution to the [`EGraph`], returning
/// the eclass of the instantiated pattern.
pub fn add_instantiation(&mut self, pat: &PatternAst<L>, subst: &Subst) -> Id {
let id = self.add_instantiation_internal(pat, subst);
let id = self.add_instantiation_noncanonical(pat, subst);
self.find(id)
}

fn add_instantiation_internal(&mut self, pat: &PatternAst<L>, subst: &Subst) -> Id {
/// Similar to [`add_instantiation`](EGraph::add_instantiation) but the `Id` returned may not be
/// canonical
///
/// Like [`add_uncanonical`](EGraph::add_uncanonical), when explanations are enabled calling
/// [`id_to_expr`](EGraph::id_to_expr) on this `Id` will correspond to the parameter `enode`
fn add_instantiation_noncanonical(&mut self, pat: &PatternAst<L>, subst: &Subst) -> Id {
let nodes = pat.as_ref();
let mut new_ids = Vec::with_capacity(nodes.len());
let mut new_node_q = Vec::with_capacity(nodes.len());
Expand All @@ -601,7 +632,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
ENodeOrVar::ENode(node) => {
let new_node = node.clone().map_children(|i| new_ids[usize::from(i)]);
let size_before = self.unionfind.size();
let next_id = self.add_internal(new_node);
let next_id = self.add_uncanonical(new_node);
if self.unionfind.size() > size_before {
new_node_q.push(true);
} else {
Expand Down Expand Up @@ -696,12 +727,31 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
///
/// [`add`]: EGraph::add()
pub fn add(&mut self, enode: L) -> Id {
let id = self.add_internal(enode);
let id = self.add_uncanonical(enode);
self.find(id)
}

/// Adds an enode to the egraph and also returns the the enode's id (uncanonicalized).
fn add_internal(&mut self, mut enode: L) -> Id {
/// Similar to [`add`](EGraph::add) but the `Id` returned may not be canonical
///
/// When explanations are enabled calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` will
/// correspond to the parameter `enode`
///
/// # Example
/// ```
/// # use egg::*;
/// let mut egraph: EGraph<SymbolLang, ()> = EGraph::default().with_explanations_enabled();
/// let a = egraph.add_uncanonical(SymbolLang::leaf("a"));
/// let b = egraph.add_uncanonical(SymbolLang::leaf("b"));
/// egraph.union(a, b);
/// egraph.rebuild();
///
/// let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a]));
/// let fb = egraph.add_uncanonical(SymbolLang::new("f", vec![b]));
///
/// assert_eq!(egraph.id_to_expr(fa), "(f a)".parse().unwrap());
/// assert_eq!(egraph.id_to_expr(fb), "(f b)".parse().unwrap());
/// ```
pub fn add_uncanonical(&mut self, mut enode: L) -> Id {
let original = enode.clone();
if let Some(existing_id) = self.lookup_internal(&mut enode) {
let id = self.find(existing_id);
Expand Down Expand Up @@ -799,9 +849,9 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
subst: &Subst,
rule_name: impl Into<Symbol>,
) -> (Id, bool) {
let id1 = self.add_instantiation_internal(from_pat, subst);
let id1 = self.add_instantiation_noncanonical(from_pat, subst);
let size_before = self.unionfind.size();
let id2 = self.add_instantiation_internal(to_pat, subst);
let id2 = self.add_instantiation_noncanonical(to_pat, subst);
let rhs_new = self.unionfind.size() > size_before;

let did_union = self.perform_union(
Expand All @@ -815,12 +865,9 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {

/// Unions two e-classes, using a given reason to justify it.
///
///
/// Unlike `union_instantiations`, this function picks arbitrary representatives
/// from either e-class.
/// When possible, use [`union_instantiations`](EGraph::union_instantiations),
/// since that ensures that the proof rewrites between the terms you are
/// actually proving equivalent.
/// This function picks representatives using [`id_to_expr`](EGraph::id_to_expr) so choosing
/// `Id`s returned by functions like [`add_uncanonical`](EGraph::add_uncanonical) is important
/// to control explanations
pub fn union_trusted(&mut self, from: Id, to: Id, reason: impl Into<Symbol>) -> bool {
self.perform_union(from, to, Some(Justification::Rule(reason.into())), false)
}
Expand All @@ -833,8 +880,6 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
///
/// When explanations are enabled, this function behaves like [`EGraph::union_trusted`],
/// and it lists the call site as the proof reason.
/// You should prefer [`union_instantiations`](EGraph::union_instantiations) when
/// you want the proofs to always be meaningful.
/// See [`explain_equivalence`](Runner::explain_equivalence) for a more detailed
/// explanation of the feature.
#[track_caller]
Expand Down
93 changes: 54 additions & 39 deletions src/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,9 @@ impl<I: Eq + PartialEq> PartialOrd for HeapState<I> {
}

impl<L: Language> Explain<L> {
pub(crate) fn node(&self, node_id: Id) -> &L {
&self.explainfind[usize::from(node_id)].node
}
fn node_to_explanation(
&self,
node_id: Id,
Expand All @@ -891,7 +894,7 @@ impl<L: Language> Explain<L> {
if let Some(existing) = cache.get(&node_id) {
existing.clone()
} else {
let node = self.explainfind[usize::from(node_id)].node.clone();
let node = self.node(node_id).clone();
let children = node.fold(vec![], |mut sofar, child| {
sofar.push(vec![self.node_to_explanation(child, cache)]);
sofar
Expand All @@ -908,24 +911,20 @@ impl<L: Language> Explain<L> {
self.node_to_recexpr_internal(&mut res, node_id, &mut cache);
res
}

fn node_to_recexpr_internal(
&self,
res: &mut RecExpr<L>,
node_id: Id,
cache: &mut HashMap<Id, Id>,
) {
let new_node = self.explainfind[usize::from(node_id)]
.node
.clone()
.map_children(|child| {
if let Some(existing) = cache.get(&child) {
*existing
} else {
self.node_to_recexpr_internal(res, child, cache);
Id::from(res.as_ref().len() - 1)
}
});
let new_node = self.node(node_id).clone().map_children(|child| {
if let Some(existing) = cache.get(&child) {
*existing
} else {
self.node_to_recexpr_internal(res, child, cache);
Id::from(res.as_ref().len() - 1)
}
});
res.add(new_node);
}

Expand Down Expand Up @@ -954,23 +953,20 @@ impl<L: Language> Explain<L> {
res.add(ENodeOrVar::Var(var));
subst.insert(var, *existing);
} else {
let new_node = self.explainfind[usize::from(node_id)]
.node
.clone()
.map_children(|child| {
if let Some(existing) = cache.get(&child) {
*existing
} else {
self.node_to_pattern_internal(res, child, var_substitutions, subst, cache);
Id::from(res.as_ref().len() - 1)
}
});
let new_node = self.node(node_id).clone().map_children(|child| {
if let Some(existing) = cache.get(&child) {
*existing
} else {
self.node_to_pattern_internal(res, child, var_substitutions, subst, cache);
Id::from(res.as_ref().len() - 1)
}
});
res.add(ENodeOrVar::ENode(new_node));
}
}

fn node_to_flat_explanation(&self, node_id: Id) -> FlatTerm<L> {
let node = self.explainfind[usize::from(node_id)].node.clone();
let node = self.node(node_id).clone();
let children = node.fold(vec![], |mut sofar, child| {
sofar.push(self.node_to_flat_explanation(child));
sofar
Expand Down Expand Up @@ -1123,9 +1119,7 @@ impl<L: Language> Explain<L> {
new_rhs: bool,
) {
if let Justification::Congruence = justification {
assert!(self.explainfind[usize::from(node1)]
.node
.matches(&self.explainfind[usize::from(node2)].node));
assert!(self.node(node1).matches(self.node(node2)));
}
if new_rhs {
self.set_existance_reason(node2, node1)
Expand Down Expand Up @@ -1417,8 +1411,8 @@ impl<L: Language> Explain<L> {
}
Justification::Congruence => {
// add the children proofs to the last explanation
let current_node = &self.explainfind[usize::from(connection.current)].node;
let next_node = &self.explainfind[usize::from(connection.next)].node;
let current_node = self.node(connection.current);
let next_node = self.node(connection.next);
assert!(current_node.matches(next_node));
let mut subproofs = vec![];

Expand Down Expand Up @@ -1558,8 +1552,8 @@ impl<L: Language> Explain<L> {
next: Id,
distance_memo: &mut DistanceMemo,
) -> ProofCost {
let current_node = self.explainfind[usize::from(current)].node.clone();
let next_node = self.explainfind[usize::from(next)].node.clone();
let current_node = self.node(current).clone();
let next_node = self.node(next).clone();
let mut cost: ProofCost = Saturating(0);
for (left_child, right_child) in current_node
.children()
Expand Down Expand Up @@ -1654,8 +1648,8 @@ impl<L: Language> Explain<L> {
// find all congruence nodes
let mut cannon_enodes: HashMap<L, Vec<Id>> = Default::default();
for enode in &enodes {
let cannon = self.explainfind[usize::from(*enode)]
.node
let cannon = self
.node(*enode)
.clone()
.map_children(|child| unionfind.find(child));
if let Some(others) = cannon_enodes.get_mut(&cannon) {
Expand Down Expand Up @@ -1838,8 +1832,8 @@ impl<L: Language> Explain<L> {
std::mem::swap(&mut next, &mut current);
}
if let Justification::Congruence = connection.justification {
let current_node = self.explainfind[usize::from(current)].node.clone();
let next_node = self.explainfind[usize::from(next)].node.clone();
let current_node = self.node(current).clone();
let next_node = self.node(next).clone();
for (left_child, right_child) in current_node
.children()
.iter()
Expand Down Expand Up @@ -1903,11 +1897,11 @@ impl<L: Language> Explain<L> {
for (s_int, others) in congruence_neighbors.iter().enumerate() {
let start = &Id::from(s_int);
for other in others {
for (left, right) in self.explainfind[usize::from(*start)]
.node
for (left, right) in self
.node(*start)
.children()
.iter()
.zip(self.explainfind[usize::from(*other)].node.children().iter())
.zip(self.node(*other).children().iter())
{
if left != right {
if common_ancestor_queries.get(start).is_none() {
Expand Down Expand Up @@ -2095,3 +2089,24 @@ mod tests {
egraph.dot().to_dot("target/foo.dot").unwrap();
}
}

#[test]
fn simple_explain_union_trusted() {
use crate::SymbolLang;
crate::init_logger();
let mut egraph = EGraph::new(()).with_explanations_enabled();

let a = egraph.add_uncanonical(SymbolLang::leaf("a"));
let b = egraph.add_uncanonical(SymbolLang::leaf("b"));
let c = egraph.add_uncanonical(SymbolLang::leaf("c"));
let d = egraph.add_uncanonical(SymbolLang::leaf("d"));
egraph.union_trusted(a, b, "a=b");
egraph.rebuild();
let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a]));
let fb = egraph.add_uncanonical(SymbolLang::new("f", vec![b]));
egraph.union_trusted(c, fa, "c=fa");
egraph.union_trusted(d, fb, "d=fb");
egraph.rebuild();
let mut exp = egraph.explain_equivalence(&"c".parse().unwrap(), &"d".parse().unwrap());
assert_eq!(exp.make_flat_explanation().len(), 4)
}

0 comments on commit 44a2529

Please sign in to comment.