From e17a1a9462f8757fe4c5e4127becaf997a078e28 Mon Sep 17 00:00:00 2001
From: Michael Sippel <micha@fragmental.art>
Date: Sun, 9 Feb 2025 16:58:58 +0100
Subject: [PATCH] add subtype unification

---
 src/test/unification.rs |  57 +++++++++++++++++++++++
 src/unification.rs      | 101 +++++++++++++++++++++++++++++++++++++---
 2 files changed, 151 insertions(+), 7 deletions(-)

diff --git a/src/test/unification.rs b/src/test/unification.rs
index 34d355d..239b384 100644
--- a/src/test/unification.rs
+++ b/src/test/unification.rs
@@ -116,3 +116,60 @@ fn test_unification() {
     );
 }
 
+
+#[test]
+fn test_subtype_unification() {
+    let mut dict = TypeDict::new();
+
+    dict.add_varname(String::from("T"));
+    dict.add_varname(String::from("U"));
+    dict.add_varname(String::from("V"));
+    dict.add_varname(String::from("W"));
+
+    assert_eq!(
+        UnificationProblem::new(vec![
+            (dict.parse("<Seq~T <Digit 10> ~ Char>").unwrap(),
+                dict.parse("<Seq~<LengthPrefix x86.UInt64> Char ~ Ascii>").unwrap()),
+        ]).solve_subtype(),
+        Ok(
+            vec![
+                // T
+                (TypeID::Var(0), dict.parse("<LengthPrefix x86.UInt64>").unwrap())
+            ].into_iter().collect()
+        )
+    );
+
+    assert_eq!(
+        UnificationProblem::new(vec![
+            (dict.parse("U").unwrap(), dict.parse("<Seq Char>").unwrap()),
+            (dict.parse("T").unwrap(), dict.parse("<Seq U>").unwrap()),
+        ]).solve_subtype(),
+        Ok(
+            vec![
+                // T
+                (TypeID::Var(0), dict.parse("<Seq <Seq Char>>").unwrap()),
+
+                // U
+                (TypeID::Var(1), dict.parse("<Seq Char>").unwrap())
+            ].into_iter().collect()
+        )
+    );
+
+    assert_eq!(
+        UnificationProblem::new(vec![
+            (dict.parse("<Seq T>").unwrap(),
+                dict.parse("<Seq W~<Seq Char>>").unwrap()),
+            (dict.parse("<Seq ℕ~<PosInt 10 BigEndian>>").unwrap(),
+                dict.parse("<Seq~<LengthPrefix x86.UInt64> W>").unwrap()),
+        ]).solve_subtype(),
+        Ok(
+            vec![
+                // W
+                (TypeID::Var(3), dict.parse("ℕ~<PosInt 10 BigEndian>").unwrap()),
+
+                // T
+                (TypeID::Var(0), dict.parse("ℕ~<PosInt 10 BigEndian>~<Seq Char>").unwrap())
+            ].into_iter().collect()
+        )
+    );
+}
diff --git a/src/unification.rs b/src/unification.rs
index ac7ec19..443a9a2 100644
--- a/src/unification.rs
+++ b/src/unification.rs
@@ -25,6 +25,86 @@ impl UnificationProblem {
         }
     }
 
