From 4c1db8756593ebcff2d93d491cab285baabea972 Mon Sep 17 00:00:00 2001 From: Michael Sippel <micha@fragmental.art> Date: Sat, 15 Feb 2025 17:21:12 +0100 Subject: [PATCH] unification: reject non-identity loops & add test cases --- src/term.rs | 16 ++++++++++++++++ src/test/unification.rs | 31 ++++++++++++++++++++++++------- src/unification.rs | 29 +++++++++++++++++++++-------- 3 files changed, 61 insertions(+), 15 deletions(-) diff --git a/src/term.rs b/src/term.rs index 2879ced..c93160b 100644 --- a/src/term.rs +++ b/src/term.rs @@ -76,6 +76,22 @@ impl TypeTerm { self.arg(TypeTerm::Char(c)) } + pub fn contains_var(&self, var_id: u64) -> bool { + match self { + TypeTerm::TypeID(TypeID::Var(v)) => (&var_id == v), + TypeTerm::App(args) | + TypeTerm::Ladder(args) => { + for a in args.iter() { + if a.contains_var(var_id) { + return true; + } + } + false + } + _ => false + } + } + /// recursively apply substitution to all subterms, /// which will replace all occurences of variables which map /// some type-term in `subst` diff --git a/src/test/unification.rs b/src/test/unification.rs index 8aaee3f..d2a68a2 100644 --- a/src/test/unification.rs +++ b/src/test/unification.rs @@ -61,6 +61,19 @@ fn test_unification_error() { t2: dict.parse("B").unwrap() }) ); + + assert_eq!( + crate::unify( + &dict.parse("T").unwrap(), + &dict.parse("<Seq T>").unwrap() + ), + + Err(UnificationError { + addr: vec![], + t1: dict.parse("T").unwrap(), + t2: dict.parse("<Seq T>").unwrap() + }) + ); } #[test] @@ -119,7 +132,6 @@ fn test_unification() { #[test] fn test_subtype_unification() { let mut dict = BimapTypeDict::new(); - dict.add_varname(String::from("T")); dict.add_varname(String::from("U")); dict.add_varname(String::from("V")); @@ -130,12 +142,13 @@ fn test_subtype_unification() { (dict.parse("<Seq~T <Digit 10> ~ Char>").unwrap(), dict.parse("<Seq~<LengthPrefix x86.UInt64> Char ~ Ascii>").unwrap()), ]).solve_subtype(), - Ok( + Ok(( + dict.parse("<Seq <Digit 10>>").unwrap(), vec![ // T (TypeID::Var(0), dict.parse("<LengthPrefix x86.UInt64>").unwrap()) ].into_iter().collect() - ) + )) ); assert_eq!( @@ -143,7 +156,8 @@ fn test_subtype_unification() { (dict.parse("U").unwrap(), dict.parse("<Seq Char>").unwrap()), (dict.parse("T").unwrap(), dict.parse("<Seq U>").unwrap()), ]).solve_subtype(), - Ok( + Ok(( + TypeTerm::unit(), vec![ // T (TypeID::Var(0), dict.parse("<Seq <Seq Char>>").unwrap()), @@ -151,7 +165,7 @@ fn test_subtype_unification() { // U (TypeID::Var(1), dict.parse("<Seq Char>").unwrap()) ].into_iter().collect() - ) + )) ); assert_eq!( @@ -161,7 +175,10 @@ fn test_subtype_unification() { (dict.parse("<Seq ℕ~<PosInt 10 BigEndian>>").unwrap(), dict.parse("<Seq~<LengthPrefix x86.UInt64> W>").unwrap()), ]).solve_subtype(), - Ok( + Ok(( + dict.parse(" + <Seq~<LengthPrefix x86.UInt64> ℕ~<PosInt 10 BigEndian>> + ").unwrap(), vec![ // W (TypeID::Var(3), dict.parse("ℕ~<PosInt 10 BigEndian>").unwrap()), @@ -169,6 +186,6 @@ fn test_subtype_unification() { // T (TypeID::Var(0), dict.parse("ℕ~<PosInt 10 BigEndian>~<Seq Char>").unwrap()) ].into_iter().collect() - ) + )) ); } diff --git a/src/unification.rs b/src/unification.rs index 82d7a37..fd4800d 100644 --- a/src/unification.rs +++ b/src/unification.rs @@ -42,9 +42,16 @@ impl UnificationProblem { match (lhs.clone(), rhs.clone()) { (TypeTerm::TypeID(TypeID::Var(varid)), t) | (t, TypeTerm::TypeID(TypeID::Var(varid))) => { - self.σ.insert(TypeID::Var(varid), t.clone()); - self.reapply_subst(); - Ok(vec![]) + + if ! t.contains_var( varid ) { + self.σ.insert(TypeID::Var(varid), t.clone()); + self.reapply_subst(); + Ok(vec![]) + } else if t == TypeTerm::TypeID(TypeID::Var(varid)) { + Ok(vec![]) + } else { + Err(UnificationError{ addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t }) + } } (TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => { @@ -143,9 +150,15 @@ impl UnificationProblem { match (lhs.clone(), rhs.clone()) { (TypeTerm::TypeID(TypeID::Var(varid)), t) | (t, TypeTerm::TypeID(TypeID::Var(varid))) => { - self.σ.insert(TypeID::Var(varid), t.clone()); - self.reapply_subst(); - Ok(()) + if ! t.contains_var( varid ) { + self.σ.insert(TypeID::Var(varid), t.clone()); + self.reapply_subst(); + Ok(()) + } else if t == TypeTerm::TypeID(TypeID::Var(varid)) { + Ok(()) + } else { + Err(UnificationError{ addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t }) + } } (TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => { @@ -161,7 +174,7 @@ impl UnificationProblem { (TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) | (TypeTerm::App(a1), TypeTerm::App(a2)) => { if a1.len() == a2.len() { - for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate() { + for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate().rev() { let mut new_addr = addr.clone(); new_addr.push(i); self.eqs.push((x, y, new_addr)); @@ -240,7 +253,7 @@ impl UnificationProblem { let mut halo_type = TypeTerm::Ladder(halo_rungs); halo_type = halo_type.normalize(); halo_type = halo_type.apply_substitution(&|k| self.σ.get(k).cloned()).clone(); - Ok((halo_type, self.σ)) + Ok((halo_type.param_normalize(), self.σ)) } }