work on unification

- add more unification tests
- rewrite subtype unification of ladders to work from bottom up
This commit is contained in:
Michael Sippel 2025-03-15 11:33:48 +01:00
parent dc6626833d
commit 3eaca0dc37
Signed by: senvas
GPG key ID: F96CF119C34B64A6
2 changed files with 224 additions and 113 deletions

View file

@ -132,7 +132,61 @@ fn test_unification() {
}
#[test]
fn test_subtype_unification() {
fn test_subtype_unification1() {
let mut dict = BimapTypeDict::new();
dict.add_varname(String::from("T"));
assert_eq!(
UnificationProblem::new_sub(vec![
(dict.parse("A ~ B").unwrap(),
dict.parse("B").unwrap()),
]).solve(),
Ok((
vec![ dict.parse("A").unwrap() ],
vec![].into_iter().collect()
))
);
assert_eq!(
UnificationProblem::new_sub(vec![
(dict.parse("A ~ B ~ C ~ D").unwrap(),
dict.parse("C ~ D").unwrap()),
]).solve(),
Ok((
vec![ dict.parse("A ~ B").unwrap() ],
vec![].into_iter().collect()
))
);
assert_eq!(
UnificationProblem::new_sub(vec![
(dict.parse("A ~ B ~ C ~ D").unwrap(),
dict.parse("T ~ D").unwrap()),
]).solve(),
Ok((
vec![ TypeTerm::unit() ],
vec![
(dict.get_typeid(&"T".into()).unwrap(), dict.parse("A ~ B ~ C").unwrap())
].into_iter().collect()
))
);
assert_eq!(
UnificationProblem::new_sub(vec![
(dict.parse("A ~ B ~ C ~ D").unwrap(),
dict.parse("B ~ T ~ D").unwrap()),
]).solve(),
Ok((
vec![ dict.parse("A").unwrap() ],
vec![
(dict.get_typeid(&"T".into()).unwrap(), dict.parse("C").unwrap())
].into_iter().collect()
))
);
}
#[test]
fn test_subtype_unification2() {
let mut dict = BimapTypeDict::new();
dict.add_varname(String::from("T"));
@ -232,9 +286,9 @@ fn test_trait_not_subtype() {
&dict.parse("A ~ B ~ C").expect("")
),
Err(UnificationError {
addr: vec![],
t1: dict.parse("A ~ B").expect(""),
t2: dict.parse("A ~ B ~ C").expect("")
addr: vec![1],
t1: dict.parse("B").expect(""),
t2: dict.parse("C").expect("")
})
);
@ -244,9 +298,9 @@ fn test_trait_not_subtype() {
&dict.parse("<Seq~List~Vec Char~ReprTree>").expect("")
),
Err(UnificationError {
addr: vec![1],
t1: dict.parse("<Digit 10> ~ Char").expect(""),
t2: dict.parse("Char ~ ReprTree").expect("")
addr: vec![1,1],
t1: dict.parse("Char").expect(""),
t2: dict.parse("ReprTree").expect("")
})
);
}

View file

