diff --git a/src/lib.rs b/src/lib.rs index 20d0515..6c844a6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,8 @@ pub mod bimap; pub mod dict; pub mod term; +pub mod substitution; + pub mod lexer; pub mod parser; pub mod unparser; @@ -25,6 +27,7 @@ mod pretty; pub use { dict::*, term::*, + substitution::*, sugar::*, unification::*, morphism::* diff --git a/src/morphism.rs b/src/morphism.rs index 3d47a5f..69853f3 100644 --- a/src/morphism.rs +++ b/src/morphism.rs @@ -57,13 +57,13 @@ impl<M: Morphism + Clone> MorphismInstance<M> { src_type: TypeTerm::Ladder(vec![ self.halo.clone(), self.m.get_type().src_type.clone() - ]).apply_substitution(&|k| self.σ.get(k).cloned()) + ]).apply_subst(&self.σ) .clone(), dst_type: TypeTerm::Ladder(vec![ self.halo.clone(), self.m.get_type().dst_type.clone() - ]).apply_substitution(&|k| self.σ.get(k).cloned()) + ]).apply_subst(&self.σ) .clone(), /* trait_bounds: Vec::new(), diff --git a/src/morphism_path.rs b/src/morphism_path.rs index 817d58a..849603d 100644 --- a/src/morphism_path.rs +++ b/src/morphism_path.rs @@ -57,7 +57,7 @@ impl<'a, M:Morphism+Clone> ShortestPathProblem<'a, M> { || dst_type.contains_var(*varid) { new_σ.insert( k.clone(), - v.clone().apply_substitution(&|k| σ.get(k).cloned()).clone().strip() + v.clone().apply_subst(&σ).clone().strip() ); } } @@ -66,18 +66,12 @@ impl<'a, M:Morphism+Clone> ShortestPathProblem<'a, M> { if let TypeID::Var(varid) = k { if src_type.contains_var(*varid) || dst_type.contains_var(*varid) { - new_σ.insert( - k.clone(), - v.clone().apply_substitution(&|k| σ.get(k).cloned()).clone().strip() - ); + new_σ.insert( k.clone(), v.clone().apply_subst(&σ).clone().strip() ); } } } - n.halo = n.halo.clone().apply_substitution( - &|k| σ.get(k).cloned() - ).clone().strip().param_normalize(); - + n.halo = n.halo.clone().apply_subst(&σ).clone().strip().param_normalize(); n.σ = new_σ; } @@ -98,20 +92,20 @@ impl<'a, M:Morphism+Clone> ShortestPathProblem<'a, M> { for (k,v) in next_morph_inst.σ.iter() { new_σ.insert( k.clone(), - v.clone().apply_substitution(&|k| next_morph_inst.σ.get(k).cloned()).clone() + v.clone().apply_subst(&next_morph_inst.σ).clone() ); } for (k,v) in n.σ.iter() { new_σ.insert( k.clone(), - v.clone().apply_substitution(&|k| next_morph_inst.σ.get(k).cloned()).clone() + v.clone().apply_subst(&next_morph_inst.σ).clone() ); } - n.halo = n.halo.clone().apply_substitution( - &|k| next_morph_inst.σ.get(k).cloned() - ).clone().strip().param_normalize(); + n.halo = n.halo.clone() + .apply_subst( &next_morph_inst.σ ).clone() + .strip().param_normalize(); n.σ = new_σ; } diff --git a/src/steiner_tree.rs b/src/steiner_tree.rs index 714956d..3df0aca 100644 --- a/src/steiner_tree.rs +++ b/src/steiner_tree.rs @@ -38,8 +38,8 @@ impl SteinerTree { // goal reached. for e in self.edges.iter_mut() { - e.src_type = e.src_type.apply_substitution(&|x| σ.get(x).cloned()).clone(); - e.dst_type = e.dst_type.apply_substitution(&|x| σ.get(x).cloned()).clone(); + e.src_type = e.src_type.apply_subst(&σ).clone(); + e.dst_type = e.dst_type.apply_subst(&σ).clone(); } added = true; } else { diff --git a/src/substitution.rs b/src/substitution.rs new file mode 100644 index 0000000..b0f70ff --- /dev/null +++ b/src/substitution.rs @@ -0,0 +1,62 @@ + +use crate::{ + TypeID, + TypeTerm +}; + +//<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\ + +pub trait Substitution { + fn get(&self, t: &TypeID) -> Option< TypeTerm >; +} + +impl<S: Fn(&TypeID)->Option<TypeTerm>> Substitution for S { + fn get(&self, t: &TypeID) -> Option< TypeTerm > { + (self)(t) + } +} + +impl Substitution for std::collections::HashMap< TypeID, TypeTerm > { + fn get(&self, t: &TypeID) -> Option< TypeTerm > { + (self as &std::collections::HashMap< TypeID, TypeTerm >).get(t).cloned() + } +} + +impl TypeTerm { + /// recursively apply substitution to all subterms, + /// which will replace all occurences of variables which map + /// some type-term in `subst` + pub fn apply_substitution( + &mut self, + σ: &impl Substitution + ) -> &mut Self { + self.apply_subst(σ) + } + + pub fn apply_subst( + &mut self, + σ: &impl Substitution + ) -> &mut Self { + match self { + TypeTerm::TypeID(typid) => { + if let Some(t) = σ.get(typid) { + *self = t; + } + } + + TypeTerm::Ladder(rungs) => { + for r in rungs.iter_mut() { + r.apply_subst(σ); + } + } + TypeTerm::App(args) => { + for r in args.iter_mut() { + r.apply_subst(σ); + } + } + _ => {} + } + + self + } +} diff --git a/src/term.rs b/src/term.rs index ec58afb..4326d67 100644 --- a/src/term.rs +++ b/src/term.rs @@ -92,35 +92,6 @@ impl TypeTerm { } } - /// recursively apply substitution to all subterms, - /// which will replace all occurences of variables which map - /// some type-term in `subst` - pub fn apply_substitution( - &mut self, - subst: &impl Fn(&TypeID) -> Option<TypeTerm> - ) -> &mut Self { - match self { - TypeTerm::TypeID(typid) => { - if let Some(t) = subst(typid) { - *self = t; - } - } - - TypeTerm::Ladder(rungs) => { - for r in rungs.iter_mut() { - r.apply_substitution(subst); - } - } - TypeTerm::App(args) => { - for r in args.iter_mut() { - r.apply_substitution(subst); - } - } - _ => {} - } - - self - } /* strip away empty ladders * & unwrap singletons diff --git a/src/test/substitution.rs b/src/test/substitution.rs index e8906b9..91aa810 100644 --- a/src/test/substitution.rs +++ b/src/test/substitution.rs @@ -1,7 +1,7 @@ use { - crate::{dict::*, term::*, parser::*, unparser::*}, - std::iter::FromIterator + crate::{dict::*, term::*, parser::*, unparser::*, substitution::*}, + std::iter::FromIterator, }; //<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\ @@ -24,8 +24,7 @@ fn test_subst() { assert_eq!( - dict.parse("<Seq T~U>").unwrap() - .apply_substitution(&|typid|{ σ.get(typid).cloned() }).clone(), + dict.parse("<Seq T~U>").unwrap().apply_subst(&σ).clone(), dict.parse("<Seq ℕ~<Seq Char>>").unwrap() ); } diff --git a/src/test/unification.rs b/src/test/unification.rs index 6021dda..99ea7ed 100644 --- a/src/test/unification.rs +++ b/src/test/unification.rs @@ -23,8 +23,8 @@ fn test_unify(ts1: &str, ts2: &str, expect_unificator: bool) { let σ = σ.unwrap(); assert_eq!( - t1.apply_substitution(&|v| σ.get(v).cloned()), - t2.apply_substitution(&|v| σ.get(v).cloned()) + t1.apply_subst(&σ), + t2.apply_subst(&σ) ); } else { assert!(! σ.is_ok()); diff --git a/src/unification.rs b/src/unification.rs index 5072ea4..32d45df 100644 --- a/src/unification.rs +++ b/src/unification.rs @@ -79,7 +79,7 @@ impl UnificationProblem { let mut new_σ = HashMap::new(); for (v, tt) in self.σ.iter() { let mut tt = tt.clone().normalize(); - tt.apply_substitution(&|v| self.σ.get(v).cloned()); + tt.apply_subst(&self.σ); tt = tt.normalize(); //eprintln!("update σ : {:?} --> {:?}", v, tt); new_σ.insert(v.clone(), tt); @@ -415,7 +415,7 @@ impl UnificationProblem { Ok(halo) => { if halo == TypeTerm::unit() { let mut y = y.clone(); - y.apply_substitution(&|k| self.σ.get(k).cloned()); + y.apply_subst(&self.σ); y = y.strip(); let mut top = y.get_lnf_vec().first().unwrap().clone(); halo_args.push(top.clone()); @@ -426,7 +426,7 @@ impl UnificationProblem { let x = &mut halo_args[n_halos_required-1]; if let TypeTerm::Ladder(argrs) = x { let mut a = a2[n_halos_required-1].clone(); - a.apply_substitution(&|k| self.σ.get(k).cloned()); + a.apply_subst(&self.σ); a = a.get_lnf_vec().first().unwrap().clone(); argrs.push(a); } else { @@ -435,7 +435,7 @@ impl UnificationProblem { a2[n_halos_required-1].clone().get_lnf_vec().first().unwrap().clone() ]); - x.apply_substitution(&|k| self.σ.get(k).cloned()); + x.apply_subst(&self.σ); } } @@ -465,8 +465,8 @@ impl UnificationProblem { pub fn solve(mut self) -> Result<(Vec<TypeTerm>, HashMap<TypeID, TypeTerm>), UnificationError> { // solve equations while let Some( mut equal_pair ) = self.equal_pairs.pop() { - equal_pair.lhs.apply_substitution(&|v| self.σ.get(v).cloned()); - equal_pair.rhs.apply_substitution(&|v| self.σ.get(v).cloned()); + equal_pair.lhs.apply_subst(&self.σ); + equal_pair.rhs.apply_subst(&self.σ); self.eval_equation(equal_pair)?; } @@ -474,8 +474,8 @@ impl UnificationProblem { // solve subtypes // eprintln!("------ SOLVE SUBTYPES ---- "); for mut subtype_pair in self.subtype_pairs.clone().into_iter() { - subtype_pair.lhs.apply_substitution(&|v| self.σ.get(v).cloned()); - subtype_pair.rhs.apply_substitution(&|v| self.σ.get(v).cloned()); + subtype_pair.lhs.apply_subst(&self.σ); + subtype_pair.rhs.apply_subst(&self.σ); let _halo = self.eval_subtype( subtype_pair.clone() )?.strip(); } @@ -495,8 +495,8 @@ impl UnificationProblem { // eprintln!("------ MAKE HALOS -----"); let mut halo_types = Vec::new(); for mut subtype_pair in self.subtype_pairs.clone().into_iter() { - subtype_pair.lhs = subtype_pair.lhs.apply_substitution(&|v| self.σ.get(v).cloned()).clone().strip(); - subtype_pair.rhs = subtype_pair.rhs.apply_substitution(&|v| self.σ.get(v).cloned()).clone().strip(); + subtype_pair.lhs = subtype_pair.lhs.apply_subst(&self.σ).clone().strip(); + subtype_pair.rhs = subtype_pair.rhs.apply_subst(&self.σ).clone().strip(); let halo = self.eval_subtype( subtype_pair.clone() )?.strip(); halo_types.push(halo);