diff --git a/src/term.rs b/src/term.rs index 29c7d27..240996f 100644 --- a/src/term.rs +++ b/src/term.rs @@ -79,6 +79,22 @@ impl TypeTerm { self.arg(TypeTerm::Char(c)) } + pub fn contains_var(&self, var_id: u64) -> bool { + match self { + TypeTerm::TypeID(TypeID::Var(v)) => (&var_id == v), + TypeTerm::App(args) | + TypeTerm::Ladder(args) => { + for a in args.iter() { + if a.contains_var(var_id) { + return true; + } + } + false + } + _ => false + } + } + /// recursively apply substitution to all subterms, /// which will replace all occurences of variables which map /// some type-term in `subst` diff --git a/src/test/unification.rs b/src/test/unification.rs index 239b384..6c55a80 100644 --- a/src/test/unification.rs +++ b/src/test/unification.rs @@ -61,6 +61,19 @@ fn test_unification_error() { t2: dict.parse("B").unwrap() }) ); + + assert_eq!( + crate::unify( + &dict.parse("T").unwrap(), + &dict.parse("<Seq T>").unwrap() + ), + + Err(UnificationError { + addr: vec![], + t1: dict.parse("T").unwrap(), + t2: dict.parse("<Seq T>").unwrap() + }) + ); } #[test] @@ -131,12 +144,13 @@ fn test_subtype_unification() { (dict.parse("<Seq~T <Digit 10> ~ Char>").unwrap(), dict.parse("<Seq~<LengthPrefix x86.UInt64> Char ~ Ascii>").unwrap()), ]).solve_subtype(), - Ok( + Ok(( + dict.parse("<Seq <Digit 10>>").unwrap(), vec![ // T (TypeID::Var(0), dict.parse("<LengthPrefix x86.UInt64>").unwrap()) ].into_iter().collect() - ) + )) ); assert_eq!( @@ -144,7 +158,8 @@ fn test_subtype_unification() { (dict.parse("U").unwrap(), dict.parse("<Seq Char>").unwrap()), (dict.parse("T").unwrap(), dict.parse("<Seq U>").unwrap()), ]).solve_subtype(), - Ok( + Ok(( + TypeTerm::unit(), vec![ // T (TypeID::Var(0), dict.parse("<Seq <Seq Char>>").unwrap()), @@ -152,7 +167,7 @@ fn test_subtype_unification() { // U (TypeID::Var(1), dict.parse("<Seq Char>").unwrap()) ].into_iter().collect() - ) + )) ); assert_eq!( @@ -162,7 +177,10 @@ fn test_subtype_unification() { (dict.parse("<Seq ℕ~<PosInt 10 BigEndian>>").unwrap(), dict.parse("<Seq~<LengthPrefix x86.UInt64> W>").unwrap()), ]).solve_subtype(), - Ok( + Ok(( + dict.parse(" + <Seq~<LengthPrefix x86.UInt64> ℕ~<PosInt 10 BigEndian>> + ").unwrap(), vec![ // W (TypeID::Var(3), dict.parse("ℕ~<PosInt 10 BigEndian>").unwrap()), @@ -170,6 +188,6 @@ fn test_subtype_unification() { // T (TypeID::Var(0), dict.parse("ℕ~<PosInt 10 BigEndian>~<Seq Char>").unwrap()) ].into_iter().collect() - ) + )) ); } diff --git a/src/unification.rs b/src/unification.rs index 443a9a2..5f34815 100644 --- a/src/unification.rs +++ b/src/unification.rs @@ -29,19 +29,15 @@ impl UnificationProblem { match (lhs.clone(), rhs.clone()) { (TypeTerm::TypeID(TypeID::Var(varid)), t) | (t, TypeTerm::TypeID(TypeID::Var(varid))) => { - self.σ.insert(TypeID::Var(varid), t.clone()); - - // update all values in substitution - let mut new_σ = HashMap::new(); - for (v, tt) in self.σ.iter() { - let mut tt = tt.clone().normalize(); - tt.apply_substitution(&|v| self.σ.get(v).cloned()); - eprintln!("update σ : {:?} --> {:?}", v, tt); - new_σ.insert(v.clone(), tt); + if ! t.contains_var( varid ) { + self.σ.insert(TypeID::Var(varid), t.clone()); + self.reapply_subst(); + Ok(vec![]) + } else if t == TypeTerm::TypeID(TypeID::Var(varid)) { + Ok(vec![]) + } else { + Err(UnificationError{ addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t }) } - self.σ = new_σ; - - Ok(()) } (TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => { @@ -109,18 +105,15 @@ impl UnificationProblem { match (lhs.clone(), rhs.clone()) { (TypeTerm::TypeID(TypeID::Var(varid)), t) | (t, TypeTerm::TypeID(TypeID::Var(varid))) => { - self.σ.insert(TypeID::Var(varid), t.clone()); - - // update all values in substitution - let mut new_σ = HashMap::new(); - for (v, tt) in self.σ.iter() { - let mut tt = tt.clone(); - tt.apply_substitution(&|v| self.σ.get(v).cloned()); - new_σ.insert(v.clone(), tt); + if ! t.contains_var( varid ) { + self.σ.insert(TypeID::Var(varid), t.clone()); + self.reapply_subst(); + Ok(()) + } else if t == TypeTerm::TypeID(TypeID::Var(varid)) { + Ok(()) + } else { + Err(UnificationError{ addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t }) } - self.σ = new_σ; - - Ok(()) } (TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => { @@ -136,7 +129,7 @@ impl UnificationProblem { (TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) | (TypeTerm::App(a1), TypeTerm::App(a2)) => { if a1.len() == a2.len() { - for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate() { + for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate().rev() { let mut new_addr = addr.clone(); new_addr.push(i); self.eqs.push((x, y, new_addr)); @@ -168,7 +161,11 @@ impl UnificationProblem { eprintln!("eval subtype LHS={:?} || RHS={:?}", lhs, rhs); self.eval_subtype(lhs, rhs, addr)?; } - Ok(self.σ) + + let mut halo_type = TypeTerm::Ladder(halo_rungs); + halo_type = halo_type.normalize(); + halo_type = halo_type.apply_substitution(&|k| self.σ.get(k).cloned()).clone(); + Ok((halo_type.param_normalize(), self.σ)) } }