@ -137,85 +137,119 @@ impl UnificationProblem {
}
}
pub fn add_lower_subtype_bound(&mut self, v: u64, new_lower_bound: TypeTerm) -> Result<(),()> {
if new_lower_bound == TypeTerm::TypeID(TypeID::Var(v)) {
return Ok(());
}
if new_lower_bound.contains_var(v) {
// loop
return Err(());
}
if let Some(lower_bound) = self.lower_bounds.get(&v).cloned() {
// eprintln!("var already exists. check max. type");
if let Ok(halo) = self.eval_subtype(
UnificationPair {
lhs: lower_bound.clone(),
rhs: new_lower_bound.clone(),
halo: TypeTerm::unit(),
addr: vec![]
}
) {
// eprintln!("found more general lower bound");
// eprintln!("set var {}'s lowerbound to {:?}", varid, t.clone());
// generalize variable type to supertype
self.lower_bounds.insert(v, new_lower_bound);
Ok(())
} else if let Ok(halo) = self.eval_subtype(
UnificationPair{
lhs: new_lower_bound,
rhs: lower_bound,
halo: TypeTerm::unit(),
addr: vec![]
}
) {
// eprintln!("OK, is already larger type");
Ok(())
} else {
// eprintln!("violated subtype restriction");
Err(())
}
} else {
// eprintln!("set var {}'s lowerbound to {:?}", varid, t.clone());
self.lower_bounds.insert(v, new_lower_bound);
Ok(())
}
}
pub fn add_upper_subtype_bound(&mut self, v: u64, new_upper_bound: TypeTerm) -> Result<(),()> {
if new_upper_bound == TypeTerm::TypeID(TypeID::Var(v)) {
return Ok(());
}
if new_upper_bound.contains_var(v) {
// loop
return Err(());
}
if let Some(upper_bound) = self.upper_bounds.get(&v).cloned() {
if let Ok(_halo) = self.eval_subtype(
UnificationPair {
lhs: new_upper_bound.clone(),
rhs: upper_bound,
halo: TypeTerm::unit(),
addr: vec![]
}
) {
// found a lower upper bound
self.upper_bounds.insert(v, new_upper_bound);
Ok(())
} else {
Err(())
}
} else {
self.upper_bounds.insert(v, new_upper_bound);
Ok(())
}
}
pub fn eval_subtype(&mut self, unification_pair: UnificationPair) -> Result<
// ok: halo type
TypeTerm,
// error
UnificationError
> {
eprintln!("eval_subtype {:?} <=? {:?}", unification_pair.lhs, unification_pair.rhs);
match (unification_pair.lhs.clone(), unification_pair.rhs.clone()) {
/*
Variables
*/
(t, TypeTerm::TypeID(TypeID::Var(varid))) => {
// eprintln!("t <= variable");
if ! t.contains_var( varid ) {
// let x = self.σ.get(&TypeID::Var(varid)).cloned();
if let Some(lower_bound) = self.lower_bounds.get(&varid).cloned() {
// eprintln!("var already exists. check max. type");
if let Ok(halo) = self.eval_subtype(
UnificationPair {
lhs: lower_bound.clone(),
rhs: t.clone(),
halo: TypeTerm::unit(),
addr: vec![]
}
) {
// eprintln!("found more general lower bound");
// eprintln!("set var {}'s lowerbound to {:?}", varid, t.clone());
// generalize variable type to supertype
self.lower_bounds.insert(varid, t.clone());
} else if let Ok(halo) = self.eval_subtype(
UnificationPair{
lhs: t.clone(),
rhs: lower_bound.clone(),
halo: TypeTerm::unit(),
addr: vec![]
}
) {
// eprintln!("OK, is already larger type");
} else {
// eprintln!("violated subtype restriction");
return Err(UnificationError{ addr: unification_pair.addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t });
}
} else {
// eprintln!("set var {}'s lowerbound to {:?}", varid, t.clone());
self.lower_bounds.insert(varid, t.clone());
}
self.reapply_subst();
Ok(TypeTerm::unit())
} else if t == TypeTerm::TypeID(TypeID::Var(varid)) {
(t, TypeTerm::TypeID(TypeID::Var(v))) => {
//eprintln!("t <= variable");
if self.add_lower_subtype_bound(v, t.clone()).is_ok() {
Ok(TypeTerm::unit())
} else {
Err(UnificationError{ addr: unification_pair.addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t })
Err(UnificationError{ addr: unification_pair.addr, t1: TypeTerm::TypeID(TypeID::Var(v)), t2: t })
}
}
(TypeTerm::TypeID(TypeID::Var(varid)), t) => {
// eprintln!("variable <= t");
if let Some(upper_bound) = self.upper_bounds.get(&varid).cloned() {
if let Ok(_halo) = self.eval_subtype(
UnificationPair {
lhs: t.clone(),
rhs: upper_bound,
halo: TypeTerm::unit(),
addr: vec![]
}
) {
// found a lower upper bound
self.upper_bounds.insert(varid, t);
}
(TypeTerm::TypeID(TypeID::Var(v)), t) => {
//eprintln!("variable <= t");
if self.add_upper_subtype_bound(v, t.clone()).is_ok() {
Ok(TypeTerm::unit())
} else {
self.upper_bounds.insert(varid, t);
Err(UnificationError{ addr: unification_pair.addr, t1: TypeTerm::TypeID(TypeID::Var(v)), t2: t })
}
Ok(TypeTerm::unit())
}
/*
Atoms
*/
@ -236,68 +270,91 @@ impl UnificationProblem {
*/
(TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) => {
let mut halo = Vec::new();
if a1.len() < a2.len() {
return Err(UnificationError {
addr: vec![],
t1: unification_pair.lhs,
t2: unification_pair.rhs
});
}
for i in 0..a1.len() {
let mut new_addr = unification_pair.addr.clone();
new_addr.push(i);
if let Ok(r_halo) = self.eval_subtype( UnificationPair {
lhs: a1[i].clone(),
rhs: a2[0].clone(),
let mut l1_iter = a1.into_iter().enumerate().rev();
let mut l2_iter = a2.into_iter().rev();
halo: TypeTerm::unit(),
addr: new_addr
}) {
if a1.len() == i+a2.len() {
for j in 0..a2.len() {
if i+j < a1.len() {
let mut new_addr = unification_pair.addr.clone();
new_addr.push(i+j);
let mut halo_ladder = Vec::new();
let lhs = a1[i+j].clone();//.apply_substitution(&|k| self.σ.get(k).cloned()).clone();
let rhs = a2[j].clone();//.apply_substitution(&|k| self.σ.get(k).cloned()).clone();
while let Some(rhs) = l2_iter.next() {
//eprintln!("take rhs = {:?}", rhs);
if let Some((i, lhs)) = l1_iter.next() {
//eprintln!("take lhs ({}) = {:?}", i, lhs);
let mut addr = unification_pair.addr.clone();
addr.push(i);
//eprintln!("addr = {:?}", addr);
if let Ok(rung_halo) = self.eval_subtype(
match (lhs.clone(), rhs.clone()) {
(t, TypeTerm::TypeID(TypeID::Var(v))) => {
if self.add_upper_subtype_bound(v,t.clone()).is_ok() {
let mut new_upper_bound_ladder = vec![ t ];
if let Some(next_rhs) = l2_iter.next() {
// TODO
} else {
// take everything
while let Some((i,t)) = l1_iter.next() {
new_upper_bound_ladder.push(t);
}
}
new_upper_bound_ladder.reverse();
if self.add_upper_subtype_bound(v, TypeTerm::Ladder(new_upper_bound_ladder)).is_ok() {
// ok
} else {
return Err(UnificationError {
addr,
t1: lhs,
t2: rhs
});
}
} else {
return Err(UnificationError {
addr,
t1: lhs,
t2: rhs
});
}
}
(lhs, rhs) => {
if let Ok(ψ) = self.eval_subtype(
UnificationPair {
lhs: lhs.clone(), rhs: rhs.clone(),
addr: new_addr.clone(),
halo: TypeTerm::unit()
addr:addr.clone(), halo: TypeTerm::unit()
}
) {
halo.push(rung_halo);
// ok.
//eprintln!("rungs are subtypes. continue");
halo_ladder.push(ψ);
} else {
return Err(UnificationError{ addr: new_addr, t1: lhs, t2: rhs })
return Err(UnificationError {
addr,
t1: lhs,
t2: rhs
});
}
}
}
return Ok(
if halo.len() == 1 {
halo.pop().unwrap()
} else {
TypeTerm::Ladder(halo).strip()
});
} else {
/* after the first match, the remaining ladders dont have the same length
* thus it cannot be a subtype,
* at most it could be a trait type
*/
}
} else {
halo.push(a1[i].clone());
//eprintln!("could not unify ladders");
// not a subtype,
return Err(UnificationError {
addr: vec![],
t1: unification_pair.lhs,
t2: unification_pair.rhs
});
}
}
//eprintln!("left ladder fully consumed");
Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs })
for (i,t) in l1_iter {
halo_ladder.push(t);
}
halo_ladder.reverse();
Ok(TypeTerm::Ladder(halo_ladder).strip().param_normalize())
},
(t, TypeTerm::Ladder(a1)) => {
@ -362,9 +419,9 @@ impl UnificationProblem {
y = y.strip();
let mut top = y.get_lnf_vec().first().unwrap().clone();
halo_args.push(top.clone());
// eprintln!("add top {:?}", top);
//eprintln!("add top {:?}", top);
} else {
// eprintln!("add halo {:?}", halo);
//eprintln!("add halo {:?}", halo);
if n_halos_required > 0 {
let x = &mut halo_args[n_halos_required-1];
if let TypeTerm::Ladder(argrs) = x {
@ -391,7 +448,7 @@ impl UnificationProblem {
}
if n_halos_required > 0 {
// eprintln!("halo args : {:?}", halo_args);
//eprintln!("halo args : {:?}", halo_args);
Ok(TypeTerm::App(halo_args))
} else {
Ok(TypeTerm::unit())