unification: reject non-identity loops & add test cases
This commit is contained in:
parent
f05ef07589
commit
b502b62479
3 changed files with 58 additions and 32 deletions
16
src/term.rs
16
src/term.rs
|
@ -79,6 +79,22 @@ impl TypeTerm {
|
|||
self.arg(TypeTerm::Char(c))
|
||||
}
|
||||
|
||||
pub fn contains_var(&self, var_id: u64) -> bool {
|
||||
match self {
|
||||
TypeTerm::TypeID(TypeID::Var(v)) => (&var_id == v),
|
||||
TypeTerm::App(args) |
|
||||
TypeTerm::Ladder(args) => {
|
||||
for a in args.iter() {
|
||||
if a.contains_var(var_id) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
_ => false
|
||||
}
|
||||
}
|
||||
|
||||
/// recursively apply substitution to all subterms,
|
||||
/// which will replace all occurences of variables which map
|
||||
/// some type-term in `subst`
|
||||
|
|
|
@ -61,6 +61,19 @@ fn test_unification_error() {
|
|||
t2: dict.parse("B").unwrap()
|
||||
})
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
crate::unify(
|
||||
&dict.parse("T").unwrap(),
|
||||
&dict.parse("<Seq T>").unwrap()
|
||||
),
|
||||
|
||||
Err(UnificationError {
|
||||
addr: vec![],
|
||||
t1: dict.parse("T").unwrap(),
|
||||
t2: dict.parse("<Seq T>").unwrap()
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -131,12 +144,13 @@ fn test_subtype_unification() {
|
|||
(dict.parse("<Seq~T <Digit 10> ~ Char>").unwrap(),
|
||||
dict.parse("<Seq~<LengthPrefix x86.UInt64> Char ~ Ascii>").unwrap()),
|
||||
]).solve_subtype(),
|
||||
Ok(
|
||||
Ok((
|
||||
dict.parse("<Seq <Digit 10>>").unwrap(),
|
||||
vec![
|
||||
// T
|
||||
(TypeID::Var(0), dict.parse("<LengthPrefix x86.UInt64>").unwrap())
|
||||
].into_iter().collect()
|
||||
)
|
||||
))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
|
@ -144,7 +158,8 @@ fn test_subtype_unification() {
|
|||
(dict.parse("U").unwrap(), dict.parse("<Seq Char>").unwrap()),
|
||||
(dict.parse("T").unwrap(), dict.parse("<Seq U>").unwrap()),
|
||||
]).solve_subtype(),
|
||||
Ok(
|
||||
Ok((
|
||||
TypeTerm::unit(),
|
||||
vec![
|
||||
// T
|
||||
(TypeID::Var(0), dict.parse("<Seq <Seq Char>>").unwrap()),
|
||||
|
@ -152,7 +167,7 @@ fn test_subtype_unification() {
|
|||
// U
|
||||
(TypeID::Var(1), dict.parse("<Seq Char>").unwrap())
|
||||
].into_iter().collect()
|
||||
)
|
||||
))
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
|
@ -162,7 +177,10 @@ fn test_subtype_unification() {
|
|||
(dict.parse("<Seq ℕ~<PosInt 10 BigEndian>>").unwrap(),
|
||||
dict.parse("<Seq~<LengthPrefix x86.UInt64> W>").unwrap()),
|
||||
]).solve_subtype(),
|
||||
Ok(
|
||||
Ok((
|
||||
dict.parse("
|
||||
<Seq~<LengthPrefix x86.UInt64> ℕ~<PosInt 10 BigEndian>>
|
||||
").unwrap(),
|
||||
vec![
|
||||
// W
|
||||
(TypeID::Var(3), dict.parse("ℕ~<PosInt 10 BigEndian>").unwrap()),
|
||||
|
@ -170,6 +188,6 @@ fn test_subtype_unification() {
|
|||
// T
|
||||
(TypeID::Var(0), dict.parse("ℕ~<PosInt 10 BigEndian>~<Seq Char>").unwrap())
|
||||
].into_iter().collect()
|
||||
)
|
||||
))
|
||||
);
|
||||
}
|
||||
|
|
|
@ -42,19 +42,15 @@ impl UnificationProblem {
|
|||
match (lhs.clone(), rhs.clone()) {
|
||||
(TypeTerm::TypeID(TypeID::Var(varid)), t) |
|
||||
(t, TypeTerm::TypeID(TypeID::Var(varid))) => {
|
||||
self.σ.insert(TypeID::Var(varid), t.clone());
|
||||
|
||||
// update all values in substitution
|
||||
let mut new_σ = HashMap::new();
|
||||
for (v, tt) in self.σ.iter() {
|
||||
let mut tt = tt.clone().normalize();
|
||||
tt.apply_substitution(&|v| self.σ.get(v).cloned());
|
||||
eprintln!("update σ : {:?} --> {:?}", v, tt);
|
||||
new_σ.insert(v.clone(), tt);
|
||||
if ! t.contains_var( varid ) {
|
||||
self.σ.insert(TypeID::Var(varid), t.clone());
|
||||
self.reapply_subst();
|
||||
Ok(vec![])
|
||||
} else if t == TypeTerm::TypeID(TypeID::Var(varid)) {
|
||||
Ok(vec![])
|
||||
} else {
|
||||
Err(UnificationError{ addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t })
|
||||
}
|
||||
self.σ = new_σ;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
(TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => {
|
||||
|
@ -153,20 +149,15 @@ impl UnificationProblem {
|
|||
match (lhs.clone(), rhs.clone()) {
|
||||
(TypeTerm::TypeID(TypeID::Var(varid)), t) |
|
||||
(t, TypeTerm::TypeID(TypeID::Var(varid))) => {
|
||||
self.σ.insert(TypeID::Var(varid), t.clone());
|
||||
|
||||
// update all values in substitution
|
||||
let mut new_σ = HashMap::new();
|
||||
for (v, tt) in self.σ.iter() {
|
||||
let mut tt = tt.clone();
|
||||
tt.apply_substitution(&|v| self.σ.get(v).cloned());
|
||||
new_σ.insert(v.clone(), tt);
|
||||
if ! t.contains_var( varid ) {
|
||||
self.σ.insert(TypeID::Var(varid), t.clone());
|
||||
self.reapply_subst();
|
||||
Ok(())
|
||||
} else if t == TypeTerm::TypeID(TypeID::Var(varid)) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(UnificationError{ addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t })
|
||||
}
|
||||
|
||||
self.σ.insert(TypeID::Var(varid), t.clone());
|
||||
self.reapply_subst();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
(TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => {
|
||||
|
@ -182,7 +173,7 @@ impl UnificationProblem {
|
|||
(TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) |
|
||||
(TypeTerm::App(a1), TypeTerm::App(a2)) => {
|
||||
if a1.len() == a2.len() {
|
||||
for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate() {
|
||||
for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate().rev() {
|
||||
let mut new_addr = addr.clone();
|
||||
new_addr.push(i);
|
||||
self.eqs.push((x, y, new_addr));
|
||||
|
@ -263,6 +254,7 @@ impl UnificationProblem {
|
|||
halo_type = halo_type.apply_substitution(&|k| self.σ.get(k).cloned()).clone();
|
||||
|
||||
Ok((halo_type.param_normalize(), self.σ))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unify(
|
||||
|
|
Loading…
Reference in a new issue