From 4c1db8756593ebcff2d93d491cab285baabea972 Mon Sep 17 00:00:00 2001
From: Michael Sippel <micha@fragmental.art>
Date: Sat, 15 Feb 2025 17:21:12 +0100
Subject: [PATCH] unification: reject non-identity loops & add test cases

---
 src/term.rs             | 16 ++++++++++++++++
 src/test/unification.rs | 31 ++++++++++++++++++++++++-------
 src/unification.rs      | 29 +++++++++++++++++++++--------
 3 files changed, 61 insertions(+), 15 deletions(-)

diff --git a/src/term.rs b/src/term.rs
index 2879ced..c93160b 100644
--- a/src/term.rs
+++ b/src/term.rs
@@ -76,6 +76,22 @@ impl TypeTerm {
         self.arg(TypeTerm::Char(c))
     }
 
+    pub fn contains_var(&self, var_id: u64) -> bool {
+        match self {
+            TypeTerm::TypeID(TypeID::Var(v)) => (&var_id == v),
+            TypeTerm::App(args) |
+            TypeTerm::Ladder(args) => {
+                for a in args.iter() {
+                    if a.contains_var(var_id) {
+                        return true;
+                    }
+                }
+                false
+            }
+            _ => false
+        }
+    }
+
     /// recursively apply substitution to all subterms,
     /// which will replace all occurences of variables which map
     /// some type-term in `subst`
diff --git a/src/test/unification.rs b/src/test/unification.rs
index 8aaee3f..d2a68a2 100644
--- a/src/test/unification.rs
+++ b/src/test/unification.rs
@@ -61,6 +61,19 @@ fn test_unification_error() {
             t2: dict.parse("B").unwrap()
         })
     );
+
+    assert_eq!(
+        crate::unify(
+            &dict.parse("T").unwrap(),
+            &dict.parse("<Seq T>").unwrap()
+        ),
+
+        Err(UnificationError {
+            addr: vec![],
+            t1: dict.parse("T").unwrap(),
+            t2: dict.parse("<Seq T>").unwrap()
+        })
+    );
 }
 
 #[test]
@@ -119,7 +132,6 @@ fn test_unification() {
 #[test]
 fn test_subtype_unification() {
     let mut dict = BimapTypeDict::new();
-
     dict.add_varname(String::from("T"));
     dict.add_varname(String::from("U"));
     dict.add_varname(String::from("V"));
@@ -130,12 +142,13 @@ fn test_subtype_unification() {
             (dict.parse("<Seq~T <Digit 10> ~ Char>").unwrap(),
                 dict.parse("<Seq~<LengthPrefix x86.UInt64> Char ~ Ascii>").unwrap()),
         ]).solve_subtype(),
-        Ok(
+        Ok((
+            dict.parse("<Seq <Digit 10>>").unwrap(),
             vec![
                 // T
                 (TypeID::Var(0), dict.parse("<LengthPrefix x86.UInt64>").unwrap())
             ].into_iter().collect()
-        )
+        ))
     );
 
     assert_eq!(
@@ -143,7 +156,8 @@ fn test_subtype_unification() {
             (dict.parse("U").unwrap(), dict.parse("<Seq Char>").unwrap()),
             (dict.parse("T").unwrap(), dict.parse("<Seq U>").unwrap()),
         ]).solve_subtype(),
-        Ok(
+        Ok((
+            TypeTerm::unit(),
             vec![
                 // T
                 (TypeID::Var(0), dict.parse("<Seq <Seq Char>>").unwrap()),
@@ -151,7 +165,7 @@ fn test_subtype_unification() {
                 // U
                 (TypeID::Var(1), dict.parse("<Seq Char>").unwrap())
             ].into_iter().collect()
-        )
+        ))
     );
 
     assert_eq!(
@@ -161,7 +175,10 @@ fn test_subtype_unification() {
             (dict.parse("<Seq ℕ~<PosInt 10 BigEndian>>").unwrap(),
                 dict.parse("<Seq~<LengthPrefix x86.UInt64> W>").unwrap()),
         ]).solve_subtype(),
-        Ok(
+        Ok((
+            dict.parse("
+                <Seq~<LengthPrefix x86.UInt64> ℕ~<PosInt 10 BigEndian>>
+            ").unwrap(),
             vec![
                 // W
                 (TypeID::Var(3), dict.parse("ℕ~<PosInt 10 BigEndian>").unwrap()),
@@ -169,6 +186,6 @@ fn test_subtype_unification() {
                 // T
                 (TypeID::Var(0), dict.parse("ℕ~<PosInt 10 BigEndian>~<Seq Char>").unwrap())
             ].into_iter().collect()
-        )
+        ))
     );
 }
diff --git a/src/unification.rs b/src/unification.rs
index 82d7a37..fd4800d 100644
--- a/src/unification.rs
+++ b/src/unification.rs
@@ -42,9 +42,16 @@ impl UnificationProblem {
         match (lhs.clone(), rhs.clone()) {
             (TypeTerm::TypeID(TypeID::Var(varid)), t) |
             (t, TypeTerm::TypeID(TypeID::Var(varid))) => {
-                self.σ.insert(TypeID::Var(varid), t.clone());
-                self.reapply_subst();
-                Ok(vec![])
+
+                if ! t.contains_var( varid ) {
+                    self.σ.insert(TypeID::Var(varid), t.clone());
+                    self.reapply_subst();
+                    Ok(vec![])
+                } else if t == TypeTerm::TypeID(TypeID::Var(varid)) {
+                    Ok(vec![])
+                } else {
+                    Err(UnificationError{ addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t })
+                }
             }
 
             (TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => {
@@ -143,9 +150,15 @@ impl UnificationProblem {
         match (lhs.clone(), rhs.clone()) {
             (TypeTerm::TypeID(TypeID::Var(varid)), t) |
             (t, TypeTerm::TypeID(TypeID::Var(varid))) => {
-                self.σ.insert(TypeID::Var(varid), t.clone());
-                self.reapply_subst();
-                Ok(())
+                if ! t.contains_var( varid ) {
+                    self.σ.insert(TypeID::Var(varid), t.clone());
+                    self.reapply_subst();
+                    Ok(())
+                } else if t == TypeTerm::TypeID(TypeID::Var(varid)) {
+                    Ok(())
+                } else {
+                    Err(UnificationError{ addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t })
+                }
             }
 
             (TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => {
@@ -161,7 +174,7 @@ impl UnificationProblem {
             (TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) |
             (TypeTerm::App(a1), TypeTerm::App(a2)) => {
                 if a1.len() == a2.len() {
-                    for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate() {
+                    for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate().rev() {
                         let mut new_addr = addr.clone();
                         new_addr.push(i);
                         self.eqs.push((x, y, new_addr));
@@ -240,7 +253,7 @@ impl UnificationProblem {
         let mut halo_type = TypeTerm::Ladder(halo_rungs);
         halo_type = halo_type.normalize();
         halo_type = halo_type.apply_substitution(&|k| self.σ.get(k).cloned()).clone();
-        Ok((halo_type, self.σ))
+        Ok((halo_type.param_normalize(), self.σ))
     }
 }