From e17a1a9462f8757fe4c5e4127becaf997a078e28 Mon Sep 17 00:00:00 2001 From: Michael Sippel Date: Sun, 9 Feb 2025 16:58:58 +0100 Subject: [PATCH] add subtype unification --- src/test/unification.rs | 57 +++++++++++++++++++++++ src/unification.rs | 101 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 151 insertions(+), 7 deletions(-) diff --git a/src/test/unification.rs b/src/test/unification.rs index 34d355d..239b384 100644 --- a/src/test/unification.rs +++ b/src/test/unification.rs @@ -116,3 +116,60 @@ fn test_unification() { ); } + +#[test] +fn test_subtype_unification() { + let mut dict = TypeDict::new(); + + dict.add_varname(String::from("T")); + dict.add_varname(String::from("U")); + dict.add_varname(String::from("V")); + dict.add_varname(String::from("W")); + + assert_eq!( + UnificationProblem::new(vec![ + (dict.parse(" ~ Char>").unwrap(), + dict.parse(" Char ~ Ascii>").unwrap()), + ]).solve_subtype(), + Ok( + vec![ + // T + (TypeID::Var(0), dict.parse("").unwrap()) + ].into_iter().collect() + ) + ); + + assert_eq!( + UnificationProblem::new(vec![ + (dict.parse("U").unwrap(), dict.parse("").unwrap()), + (dict.parse("T").unwrap(), dict.parse("").unwrap()), + ]).solve_subtype(), + Ok( + vec![ + // T + (TypeID::Var(0), dict.parse(">").unwrap()), + + // U + (TypeID::Var(1), dict.parse("").unwrap()) + ].into_iter().collect() + ) + ); + + assert_eq!( + UnificationProblem::new(vec![ + (dict.parse("").unwrap(), + dict.parse(">").unwrap()), + (dict.parse(">").unwrap(), + dict.parse(" W>").unwrap()), + ]).solve_subtype(), + Ok( + vec![ + // W + (TypeID::Var(3), dict.parse("ℕ~").unwrap()), + + // T + (TypeID::Var(0), dict.parse("ℕ~~").unwrap()) + ].into_iter().collect() + ) + ); +} diff --git a/src/unification.rs b/src/unification.rs index ac7ec19..443a9a2 100644 --- a/src/unification.rs +++ b/src/unification.rs @@ -25,6 +25,86 @@ impl UnificationProblem { } } + pub fn eval_subtype(&mut self, lhs: TypeTerm, rhs: TypeTerm, addr: Vec) -> Result<(), UnificationError> { + match (lhs.clone(), rhs.clone()) { + (TypeTerm::TypeID(TypeID::Var(varid)), t) | + (t, TypeTerm::TypeID(TypeID::Var(varid))) => { + self.σ.insert(TypeID::Var(varid), t.clone()); + + // update all values in substitution + let mut new_σ = HashMap::new(); + for (v, tt) in self.σ.iter() { + let mut tt = tt.clone().normalize(); + tt.apply_substitution(&|v| self.σ.get(v).cloned()); + eprintln!("update σ : {:?} --> {:?}", v, tt); + new_σ.insert(v.clone(), tt); + } + self.σ = new_σ; + + Ok(()) + } + + (TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => { + if a1 == a2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } + } + (TypeTerm::Num(n1), TypeTerm::Num(n2)) => { + if n1 == n2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } + } + (TypeTerm::Char(c1), TypeTerm::Char(c2)) => { + if c1 == c2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } + } + + (TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) => { + eprintln!("unification: check two ladders"); + for i in 0..a1.len() { + if let Some((_, _)) = a1[i].is_semantic_subtype_of( &a2[0] ) { + for j in 0..a2.len() { + if i+j < a1.len() { + let mut new_addr = addr.clone(); + new_addr.push(i+j); + self.eqs.push((a1[i+j].clone(), a2[j].clone(), new_addr)) + } + } + return Ok(()) + } + } + + Err(UnificationError{ addr, t1: lhs, t2: rhs }) + }, + + (t, TypeTerm::Ladder(a1)) => { + if let Some((idx, τ)) = TypeTerm::Ladder(a1.clone()).is_semantic_subtype_of(&t) { + Ok(()) + } else { + Err(UnificationError{ addr, t1: TypeTerm::Ladder(a1), t2: t }) + } + } + + (TypeTerm::Ladder(a1), t) => { + if let Some((idx, τ)) = TypeTerm::Ladder(a1.clone()).is_semantic_subtype_of(&t) { + Ok(()) + } else { + Err(UnificationError{ addr, t1: TypeTerm::Ladder(a1), t2: t }) + } + } + + (TypeTerm::App(a1), TypeTerm::App(a2)) => { + if a1.len() == a2.len() { + for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate() { + let mut new_addr = addr.clone(); + new_addr.push(i); + self.eqs.push((x, y, new_addr)); + } + Ok(()) + } else { + Err(UnificationError{ addr, t1: lhs, t2: rhs }) + } + } + + _ => Err(UnificationError{ addr, t1: lhs, t2: rhs}) + } + } + pub fn eval_equation(&mut self, lhs: TypeTerm, rhs: TypeTerm, addr: Vec) -> Result<(), UnificationError> { match (lhs.clone(), rhs.clone()) { (TypeTerm::TypeID(TypeID::Var(varid)), t) | @@ -72,14 +152,22 @@ impl UnificationProblem { } pub fn solve(mut self) -> Result, UnificationError> { - while self.eqs.len() > 0 { - while let Some( (mut lhs,mut rhs,addr) ) = self.eqs.pop() { - lhs.apply_substitution(&|v| self.σ.get(v).cloned()); - rhs.apply_substitution(&|v| self.σ.get(v).cloned()); - self.eval_equation(lhs, rhs, addr)?; - } + while let Some( (mut lhs,mut rhs,addr) ) = self.eqs.pop() { + lhs.apply_substitution(&|v| self.σ.get(v).cloned()); + rhs.apply_substitution(&|v| self.σ.get(v).cloned()); + self.eval_equation(lhs, rhs, addr)?; } + Ok(self.σ) + } + + pub fn solve_subtype(mut self) -> Result, UnificationError> { + while let Some( (mut lhs,mut rhs,addr) ) = self.eqs.pop() { + lhs.apply_substitution(&|v| self.σ.get(v).cloned()); + rhs.apply_substitution(&|v| self.σ.get(v).cloned()); + eprintln!("eval subtype LHS={:?} || RHS={:?}", lhs, rhs); + self.eval_subtype(lhs, rhs, addr)?; + } Ok(self.σ) } } @@ -93,4 +181,3 @@ pub fn unify( } //<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\ -