From b502b62479b42806b8076c23881a4744fda42359 Mon Sep 17 00:00:00 2001 From: Michael Sippel Date: Sat, 15 Feb 2025 17:21:12 +0100 Subject: [PATCH] unification: reject non-identity loops & add test cases --- src/term.rs | 16 +++++++++++++++ src/test/unification.rs | 30 ++++++++++++++++++++++------ src/unification.rs | 44 +++++++++++++++++------------------------ 3 files changed, 58 insertions(+), 32 deletions(-) diff --git a/src/term.rs b/src/term.rs index 29c7d27..240996f 100644 --- a/src/term.rs +++ b/src/term.rs @@ -79,6 +79,22 @@ impl TypeTerm { self.arg(TypeTerm::Char(c)) } + pub fn contains_var(&self, var_id: u64) -> bool { + match self { + TypeTerm::TypeID(TypeID::Var(v)) => (&var_id == v), + TypeTerm::App(args) | + TypeTerm::Ladder(args) => { + for a in args.iter() { + if a.contains_var(var_id) { + return true; + } + } + false + } + _ => false + } + } + /// recursively apply substitution to all subterms, /// which will replace all occurences of variables which map /// some type-term in `subst` diff --git a/src/test/unification.rs b/src/test/unification.rs index 239b384..6c55a80 100644 --- a/src/test/unification.rs +++ b/src/test/unification.rs @@ -61,6 +61,19 @@ fn test_unification_error() { t2: dict.parse("B").unwrap() }) ); + + assert_eq!( + crate::unify( + &dict.parse("T").unwrap(), + &dict.parse("").unwrap() + ), + + Err(UnificationError { + addr: vec![], + t1: dict.parse("T").unwrap(), + t2: dict.parse("").unwrap() + }) + ); } #[test] @@ -131,12 +144,13 @@ fn test_subtype_unification() { (dict.parse(" ~ Char>").unwrap(), dict.parse(" Char ~ Ascii>").unwrap()), ]).solve_subtype(), - Ok( + Ok(( + dict.parse(">").unwrap(), vec![ // T (TypeID::Var(0), dict.parse("").unwrap()) ].into_iter().collect() - ) + )) ); assert_eq!( @@ -144,7 +158,8 @@ fn test_subtype_unification() { (dict.parse("U").unwrap(), dict.parse("").unwrap()), (dict.parse("T").unwrap(), dict.parse("").unwrap()), ]).solve_subtype(), - Ok( + Ok(( + TypeTerm::unit(), vec![ // T (TypeID::Var(0), dict.parse(">").unwrap()), @@ -152,7 +167,7 @@ fn test_subtype_unification() { // U (TypeID::Var(1), dict.parse("").unwrap()) ].into_iter().collect() - ) + )) ); assert_eq!( @@ -162,7 +177,10 @@ fn test_subtype_unification() { (dict.parse(">").unwrap(), dict.parse(" W>").unwrap()), ]).solve_subtype(), - Ok( + Ok(( + dict.parse(" + ℕ~> + ").unwrap(), vec![ // W (TypeID::Var(3), dict.parse("ℕ~").unwrap()), @@ -170,6 +188,6 @@ fn test_subtype_unification() { // T (TypeID::Var(0), dict.parse("ℕ~~").unwrap()) ].into_iter().collect() - ) + )) ); } diff --git a/src/unification.rs b/src/unification.rs index 1c0d2d9..e605af4 100644 --- a/src/unification.rs +++ b/src/unification.rs @@ -42,19 +42,15 @@ impl UnificationProblem { match (lhs.clone(), rhs.clone()) { (TypeTerm::TypeID(TypeID::Var(varid)), t) | (t, TypeTerm::TypeID(TypeID::Var(varid))) => { - self.σ.insert(TypeID::Var(varid), t.clone()); - - // update all values in substitution - let mut new_σ = HashMap::new(); - for (v, tt) in self.σ.iter() { - let mut tt = tt.clone().normalize(); - tt.apply_substitution(&|v| self.σ.get(v).cloned()); - eprintln!("update σ : {:?} --> {:?}", v, tt); - new_σ.insert(v.clone(), tt); + if ! t.contains_var( varid ) { + self.σ.insert(TypeID::Var(varid), t.clone()); + self.reapply_subst(); + Ok(vec![]) + } else if t == TypeTerm::TypeID(TypeID::Var(varid)) { + Ok(vec![]) + } else { + Err(UnificationError{ addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t }) } - self.σ = new_σ; - - Ok(()) } (TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => { @@ -153,20 +149,15 @@ impl UnificationProblem { match (lhs.clone(), rhs.clone()) { (TypeTerm::TypeID(TypeID::Var(varid)), t) | (t, TypeTerm::TypeID(TypeID::Var(varid))) => { - self.σ.insert(TypeID::Var(varid), t.clone()); - - // update all values in substitution - let mut new_σ = HashMap::new(); - for (v, tt) in self.σ.iter() { - let mut tt = tt.clone(); - tt.apply_substitution(&|v| self.σ.get(v).cloned()); - new_σ.insert(v.clone(), tt); + if ! t.contains_var( varid ) { + self.σ.insert(TypeID::Var(varid), t.clone()); + self.reapply_subst(); + Ok(()) + } else if t == TypeTerm::TypeID(TypeID::Var(varid)) { + Ok(()) + } else { + Err(UnificationError{ addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t }) } - - self.σ.insert(TypeID::Var(varid), t.clone()); - self.reapply_subst(); - - Ok(()) } (TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => { @@ -182,7 +173,7 @@ impl UnificationProblem { (TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) | (TypeTerm::App(a1), TypeTerm::App(a2)) => { if a1.len() == a2.len() { - for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate() { + for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate().rev() { let mut new_addr = addr.clone(); new_addr.push(i); self.eqs.push((x, y, new_addr)); @@ -263,6 +254,7 @@ impl UnificationProblem { halo_type = halo_type.apply_substitution(&|k| self.σ.get(k).cloned()).clone(); Ok((halo_type.param_normalize(), self.σ)) + } } pub fn unify(