unification: reject non-identity loops & add test cases

This commit is contained in:
Michael Sippel 2025-02-15 17:21:12 +01:00
parent 03c2756ede
commit 4c1db87565
Signed by: senvas
GPG key ID: F96CF119C34B64A6
3 changed files with 61 additions and 15 deletions

View file

@ -76,6 +76,22 @@ impl TypeTerm {
self.arg(TypeTerm::Char(c))
}
pub fn contains_var(&self, var_id: u64) -> bool {
match self {
TypeTerm::TypeID(TypeID::Var(v)) => (&var_id == v),
TypeTerm::App(args) |
TypeTerm::Ladder(args) => {
for a in args.iter() {
if a.contains_var(var_id) {
return true;
}
}
false
}
_ => false
}
}
/// recursively apply substitution to all subterms,
/// which will replace all occurences of variables which map
/// some type-term in `subst`

View file

@ -61,6 +61,19 @@ fn test_unification_error() {
t2: dict.parse("B").unwrap()
})
);
assert_eq!(
crate::unify(
&dict.parse("T").unwrap(),
&dict.parse("<Seq T>").unwrap()
),
Err(UnificationError {
addr: vec![],
t1: dict.parse("T").unwrap(),
t2: dict.parse("<Seq T>").unwrap()
})
);
}
#[test]
@ -119,7 +132,6 @@ fn test_unification() {
#[test]
fn test_subtype_unification() {
let mut dict = BimapTypeDict::new();
dict.add_varname(String::from("T"));
dict.add_varname(String::from("U"));
dict.add_varname(String::from("V"));
@ -130,12 +142,13 @@ fn test_subtype_unification() {
(dict.parse("<Seq~T <Digit 10> ~ Char>").unwrap(),
dict.parse("<Seq~<LengthPrefix x86.UInt64> Char ~ Ascii>").unwrap()),
]).solve_subtype(),
Ok(
Ok((
dict.parse("<Seq <Digit 10>>").unwrap(),
vec![
// T
(TypeID::Var(0), dict.parse("<LengthPrefix x86.UInt64>").unwrap())
].into_iter().collect()
)
))
);
assert_eq!(
@ -143,7 +156,8 @@ fn test_subtype_unification() {
(dict.parse("U").unwrap(), dict.parse("<Seq Char>").unwrap()),
(dict.parse("T").unwrap(), dict.parse("<Seq U>").unwrap()),
]).solve_subtype(),
Ok(
Ok((
TypeTerm::unit(),
vec![
// T
(TypeID::Var(0), dict.parse("<Seq <Seq Char>>").unwrap()),
@ -151,7 +165,7 @@ fn test_subtype_unification() {
// U
(TypeID::Var(1), dict.parse("<Seq Char>").unwrap())
].into_iter().collect()
)
))
);
assert_eq!(
@ -161,7 +175,10 @@ fn test_subtype_unification() {
(dict.parse("<Seq ~<PosInt 10 BigEndian>>").unwrap(),
dict.parse("<Seq~<LengthPrefix x86.UInt64> W>").unwrap()),
]).solve_subtype(),
Ok(
Ok((
dict.parse("
<Seq~<LengthPrefix x86.UInt64> ~<PosInt 10 BigEndian>>
").unwrap(),
vec![
// W
(TypeID::Var(3), dict.parse("~<PosInt 10 BigEndian>").unwrap()),
@ -169,6 +186,6 @@ fn test_subtype_unification() {
// T
(TypeID::Var(0), dict.parse("~<PosInt 10 BigEndian>~<Seq Char>").unwrap())
].into_iter().collect()
)
))
);
}

View file

@ -42,9 +42,16 @@ impl UnificationProblem {
match (lhs.clone(), rhs.clone()) {
(TypeTerm::TypeID(TypeID::Var(varid)), t) |
(t, TypeTerm::TypeID(TypeID::Var(varid))) => {
self.σ.insert(TypeID::Var(varid), t.clone());
self.reapply_subst();
Ok(vec![])
if ! t.contains_var( varid ) {
self.σ.insert(TypeID::Var(varid), t.clone());
self.reapply_subst();
Ok(vec![])
} else if t == TypeTerm::TypeID(TypeID::Var(varid)) {
Ok(vec![])
} else {
Err(UnificationError{ addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t })
}
}
(TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => {
@ -143,9 +150,15 @@ impl UnificationProblem {
match (lhs.clone(), rhs.clone()) {
(TypeTerm::TypeID(TypeID::Var(varid)), t) |
(t, TypeTerm::TypeID(TypeID::Var(varid))) => {
self.σ.insert(TypeID::Var(varid), t.clone());
self.reapply_subst();
Ok(())
if ! t.contains_var( varid ) {
self.σ.insert(TypeID::Var(varid), t.clone());
self.reapply_subst();
Ok(())
} else if t == TypeTerm::TypeID(TypeID::Var(varid)) {
Ok(())
} else {
Err(UnificationError{ addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t })
}
}
(TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => {
@ -161,7 +174,7 @@ impl UnificationProblem {
(TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) |
(TypeTerm::App(a1), TypeTerm::App(a2)) => {
if a1.len() == a2.len() {
for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate() {
for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate().rev() {
let mut new_addr = addr.clone();
new_addr.push(i);
self.eqs.push((x, y, new_addr));
@ -240,7 +253,7 @@ impl UnificationProblem {
let mut halo_type = TypeTerm::Ladder(halo_rungs);
halo_type = halo_type.normalize();
halo_type = halo_type.apply_substitution(&|k| self.σ.get(k).cloned()).clone();
Ok((halo_type, self.σ))
Ok((halo_type.param_normalize(), self.σ))
}
}