From f05ef075896b66f25b3997a7ae2b4ea1dc048936 Mon Sep 17 00:00:00 2001 From: Michael Sippel Date: Thu, 13 Feb 2025 12:27:48 +0100 Subject: [PATCH] subtype unification --- src/unification.rs | 149 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 125 insertions(+), 24 deletions(-) diff --git a/src/unification.rs b/src/unification.rs index 443a9a2..1c0d2d9 100644 --- a/src/unification.rs +++ b/src/unification.rs @@ -25,7 +25,20 @@ impl UnificationProblem { } } - pub fn eval_subtype(&mut self, lhs: TypeTerm, rhs: TypeTerm, addr: Vec) -> Result<(), UnificationError> { + pub fn reapply_subst(&mut self) { + // update all values in substitution + let mut new_σ = HashMap::new(); + for (v, tt) in self.σ.iter() { + let mut tt = tt.clone().normalize(); + tt.apply_substitution(&|v| self.σ.get(v).cloned()); + tt = tt.normalize(); + //eprintln!("update σ : {:?} --> {:?}", v, tt); + new_σ.insert(v.clone(), tt); + } + self.σ = new_σ; + } + + pub fn eval_subtype(&mut self, lhs: TypeTerm, rhs: TypeTerm, addr: Vec) -> Result, UnificationError> { match (lhs.clone(), rhs.clone()) { (TypeTerm::TypeID(TypeID::Var(varid)), t) | (t, TypeTerm::TypeID(TypeID::Var(varid))) => { @@ -45,44 +58,56 @@ impl UnificationProblem { } (TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => { - if a1 == a2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } + if a1 == a2 { Ok(vec![]) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } } (TypeTerm::Num(n1), TypeTerm::Num(n2)) => { - if n1 == n2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } + if n1 == n2 { Ok(vec![]) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } } (TypeTerm::Char(c1), TypeTerm::Char(c2)) => { - if c1 == c2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } + if c1 == c2 { Ok(vec![]) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } } (TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) => { - eprintln!("unification: check two ladders"); + let mut halo = Vec::new(); for i in 0..a1.len() { - if let Some((_, _)) = a1[i].is_semantic_subtype_of( &a2[0] ) { + if let Ok((r_halo, σ)) = subtype_unify( &a1[i], &a2[0] ) { + //eprintln!("unified ladders at {}, r_halo = {:?}", i, r_halo); + for (k,v) in σ.iter() { + self.σ.insert(k.clone(),v.clone()); + } + for j in 0..a2.len() { if i+j < a1.len() { let mut new_addr = addr.clone(); new_addr.push(i+j); - self.eqs.push((a1[i+j].clone(), a2[j].clone(), new_addr)) + self.eqs.push((a1[i+j].clone().apply_substitution(&|k| σ.get(k).cloned()).clone(), + a2[j].clone().apply_substitution(&|k| σ.get(k).cloned()).clone(), + new_addr)); } } - return Ok(()) + return Ok(halo) + } else { + halo.push(a1[i].clone()); + //eprintln!("could not unify ladders"); } } Err(UnificationError{ addr, t1: lhs, t2: rhs }) }, - (t, TypeTerm::Ladder(a1)) => { - if let Some((idx, τ)) = TypeTerm::Ladder(a1.clone()).is_semantic_subtype_of(&t) { - Ok(()) + (t, TypeTerm::Ladder(mut a1)) => { + if let Ok(mut halo) = self.eval_subtype( t.clone(), a1.first().unwrap().clone(), addr.clone() ) { + a1.append(&mut halo); + Ok(a1) } else { - Err(UnificationError{ addr, t1: TypeTerm::Ladder(a1), t2: t }) + Err(UnificationError{ addr, t1: t, t2: TypeTerm::Ladder(a1) }) } } - (TypeTerm::Ladder(a1), t) => { - if let Some((idx, τ)) = TypeTerm::Ladder(a1.clone()).is_semantic_subtype_of(&t) { - Ok(()) + (TypeTerm::Ladder(mut a1), t) => { + if let Ok(mut halo) = self.eval_subtype( a1.pop().unwrap(), t.clone(), addr.clone() ) { + a1.append(&mut halo); + Ok(a1) } else { Err(UnificationError{ addr, t1: TypeTerm::Ladder(a1), t2: t }) } @@ -90,12 +115,31 @@ impl UnificationProblem { (TypeTerm::App(a1), TypeTerm::App(a2)) => { if a1.len() == a2.len() { + let mut halo_args = Vec::new(); + let mut halo_required = false; + for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate() { let mut new_addr = addr.clone(); new_addr.push(i); - self.eqs.push((x, y, new_addr)); + //self.eqs.push((x, y, new_addr)); + + if let Ok(halo) = self.eval_subtype( x.clone(), y.clone(), new_addr ) { + if halo.len() == 0 { + halo_args.push(y.get_lnf_vec().first().unwrap().clone()); + } else { + halo_args.push(TypeTerm::Ladder(halo)); + halo_required = true; + } + } else { + return Err(UnificationError{ addr, t1: x, t2: y }) + } + } + + if halo_required { + Ok(vec![ TypeTerm::App(halo_args) ]) + } else { + Ok(vec![]) } - Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs }) } @@ -118,7 +162,9 @@ impl UnificationProblem { tt.apply_substitution(&|v| self.σ.get(v).cloned()); new_σ.insert(v.clone(), tt); } - self.σ = new_σ; + + self.σ.insert(TypeID::Var(varid), t.clone()); + self.reapply_subst(); Ok(()) } @@ -160,16 +206,63 @@ impl UnificationProblem { Ok(self.σ) } + pub fn solve_subtype(mut self) -> Result<(TypeTerm, HashMap), UnificationError> { - pub fn solve_subtype(mut self) -> Result, UnificationError> { - while let Some( (mut lhs,mut rhs,addr) ) = self.eqs.pop() { + pub fn insert_halo_at( + t: &mut TypeTerm, + mut addr: Vec, + halo_type: TypeTerm + ) { + match t { + TypeTerm::Ladder(rungs) => { + if let Some(idx) = addr.pop() { + for i in rungs.len()..idx+1 { + rungs.push(TypeTerm::unit()); + } + insert_halo_at( &mut rungs[idx], addr, halo_type ); + } else { + rungs.push(halo_type); + } + }, + + TypeTerm::App(args) => { + if let Some(idx) = addr.pop() { + insert_halo_at( &mut args[idx], addr, halo_type ); + } else { + *t = TypeTerm::Ladder(vec![ + halo_type, + t.clone() + ]); + } + } + + atomic => { + + } + } + } + + //let mut halo_type = TypeTerm::unit(); + let mut halo_rungs = Vec::new(); + + while let Some( (mut lhs, mut rhs, mut addr) ) = self.eqs.pop() { lhs.apply_substitution(&|v| self.σ.get(v).cloned()); rhs.apply_substitution(&|v| self.σ.get(v).cloned()); - eprintln!("eval subtype LHS={:?} || RHS={:?}", lhs, rhs); - self.eval_subtype(lhs, rhs, addr)?; + //eprintln!("eval subtype\n\tLHS={:?}\n\tRHS={:?}\n", lhs, rhs); + let mut new_halo = self.eval_subtype(lhs, rhs, addr.clone())?; + if new_halo.len() > 0 { + //insert_halo_at( &mut halo_type, addr, TypeTerm::Ladder(new_halo) ); + if addr.len() == 0 { + halo_rungs.push(TypeTerm::Ladder(new_halo)) + } + } } - Ok(self.σ) - } + + let mut halo_type = TypeTerm::Ladder(halo_rungs); + halo_type = halo_type.normalize(); + halo_type = halo_type.apply_substitution(&|k| self.σ.get(k).cloned()).clone(); + + Ok((halo_type.param_normalize(), self.σ)) } pub fn unify( @@ -180,4 +273,12 @@ pub fn unify( unification.solve() } +pub fn subtype_unify( + t1: &TypeTerm, + t2: &TypeTerm +) -> Result<(TypeTerm, HashMap), UnificationError> { + let mut unification = UnificationProblem::new(vec![ (t1.clone(), t2.clone()) ]); + unification.solve_subtype() +} + //<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\