add subtype unification

This commit is contained in:
Michael Sippel 2025-02-09 16:58:58 +01:00
parent e53edd23b9
commit e17a1a9462
Signed by: senvas
GPG key ID: F96CF119C34B64A6
2 changed files with 151 additions and 7 deletions

View file

@ -116,3 +116,60 @@ fn test_unification() {
);
}
#[test]
fn test_subtype_unification() {
let mut dict = TypeDict::new();
dict.add_varname(String::from("T"));
dict.add_varname(String::from("U"));
dict.add_varname(String::from("V"));
dict.add_varname(String::from("W"));
assert_eq!(
UnificationProblem::new(vec![
(dict.parse("<Seq~T <Digit 10> ~ Char>").unwrap(),
dict.parse("<Seq~<LengthPrefix x86.UInt64> Char ~ Ascii>").unwrap()),
]).solve_subtype(),
Ok(
vec![
// T
(TypeID::Var(0), dict.parse("<LengthPrefix x86.UInt64>").unwrap())
].into_iter().collect()
)
);
assert_eq!(
UnificationProblem::new(vec![
(dict.parse("U").unwrap(), dict.parse("<Seq Char>").unwrap()),
(dict.parse("T").unwrap(), dict.parse("<Seq U>").unwrap()),
]).solve_subtype(),
Ok(
vec![
// T
(TypeID::Var(0), dict.parse("<Seq <Seq Char>>").unwrap()),
// U
(TypeID::Var(1), dict.parse("<Seq Char>").unwrap())
].into_iter().collect()
)
);
assert_eq!(
UnificationProblem::new(vec![
(dict.parse("<Seq T>").unwrap(),
dict.parse("<Seq W~<Seq Char>>").unwrap()),
(dict.parse("<Seq ~<PosInt 10 BigEndian>>").unwrap(),
dict.parse("<Seq~<LengthPrefix x86.UInt64> W>").unwrap()),
]).solve_subtype(),
Ok(
vec![
// W
(TypeID::Var(3), dict.parse("~<PosInt 10 BigEndian>").unwrap()),
// T
(TypeID::Var(0), dict.parse("~<PosInt 10 BigEndian>~<Seq Char>").unwrap())
].into_iter().collect()
)
);
}

View file

@ -25,6 +25,86 @@ impl UnificationProblem {
}
}
pub fn eval_subtype(&mut self, lhs: TypeTerm, rhs: TypeTerm, addr: Vec<usize>) -> Result<(), UnificationError> {
match (lhs.clone(), rhs.clone()) {
(TypeTerm::TypeID(TypeID::Var(varid)), t) |
(t, TypeTerm::TypeID(TypeID::Var(varid))) => {
self.σ.insert(TypeID::Var(varid), t.clone());
// update all values in substitution
let mut new_σ = HashMap::new();
for (v, tt) in self.σ.iter() {
let mut tt = tt.clone().normalize();
tt.apply_substitution(&|v| self.σ.get(v).cloned());
eprintln!("update σ : {:?} --> {:?}", v, tt);
new_σ.insert(v.clone(), tt);
}
self.σ = new_σ;
Ok(())
}
(TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => {
if a1 == a2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) }
}
(TypeTerm::Num(n1), TypeTerm::Num(n2)) => {
if n1 == n2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) }
}
(TypeTerm::Char(c1), TypeTerm::Char(c2)) => {
if c1 == c2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) }
}
(TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) => {
eprintln!("unification: check two ladders");
for i in 0..a1.len() {
if let Some((_, _)) = a1[i].is_semantic_subtype_of( &a2[0] ) {
for j in 0..a2.len() {
if i+j < a1.len() {
let mut new_addr = addr.clone();
new_addr.push(i+j);
self.eqs.push((a1[i+j].clone(), a2[j].clone(), new_addr))
}
}
return Ok(())
}
}
Err(UnificationError{ addr, t1: lhs, t2: rhs })
},
(t, TypeTerm::Ladder(a1)) => {
if let Some((idx, τ)) = TypeTerm::Ladder(a1.clone()).is_semantic_subtype_of(&t) {
Ok(())
} else {
Err(UnificationError{ addr, t1: TypeTerm::Ladder(a1), t2: t })
}
}
(TypeTerm::Ladder(a1), t) => {
if let Some((idx, τ)) = TypeTerm::Ladder(a1.clone()).is_semantic_subtype_of(&t) {
Ok(())
} else {
Err(UnificationError{ addr, t1: TypeTerm::Ladder(a1), t2: t })
}
}
(TypeTerm::App(a1), TypeTerm::App(a2)) => {
if a1.len() == a2.len() {
for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate() {
let mut new_addr = addr.clone();
new_addr.push(i);
self.eqs.push((x, y, new_addr));
}
Ok(())
} else {
Err(UnificationError{ addr, t1: lhs, t2: rhs })
}
}
_ => Err(UnificationError{ addr, t1: lhs, t2: rhs})
}
}
pub fn eval_equation(&mut self, lhs: TypeTerm, rhs: TypeTerm, addr: Vec<usize>) -> Result<(), UnificationError> {
match (lhs.clone(), rhs.clone()) {
(TypeTerm::TypeID(TypeID::Var(varid)), t) |
@ -72,14 +152,22 @@ impl UnificationProblem {
}
pub fn solve(mut self) -> Result<HashMap<TypeID, TypeTerm>, UnificationError> {
while self.eqs.len() > 0 {
while let Some( (mut lhs,mut rhs,addr) ) = self.eqs.pop() {
lhs.apply_substitution(&|v| self.σ.get(v).cloned());
rhs.apply_substitution(&|v| self.σ.get(v).cloned());
self.eval_equation(lhs, rhs, addr)?;
}
while let Some( (mut lhs,mut rhs,addr) ) = self.eqs.pop() {
lhs.apply_substitution(&|v| self.σ.get(v).cloned());
rhs.apply_substitution(&|v| self.σ.get(v).cloned());
self.eval_equation(lhs, rhs, addr)?;
}
Ok(self.σ)
}
pub fn solve_subtype(mut self) -> Result<HashMap<TypeID, TypeTerm>, UnificationError> {
while let Some( (mut lhs,mut rhs,addr) ) = self.eqs.pop() {
lhs.apply_substitution(&|v| self.σ.get(v).cloned());
rhs.apply_substitution(&|v| self.σ.get(v).cloned());
eprintln!("eval subtype LHS={:?} || RHS={:?}", lhs, rhs);
self.eval_subtype(lhs, rhs, addr)?;
}
Ok(self.σ)
}
}
@ -93,4 +181,3 @@ pub fn unify(
}
//<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\