From bd21a602f3778499a6264c8d6c6aa2060bfe47b3 Mon Sep 17 00:00:00 2001 From: Michael Sippel Date: Sat, 11 Nov 2023 16:26:30 +0100 Subject: [PATCH] unification --- src/test/unification.rs | 39 ++++++++++++ src/unification.rs | 135 +++++++++++++++++++--------------------- 2 files changed, 103 insertions(+), 71 deletions(-) diff --git a/src/test/unification.rs b/src/test/unification.rs index 40b7d68..34d355d 100644 --- a/src/test/unification.rs +++ b/src/test/unification.rs @@ -75,5 +75,44 @@ fn test_unification() { ">~~", true ); + + let mut dict = TypeDict::new(); + + dict.add_varname(String::from("T")); + dict.add_varname(String::from("U")); + dict.add_varname(String::from("V")); + dict.add_varname(String::from("W")); + + assert_eq!( + UnificationProblem::new(vec![ + (dict.parse("U").unwrap(), dict.parse("").unwrap()), + (dict.parse("T").unwrap(), dict.parse("").unwrap()), + ]).solve(), + Ok( + vec![ + // T + (TypeID::Var(0), dict.parse(">").unwrap()), + + // U + (TypeID::Var(1), dict.parse("").unwrap()) + ].into_iter().collect() + ) + ); + + assert_eq!( + UnificationProblem::new(vec![ + (dict.parse("").unwrap(), dict.parse(">").unwrap()), + (dict.parse("").unwrap(), dict.parse("").unwrap()), + ]).solve(), + Ok( + vec![ + // W + (TypeID::Var(3), dict.parse("ℕ").unwrap()), + + // T + (TypeID::Var(0), dict.parse("ℕ~").unwrap()) + ].into_iter().collect() + ) + ); } diff --git a/src/unification.rs b/src/unification.rs index 6d7598a..abbc1fe 100644 --- a/src/unification.rs +++ b/src/unification.rs @@ -12,95 +12,88 @@ pub struct UnificationError { pub t2: TypeTerm } -impl UnificationError { - pub fn new(t1: &TypeTerm, t2: &TypeTerm) -> Self { - UnificationError { - addr: vec![], - t1: t1.clone(), - t2: t2.clone() - } - } -} -/* -struct UnificationProblem { - eqs: Vec<(TypeTerm, TypeTerm)>, +pub struct UnificationProblem { + eqs: Vec<(TypeTerm, TypeTerm, Vec)>, σ: HashMap } impl UnificationProblem { - pub fn new() -> Self { + pub fn new(eqs: Vec<(TypeTerm, TypeTerm)>) -> Self { UnificationProblem { - eqs: Vec::new(), + eqs: eqs.iter().map(|(lhs,rhs)| (lhs.clone(),rhs.clone(),vec![])).collect(), σ: HashMap::new() } } - pub fn eval_equation(&mut self, lhs: &TypeTerm, rhs: &TypeTerm) -> Option { - match (lhs, rhs) { - + pub fn eval_equation(&mut self, lhs: TypeTerm, rhs: TypeTerm, addr: Vec) -> Result<(), UnificationError> { + match (lhs.clone(), rhs.clone()) { + (TypeTerm::TypeID(TypeID::Var(varid)), t) | + (t, TypeTerm::TypeID(TypeID::Var(varid))) => { + self.σ.insert(TypeID::Var(varid), t.clone()); + + // update all values in substitution + let mut new_σ = HashMap::new(); + for (v, tt) in self.σ.iter() { + let mut tt = tt.clone(); + tt.apply_substitution(&|v| self.σ.get(v).cloned()); + new_σ.insert(v.clone(), tt); + } + self.σ = new_σ; + + Ok(()) + } + + (TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => { + if a1 == a2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } + } + (TypeTerm::Num(n1), TypeTerm::Num(n2)) => { + if n1 == n2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } + } + (TypeTerm::Char(c1), TypeTerm::Char(c2)) => { + if c1 == c2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } + } + + (TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) | + (TypeTerm::App(a1), TypeTerm::App(a2)) => { + if a1.len() == a2.len() { + for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate() { + let mut new_addr = addr.clone(); + new_addr.push(i); + self.eqs.push((x, y, new_addr)); + } + Ok(()) + } else { + Err(UnificationError{ addr, t1: lhs, t2: rhs }) + } + } + + (TypeTerm::Ladder(l1), TypeTerm::Ladder(l2)) => { + Err(UnificationError{ addr, t1: lhs, t2: rhs }) + } + + _ => Err(UnificationError{ addr, t1: lhs, t2: rhs}) } } - pub fn solve(self) -> Result, UnificationError> { - + pub fn solve(mut self) -> Result, UnificationError> { + while self.eqs.len() > 0 { + while let Some( (mut lhs,mut rhs,addr) ) = self.eqs.pop() { + lhs.apply_substitution(&|v| self.σ.get(v).cloned()); + rhs.apply_substitution(&|v| self.σ.get(v).cloned()); + self.eval_equation(lhs, rhs, addr)?; + } + } + + Ok(self.σ) } } -*/ + pub fn unify( t1: &TypeTerm, t2: &TypeTerm ) -> Result, UnificationError> { - let mut σ = HashMap::new(); - - match (t1, t2) { - (TypeTerm::TypeID(TypeID::Var(varid)), t) | - (t, TypeTerm::TypeID(TypeID::Var(varid))) => { - σ.insert(TypeID::Var(*varid), t.clone()); - Ok(σ) - } - - (TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => { - if a1 == a2 { Ok(σ) } else { Err(UnificationError::new(&t1, &t2)) } - } - (TypeTerm::Num(n1), TypeTerm::Num(n2)) => { - if n1 == n2 { Ok(σ) } else { Err(UnificationError::new(&t1, &t2)) } - } - (TypeTerm::Char(c1), TypeTerm::Char(c2)) => { - if c1 == c2 { Ok(σ) } else { Err(UnificationError::new(&t1, &t2)) } - } - - (TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) | - (TypeTerm::App(a1), TypeTerm::App(a2)) => { - if a1.len() == a2.len() { - for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate() { - let (mut x, mut y) = (x.clone(), y.clone()); - x.apply_substitution(&|v| σ.get(v).cloned()); - y.apply_substitution(&|v| σ.get(v).cloned()); - - match unify(&x, &y) { - Ok(τ) => { - for (v,t) in τ { - σ.insert(v,t); - } - } - Err(mut err) => { - err.addr.insert(0, i); - return Err(err); - } - } - } - Ok(σ) - } else { - Err(UnificationError::new(&t1, &t2)) - } - } - - (TypeTerm::Ladder(l1), TypeTerm::Ladder(l2)) => { - Err(UnificationError::new(&t1, &t2)) - } - - _ => Err(UnificationError::new(t1, t2)) - } + let mut unification = UnificationProblem::new(vec![ (t1.clone(), t2.clone()) ]); + unification.solve() } //<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\