diff --git a/src/term.rs b/src/term.rs index 29c7d27..240996f 100644 --- a/src/term.rs +++ b/src/term.rs @@ -79,6 +79,22 @@ impl TypeTerm { self.arg(TypeTerm::Char(c)) } + pub fn contains_var(&self, var_id: u64) -> bool { + match self { + TypeTerm::TypeID(TypeID::Var(v)) => (&var_id == v), + TypeTerm::App(args) | + TypeTerm::Ladder(args) => { + for a in args.iter() { + if a.contains_var(var_id) { + return true; + } + } + false + } + _ => false + } + } + /// recursively apply substitution to all subterms, /// which will replace all occurences of variables which map /// some type-term in `subst` diff --git a/src/test/unification.rs b/src/test/unification.rs index 239b384..6c55a80 100644 --- a/src/test/unification.rs +++ b/src/test/unification.rs @@ -61,6 +61,19 @@ fn test_unification_error() { t2: dict.parse("B").unwrap() }) ); + + assert_eq!( + crate::unify( + &dict.parse("T").unwrap(), + &dict.parse("<Seq T>").unwrap() + ), + + Err(UnificationError { + addr: vec![], + t1: dict.parse("T").unwrap(), + t2: dict.parse("<Seq T>").unwrap() + }) + ); } #[test] @@ -131,12 +144,13 @@ fn test_subtype_unification() { (dict.parse("<Seq~T <Digit 10> ~ Char>").unwrap(), dict.parse("<Seq~<LengthPrefix x86.UInt64> Char ~ Ascii>").unwrap()), ]).solve_subtype(), - Ok( + Ok(( + dict.parse("<Seq <Digit 10>>").unwrap(), vec![ // T (TypeID::Var(0), dict.parse("<LengthPrefix x86.UInt64>").unwrap()) ].into_iter().collect() - ) + )) ); assert_eq!( @@ -144,7 +158,8 @@ fn test_subtype_unification() { (dict.parse("U").unwrap(), dict.parse("<Seq Char>").unwrap()), (dict.parse("T").unwrap(), dict.parse("<Seq U>").unwrap()), ]).solve_subtype(), - Ok( + Ok(( + TypeTerm::unit(), vec![ // T (TypeID::Var(0), dict.parse("<Seq <Seq Char>>").unwrap()), @@ -152,7 +167,7 @@ fn test_subtype_unification() { // U (TypeID::Var(1), dict.parse("<Seq Char>").unwrap()) ].into_iter().collect() - ) + )) ); assert_eq!( @@ -162,7 +177,10 @@ fn test_subtype_unification() { (dict.parse("<Seq ℕ~<PosInt 10 BigEndian>>").unwrap(), dict.parse("<Seq~<LengthPrefix x86.UInt64> W>").unwrap()), ]).solve_subtype(), - Ok( + Ok(( + dict.parse(" + <Seq~<LengthPrefix x86.UInt64> ℕ~<PosInt 10 BigEndian>> + ").unwrap(), vec![ // W (TypeID::Var(3), dict.parse("ℕ~<PosInt 10 BigEndian>").unwrap()), @@ -170,6 +188,6 @@ fn test_subtype_unification() { // T (TypeID::Var(0), dict.parse("ℕ~<PosInt 10 BigEndian>~<Seq Char>").unwrap()) ].into_iter().collect() - ) + )) ); } diff --git a/src/unification.rs b/src/unification.rs index 443a9a2..e605af4 100644 --- a/src/unification.rs +++ b/src/unification.rs @@ -25,64 +25,85 @@ impl UnificationProblem { } } - pub fn eval_subtype(&mut self, lhs: TypeTerm, rhs: TypeTerm, addr: Vec<usize>) -> Result<(), UnificationError> { + pub fn reapply_subst(&mut self) { + // update all values in substitution + let mut new_σ = HashMap::new(); + for (v, tt) in self.σ.iter() { + let mut tt = tt.clone().normalize(); + tt.apply_substitution(&|v| self.σ.get(v).cloned()); + tt = tt.normalize(); + //eprintln!("update σ : {:?} --> {:?}", v, tt); + new_σ.insert(v.clone(), tt); + } + self.σ = new_σ; + } + + pub fn eval_subtype(&mut self, lhs: TypeTerm, rhs: TypeTerm, addr: Vec<usize>) -> Result<Vec<TypeTerm>, UnificationError> { match (lhs.clone(), rhs.clone()) { (TypeTerm::TypeID(TypeID::Var(varid)), t) | (t, TypeTerm::TypeID(TypeID::Var(varid))) => { - self.σ.insert(TypeID::Var(varid), t.clone()); - - // update all values in substitution - let mut new_σ = HashMap::new(); - for (v, tt) in self.σ.iter() { - let mut tt = tt.clone().normalize(); - tt.apply_substitution(&|v| self.σ.get(v).cloned()); - eprintln!("update σ : {:?} --> {:?}", v, tt); - new_σ.insert(v.clone(), tt); + if ! t.contains_var( varid ) { + self.σ.insert(TypeID::Var(varid), t.clone()); + self.reapply_subst(); + Ok(vec![]) + } else if t == TypeTerm::TypeID(TypeID::Var(varid)) { + Ok(vec![]) + } else { + Err(UnificationError{ addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t }) } - self.σ = new_σ; - - Ok(()) } (TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => { - if a1 == a2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } + if a1 == a2 { Ok(vec![]) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } } (TypeTerm::Num(n1), TypeTerm::Num(n2)) => { - if n1 == n2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } + if n1 == n2 { Ok(vec![]) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } } (TypeTerm::Char(c1), TypeTerm::Char(c2)) => { - if c1 == c2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } + if c1 == c2 { Ok(vec![]) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } } (TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) => { - eprintln!("unification: check two ladders"); + let mut halo = Vec::new(); for i in 0..a1.len() { - if let Some((_, _)) = a1[i].is_semantic_subtype_of( &a2[0] ) { + if let Ok((r_halo, σ)) = subtype_unify( &a1[i], &a2[0] ) { + //eprintln!("unified ladders at {}, r_halo = {:?}", i, r_halo); + for (k,v) in σ.iter() { + self.σ.insert(k.clone(),v.clone()); + } + for j in 0..a2.len() { if i+j < a1.len() { let mut new_addr = addr.clone(); new_addr.push(i+j); - self.eqs.push((a1[i+j].clone(), a2[j].clone(), new_addr)) + self.eqs.push((a1[i+j].clone().apply_substitution(&|k| σ.get(k).cloned()).clone(), + a2[j].clone().apply_substitution(&|k| σ.get(k).cloned()).clone(), + new_addr)); } } - return Ok(()) + return Ok(halo) + } else { + halo.push(a1[i].clone()); + //eprintln!("could not unify ladders"); } } Err(UnificationError{ addr, t1: lhs, t2: rhs }) }, - (t, TypeTerm::Ladder(a1)) => { - if let Some((idx, τ)) = TypeTerm::Ladder(a1.clone()).is_semantic_subtype_of(&t) { - Ok(()) + (t, TypeTerm::Ladder(mut a1)) => { + if let Ok(mut halo) = self.eval_subtype( t.clone(), a1.first().unwrap().clone(), addr.clone() ) { + a1.append(&mut halo); + Ok(a1) } else { - Err(UnificationError{ addr, t1: TypeTerm::Ladder(a1), t2: t }) + Err(UnificationError{ addr, t1: t, t2: TypeTerm::Ladder(a1) }) } } - (TypeTerm::Ladder(a1), t) => { - if let Some((idx, τ)) = TypeTerm::Ladder(a1.clone()).is_semantic_subtype_of(&t) { - Ok(()) + (TypeTerm::Ladder(mut a1), t) => { + if let Ok(mut halo) = self.eval_subtype( a1.pop().unwrap(), t.clone(), addr.clone() ) { + a1.append(&mut halo); + Ok(a1) } else { Err(UnificationError{ addr, t1: TypeTerm::Ladder(a1), t2: t }) } @@ -90,12 +111,31 @@ impl UnificationProblem { (TypeTerm::App(a1), TypeTerm::App(a2)) => { if a1.len() == a2.len() { + let mut halo_args = Vec::new(); + let mut halo_required = false; + for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate() { let mut new_addr = addr.clone(); new_addr.push(i); - self.eqs.push((x, y, new_addr)); + //self.eqs.push((x, y, new_addr)); + + if let Ok(halo) = self.eval_subtype( x.clone(), y.clone(), new_addr ) { + if halo.len() == 0 { + halo_args.push(y.get_lnf_vec().first().unwrap().clone()); + } else { + halo_args.push(TypeTerm::Ladder(halo)); + halo_required = true; + } + } else { + return Err(UnificationError{ addr, t1: x, t2: y }) + } + } + + if halo_required { + Ok(vec![ TypeTerm::App(halo_args) ]) + } else { + Ok(vec![]) } - Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs }) } @@ -109,18 +149,15 @@ impl UnificationProblem { match (lhs.clone(), rhs.clone()) { (TypeTerm::TypeID(TypeID::Var(varid)), t) | (t, TypeTerm::TypeID(TypeID::Var(varid))) => { - self.σ.insert(TypeID::Var(varid), t.clone()); - - // update all values in substitution - let mut new_σ = HashMap::new(); - for (v, tt) in self.σ.iter() { - let mut tt = tt.clone(); - tt.apply_substitution(&|v| self.σ.get(v).cloned()); - new_σ.insert(v.clone(), tt); + if ! t.contains_var( varid ) { + self.σ.insert(TypeID::Var(varid), t.clone()); + self.reapply_subst(); + Ok(()) + } else if t == TypeTerm::TypeID(TypeID::Var(varid)) { + Ok(()) + } else { + Err(UnificationError{ addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t }) } - self.σ = new_σ; - - Ok(()) } (TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => { @@ -136,7 +173,7 @@ impl UnificationProblem { (TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) | (TypeTerm::App(a1), TypeTerm::App(a2)) => { if a1.len() == a2.len() { - for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate() { + for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate().rev() { let mut new_addr = addr.clone(); new_addr.push(i); self.eqs.push((x, y, new_addr)); @@ -160,15 +197,63 @@ impl UnificationProblem { Ok(self.σ) } + pub fn solve_subtype(mut self) -> Result<(TypeTerm, HashMap<TypeID, TypeTerm>), UnificationError> { - pub fn solve_subtype(mut self) -> Result<HashMap<TypeID, TypeTerm>, UnificationError> { - while let Some( (mut lhs,mut rhs,addr) ) = self.eqs.pop() { + pub fn insert_halo_at( + t: &mut TypeTerm, + mut addr: Vec<usize>, + halo_type: TypeTerm + ) { + match t { + TypeTerm::Ladder(rungs) => { + if let Some(idx) = addr.pop() { + for i in rungs.len()..idx+1 { + rungs.push(TypeTerm::unit()); + } + insert_halo_at( &mut rungs[idx], addr, halo_type ); + } else { + rungs.push(halo_type); + } + }, + + TypeTerm::App(args) => { + if let Some(idx) = addr.pop() { + insert_halo_at( &mut args[idx], addr, halo_type ); + } else { + *t = TypeTerm::Ladder(vec![ + halo_type, + t.clone() + ]); + } + } + + atomic => { + + } + } + } + + //let mut halo_type = TypeTerm::unit(); + let mut halo_rungs = Vec::new(); + + while let Some( (mut lhs, mut rhs, mut addr) ) = self.eqs.pop() { lhs.apply_substitution(&|v| self.σ.get(v).cloned()); rhs.apply_substitution(&|v| self.σ.get(v).cloned()); - eprintln!("eval subtype LHS={:?} || RHS={:?}", lhs, rhs); - self.eval_subtype(lhs, rhs, addr)?; + //eprintln!("eval subtype\n\tLHS={:?}\n\tRHS={:?}\n", lhs, rhs); + let mut new_halo = self.eval_subtype(lhs, rhs, addr.clone())?; + if new_halo.len() > 0 { + //insert_halo_at( &mut halo_type, addr, TypeTerm::Ladder(new_halo) ); + if addr.len() == 0 { + halo_rungs.push(TypeTerm::Ladder(new_halo)) + } + } } - Ok(self.σ) + + let mut halo_type = TypeTerm::Ladder(halo_rungs); + halo_type = halo_type.normalize(); + halo_type = halo_type.apply_substitution(&|k| self.σ.get(k).cloned()).clone(); + + Ok((halo_type.param_normalize(), self.σ)) } } @@ -180,4 +265,12 @@ pub fn unify( unification.solve() } +pub fn subtype_unify( + t1: &TypeTerm, + t2: &TypeTerm +) -> Result<(TypeTerm, HashMap<TypeID, TypeTerm>), UnificationError> { + let mut unification = UnificationProblem::new(vec![ (t1.clone(), t2.clone()) ]); + unification.solve_subtype() +} + //<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\