+    pub fn eval_subtype(&mut self, lhs: TypeTerm, rhs: TypeTerm, addr: Vec<usize>) -> Result<(), UnificationError> {
+        match (lhs.clone(), rhs.clone()) {
+            (TypeTerm::TypeID(TypeID::Var(varid)), t) |
+            (t, TypeTerm::TypeID(TypeID::Var(varid))) => {
+                self.σ.insert(TypeID::Var(varid), t.clone());
+
+                // update all values in substitution
+                let mut new_σ = HashMap::new();
+                for (v, tt) in self.σ.iter() {
+                    let mut tt = tt.clone().normalize();
+                    tt.apply_substitution(&|v| self.σ.get(v).cloned());
+                    eprintln!("update σ : {:?} --> {:?}", v, tt);
+                    new_σ.insert(v.clone(), tt);
+                }
+                self.σ = new_σ;
+
+                Ok(())
+            }
+
+            (TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => {
+                if a1 == a2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) }
+            }
+            (TypeTerm::Num(n1), TypeTerm::Num(n2)) => {
+                if n1 == n2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) }
+            }
+            (TypeTerm::Char(c1), TypeTerm::Char(c2)) => {
+                if c1 == c2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) }
+            }
+
+            (TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) => {
+                eprintln!("unification: check two ladders");
+                for i in 0..a1.len() {
+                    if let Some((_, _)) = a1[i].is_semantic_subtype_of( &a2[0] ) {
+                        for j in 0..a2.len() {
+                            if i+j < a1.len() {
+                                let mut new_addr = addr.clone();
+                                new_addr.push(i+j);
+                                self.eqs.push((a1[i+j].clone(), a2[j].clone(), new_addr))
+                            }
+                        }
+                        return Ok(())
+                    }
+                }
+
+                Err(UnificationError{ addr, t1: lhs, t2: rhs })
+            },
+
+            (t, TypeTerm::Ladder(a1)) => {
+                if let Some((idx, τ)) = TypeTerm::Ladder(a1.clone()).is_semantic_subtype_of(&t) {
+                    Ok(())
+                } else {
+                    Err(UnificationError{ addr, t1: TypeTerm::Ladder(a1), t2: t })
+                }
+            }
+
+            (TypeTerm::Ladder(a1), t) => {
+                if let Some((idx, τ)) = TypeTerm::Ladder(a1.clone()).is_semantic_subtype_of(&t) {
+                    Ok(())
+                } else {
+                    Err(UnificationError{ addr, t1: TypeTerm::Ladder(a1), t2: t })
+                }
+            }
+
+            (TypeTerm::App(a1), TypeTerm::App(a2)) => {
+                if a1.len() == a2.len() {
+                    for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate() {
+                        let mut new_addr = addr.clone();
+                        new_addr.push(i);
+                        self.eqs.push((x, y, new_addr));
+                    }
+                    Ok(())
+                } else {
+                    Err(UnificationError{ addr, t1: lhs, t2: rhs })
+                }
+            }
+
+            _ => Err(UnificationError{ addr, t1: lhs, t2: rhs})
+        }
+    }
+
     pub fn eval_equation(&mut self, lhs: TypeTerm, rhs: TypeTerm, addr: Vec<usize>) -> Result<(), UnificationError> {
         match (lhs.clone(), rhs.clone()) {
             (TypeTerm::TypeID(TypeID::Var(varid)), t) |
@@ -72,14 +152,22 @@ impl UnificationProblem {
     }
 
     pub fn solve(mut self) -> Result<HashMap<TypeID, TypeTerm>, UnificationError> {
-        while self.eqs.len() > 0 {
-            while let Some( (mut lhs,mut rhs,addr) ) = self.eqs.pop() {
-                lhs.apply_substitution(&|v| self.σ.get(v).cloned());
-                rhs.apply_substitution(&|v| self.σ.get(v).cloned());
-                self.eval_equation(lhs, rhs, addr)?;
-            }
+        while let Some( (mut lhs,mut rhs,addr) ) = self.eqs.pop() {
+            lhs.apply_substitution(&|v| self.σ.get(v).cloned());
+            rhs.apply_substitution(&|v| self.σ.get(v).cloned());
+            self.eval_equation(lhs, rhs, addr)?;
         }
+        Ok(self.σ)
+    }
 
+
+    pub fn solve_subtype(mut self) -> Result<HashMap<TypeID, TypeTerm>, UnificationError> {
+        while let Some( (mut lhs,mut rhs,addr) ) = self.eqs.pop() {
+            lhs.apply_substitution(&|v| self.σ.get(v).cloned());
+            rhs.apply_substitution(&|v| self.σ.get(v).cloned());
+            eprintln!("eval subtype LHS={:?} || RHS={:?}", lhs, rhs);
+            self.eval_subtype(lhs, rhs, addr)?;
+        }
         Ok(self.σ)
     }
 }
@@ -93,4 +181,3 @@ pub fn unify(
 }
 
 //<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\
-