Compare commits

...

4 commits

Author SHA1 Message Date
3eaca0dc37
work on unification
- add more unification tests
- rewrite subtype unification of ladders to work from bottom up
2025-03-15 11:33:48 +01:00
dc6626833d
add failing unification testcase 2025-03-14 17:46:59 +01:00
fe73c47504
find_morphism_path(): param-normalize halo 2025-03-14 17:44:27 +01:00
2c288dbff3
fix tests
- in subtype unification: correctly propagate error
- in case of subtype between two ladders, check that the matching sub-ladders end at the same bottom rung (to exclude trait-types from sub-types)
2025-03-12 16:38:39 +01:00
4 changed files with 272 additions and 165 deletions

View file

@ -21,8 +21,8 @@ pub struct MorphismType {
impl MorphismType {
pub fn normalize(self) -> Self {
MorphismType {
src_type: self.src_type.normalize().param_normalize(),
dst_type: self.dst_type.normalize().param_normalize(),
src_type: self.src_type.strip().param_normalize(),
dst_type: self.dst_type.strip().param_normalize(),
/*
subtype_bounds: Vec::new(),
trait_bounds: Vec::new()
@ -149,7 +149,7 @@ impl<M: Morphism + Clone> MorphismBase<M> {
if let Some( map_morph ) = item_morph_inst.m.map_morphism( seq_type.clone() ) {
dst_types.push(
MorphismInstance {
halo: TypeTerm::Ladder(dst_halo_ladder).normalize(),
halo: TypeTerm::Ladder(dst_halo_ladder).strip().param_normalize(),
m: map_morph,
σ: item_morph_inst.σ
}
@ -221,7 +221,7 @@ impl<M: Morphism + Clone> MorphismBase<M> {
n.halo = n.halo.clone().apply_substitution(
&|k| σ.get(k).cloned()
).clone().strip();
).clone().strip().param_normalize();
n.σ = new_σ;
}
@ -256,7 +256,7 @@ impl<M: Morphism + Clone> MorphismBase<M> {
n.halo = n.halo.clone().apply_substitution(
&|k| next_morph_inst.σ.get(k).cloned()
).clone().strip();
).clone().strip().param_normalize();
n.σ = new_σ;
}

View file

@ -330,51 +330,6 @@ fn test_morphism_path_posint() {
]
)
);
/*
assert_eq!(
base.find_morphism_path(MorphismType {
src_type: dict.parse("Symbol ~ ~ <PosInt 10 BigEndian> ~ <Seq <Digit 10> ~ Char>").unwrap(),
dst_type: dict.parse("Symbol ~ ~ <PosInt 16 BigEndian> ~ <Seq <Digit 16> ~ Char>").unwrap()
}),
Some(
vec![
dict.parse("Symbol ~ ~ <PosInt 10 BigEndian> ~ <Seq <Digit 10> ~ Char>").unwrap().normalize(),
dict.parse("Symbol ~ ~ <PosInt 10 BigEndian> ~ <Seq <Digit 10> ~ _2^64 ~ machine.UInt64>").unwrap().normalize(),
dict.parse("Symbol ~ ~ <PosInt 10 LittleEndian> ~ <Seq <Digit 10> ~ _2^64 ~ machine.UInt64>").unwrap().normalize(),
dict.parse("Symbol ~ ~ <PosInt 16 LittleEndian> ~ <Seq <Digit 16> ~ _2^64 ~ machine.UInt64>").unwrap().normalize(),
dict.parse("Symbol ~ ~ <PosInt 16 BigEndian> ~ <Seq <Digit 16> ~ _2^64 ~ machine.UInt64>").unwrap().normalize(),
dict.parse("Symbol ~ ~ <PosInt 16 BigEndian> ~ <Seq <Digit 16> ~ Char>").unwrap().normalize(),
]
)
);
*/
/*
assert_eq!(
base.find_morphism_with_subtyping(
&MorphismType {
src_type: dict.parse("Symbol ~ ~ <PosInt 10 BigEndian> ~ <Seq <Digit 10> ~ Char>").unwrap(),
dst_type: dict.parse("Symbol ~ ~ <PosInt 10 BigEndian> ~ <Seq <Digit 10> ~ _2^64 ~ machine.UInt64>").unwrap()
}
),
Some((
DummyMorphism(MorphismType{
src_type: dict.parse("<Seq <Digit Radix> ~ Char>").unwrap(),
dst_type: dict.parse("<Seq <Digit Radix> ~ _2^64 ~ machine.UInt64>").unwrap()
}),
dict.parse("Symbol ~ ~ <PosInt 10 BigEndian> ~ <Seq <Digit 10>>").unwrap(),
vec![
(dict.get_typeid(&"Radix".into()).unwrap(),
dict.parse("10").unwrap())
].into_iter().collect::<std::collections::HashMap<TypeID, TypeTerm>>()
))
);
*/
}
#[test]
@ -421,8 +376,20 @@ fn test_morphism_path_listedit()
);
base.add_morphism(
DummyMorphism(MorphismType{
src_type: dict.parse("<List Char>").unwrap(),
dst_type: dict.parse("<List Char~ReprTree>").unwrap()
src_type: dict.parse("Char").unwrap(),
dst_type: dict.parse("Char ~ ReprTree").unwrap()
})
);
base.add_morphism(
DummyMorphism(MorphismType{
src_type: dict.parse("Char ~ ReprTree").unwrap(),
dst_type: dict.parse("Char").unwrap()
})
);
base.add_morphism(
DummyMorphism(MorphismType{
src_type: dict.parse("<List~Vec Char>").unwrap(),
dst_type: dict.parse("<List Char>").unwrap()
})
);
base.add_morphism(
@ -483,12 +450,12 @@ fn test_morphism_path_listedit()
},
MorphismInstance {
m: DummyMorphism(MorphismType{
src_type: dict.parse("<List~Vec ReprTree>").unwrap(),
dst_type: dict.parse("<List ReprTree> ~ EditTree").unwrap()
src_type: dict.parse("<List~Vec Char~ReprTree>").unwrap(),
dst_type: dict.parse("<List Char> ~ EditTree").unwrap()
}),
halo: dict.parse("<Seq~List <Digit 10>~Char>").unwrap(),
halo: dict.parse("<Seq~List <Digit 10>>").unwrap(),
σ: HashMap::new()
}
},
])
);
}

View file

@ -132,7 +132,61 @@ fn test_unification() {
}
#[test]
fn test_subtype_unification() {
fn test_subtype_unification1() {
let mut dict = BimapTypeDict::new();
dict.add_varname(String::from("T"));
assert_eq!(
UnificationProblem::new_sub(vec![
(dict.parse("A ~ B").unwrap(),
dict.parse("B").unwrap()),
]).solve(),
Ok((
vec![ dict.parse("A").unwrap() ],
vec![].into_iter().collect()
))
);
assert_eq!(
UnificationProblem::new_sub(vec![
(dict.parse("A ~ B ~ C ~ D").unwrap(),
dict.parse("C ~ D").unwrap()),
]).solve(),
Ok((
vec![ dict.parse("A ~ B").unwrap() ],
vec![].into_iter().collect()
))
);
assert_eq!(
UnificationProblem::new_sub(vec![
(dict.parse("A ~ B ~ C ~ D").unwrap(),
dict.parse("T ~ D").unwrap()),
]).solve(),
Ok((
vec![ TypeTerm::unit() ],
vec![
(dict.get_typeid(&"T".into()).unwrap(), dict.parse("A ~ B ~ C").unwrap())
].into_iter().collect()
))
);
assert_eq!(
UnificationProblem::new_sub(vec![
(dict.parse("A ~ B ~ C ~ D").unwrap(),
dict.parse("B ~ T ~ D").unwrap()),
]).solve(),
Ok((
vec![ dict.parse("A").unwrap() ],
vec![
(dict.get_typeid(&"T".into()).unwrap(), dict.parse("C").unwrap())
].into_iter().collect()
))
);
}
#[test]
fn test_subtype_unification2() {
let mut dict = BimapTypeDict::new();
dict.add_varname(String::from("T"));
@ -142,7 +196,7 @@ fn test_subtype_unification() {
assert_eq!(
UnificationProblem::new_sub(vec![
(dict.parse("<Seq~T <Digit 10> ~ Char>").unwrap(),
(dict.parse("<Seq~T <Digit 10> ~ Char ~ Ascii>").unwrap(),
dict.parse("<Seq~<LengthPrefix x86.UInt64> Char ~ Ascii>").unwrap()),
]).solve(),
Ok((
@ -232,9 +286,9 @@ fn test_trait_not_subtype() {
&dict.parse("A ~ B ~ C").expect("")
),
Err(UnificationError {
addr: vec![],
t1: dict.parse("A ~ B").expect(""),
t2: dict.parse("A ~ B ~ C").expect("")
addr: vec![1],
t1: dict.parse("B").expect(""),
t2: dict.parse("C").expect("")
})
);
@ -244,13 +298,33 @@ fn test_trait_not_subtype() {
&dict.parse("<Seq~List~Vec Char~ReprTree>").expect("")
),
Err(UnificationError {
addr: vec![1],
t1: dict.parse("<Digit 10> ~ Char").expect(""),
t2: dict.parse("Char ~ ReprTree").expect("")
addr: vec![1,1],
t1: dict.parse("Char").expect(""),
t2: dict.parse("ReprTree").expect("")
})
);
}
#[test]
fn test_reprtree_list_subtype() {
let mut dict = BimapTypeDict::new();
dict.add_varname("Item".into());
assert_eq!(
subtype_unify(
&dict.parse("<List~Vec <Digit 10>~Char~ReprTree>").expect(""),
&dict.parse("<List~Vec Item~ReprTree>").expect("")
),
Ok((
TypeTerm::unit(),
vec![
(dict.get_typeid(&"Item".into()).unwrap(), dict.parse("<Digit 10>~Char").unwrap())
].into_iter().collect()
))
);
}
#[test]
pub fn test_subtype_delim() {
let mut dict = BimapTypeDict::new();

View file

@ -137,85 +137,119 @@ impl UnificationProblem {
}
}
pub fn add_lower_subtype_bound(&mut self, v: u64, new_lower_bound: TypeTerm) -> Result<(),()> {
if new_lower_bound == TypeTerm::TypeID(TypeID::Var(v)) {
return Ok(());
}
if new_lower_bound.contains_var(v) {
// loop
return Err(());
}
if let Some(lower_bound) = self.lower_bounds.get(&v).cloned() {
// eprintln!("var already exists. check max. type");
if let Ok(halo) = self.eval_subtype(
UnificationPair {
lhs: lower_bound.clone(),
rhs: new_lower_bound.clone(),
halo: TypeTerm::unit(),
addr: vec![]
}
) {
// eprintln!("found more general lower bound");
// eprintln!("set var {}'s lowerbound to {:?}", varid, t.clone());
// generalize variable type to supertype
self.lower_bounds.insert(v, new_lower_bound);
Ok(())
} else if let Ok(halo) = self.eval_subtype(
UnificationPair{
lhs: new_lower_bound,
rhs: lower_bound,
halo: TypeTerm::unit(),
addr: vec![]
}
) {
// eprintln!("OK, is already larger type");
Ok(())
} else {
// eprintln!("violated subtype restriction");
Err(())
}
} else {
// eprintln!("set var {}'s lowerbound to {:?}", varid, t.clone());
self.lower_bounds.insert(v, new_lower_bound);
Ok(())
}
}
pub fn add_upper_subtype_bound(&mut self, v: u64, new_upper_bound: TypeTerm) -> Result<(),()> {
if new_upper_bound == TypeTerm::TypeID(TypeID::Var(v)) {
return Ok(());
}
if new_upper_bound.contains_var(v) {
// loop
return Err(());
}
if let Some(upper_bound) = self.upper_bounds.get(&v).cloned() {
if let Ok(_halo) = self.eval_subtype(
UnificationPair {
lhs: new_upper_bound.clone(),
rhs: upper_bound,
halo: TypeTerm::unit(),
addr: vec![]
}
) {
// found a lower upper bound
self.upper_bounds.insert(v, new_upper_bound);
Ok(())
} else {
Err(())
}
} else {
self.upper_bounds.insert(v, new_upper_bound);
Ok(())
}
}
pub fn eval_subtype(&mut self, unification_pair: UnificationPair) -> Result<
// ok: halo type
TypeTerm,
// error
UnificationError
> {
eprintln!("eval_subtype {:?} <=? {:?}", unification_pair.lhs, unification_pair.rhs);
match (unification_pair.lhs.clone(), unification_pair.rhs.clone()) {
/*
Variables
*/
(t, TypeTerm::TypeID(TypeID::Var(varid))) => {
// eprintln!("t <= variable");
if ! t.contains_var( varid ) {
// let x = self.σ.get(&TypeID::Var(varid)).cloned();
if let Some(lower_bound) = self.lower_bounds.get(&varid).cloned() {
// eprintln!("var already exists. check max. type");
if let Ok(halo) = self.eval_subtype(
UnificationPair {
lhs: lower_bound.clone(),
rhs: t.clone(),
halo: TypeTerm::unit(),
addr: vec![]
}
) {
// eprintln!("found more general lower bound");
// eprintln!("set var {}'s lowerbound to {:?}", varid, t.clone());
// generalize variable type to supertype
self.lower_bounds.insert(varid, t.clone());
} else if let Ok(halo) = self.eval_subtype(
UnificationPair{
lhs: t.clone(),
rhs: lower_bound.clone(),
halo: TypeTerm::unit(),
addr: vec![]
}
) {
// eprintln!("OK, is already larger type");
} else {
// eprintln!("violated subtype restriction");
return Err(UnificationError{ addr: unification_pair.addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t });
}
} else {
// eprintln!("set var {}'s lowerbound to {:?}", varid, t.clone());
self.lower_bounds.insert(varid, t.clone());
}
self.reapply_subst();
Ok(TypeTerm::unit())
} else if t == TypeTerm::TypeID(TypeID::Var(varid)) {
(t, TypeTerm::TypeID(TypeID::Var(v))) => {
//eprintln!("t <= variable");
if self.add_lower_subtype_bound(v, t.clone()).is_ok() {
Ok(TypeTerm::unit())
} else {
Err(UnificationError{ addr: unification_pair.addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t })
Err(UnificationError{ addr: unification_pair.addr, t1: TypeTerm::TypeID(TypeID::Var(v)), t2: t })
}
}
(TypeTerm::TypeID(TypeID::Var(varid)), t) => {
// eprintln!("variable <= t");
if let Some(upper_bound) = self.upper_bounds.get(&varid).cloned() {
if let Ok(_halo) = self.eval_subtype(
UnificationPair {
lhs: t.clone(),
rhs: upper_bound,
halo: TypeTerm::unit(),
addr: vec![]
}
) {
// found a lower upper bound
self.upper_bounds.insert(varid, t);
}
(TypeTerm::TypeID(TypeID::Var(v)), t) => {
//eprintln!("variable <= t");
if self.add_upper_subtype_bound(v, t.clone()).is_ok() {
Ok(TypeTerm::unit())
} else {
self.upper_bounds.insert(varid, t);
Err(UnificationError{ addr: unification_pair.addr, t1: TypeTerm::TypeID(TypeID::Var(v)), t2: t })
}
Ok(TypeTerm::unit())
}
/*
Atoms
*/
@ -236,60 +270,91 @@ impl UnificationProblem {
*/
(TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) => {
let mut halo = Vec::new();
if a1.len() < a2.len() {
return Err(UnificationError {
addr: vec![],
t1: unification_pair.lhs,
t2: unification_pair.rhs
});
}
for i in 0..a1.len() {
let mut new_addr = unification_pair.addr.clone();
new_addr.push(i);
if let Ok(r_halo) = self.eval_subtype( UnificationPair {
lhs: a1[i].clone(),
rhs: a2[0].clone(),
let mut l1_iter = a1.into_iter().enumerate().rev();
let mut l2_iter = a2.into_iter().rev();
halo: TypeTerm::unit(),
addr: new_addr
}) {
for j in 0..a2.len() {
if i+j < a1.len() {
let mut new_addr = unification_pair.addr.clone();
new_addr.push(i+j);
let mut halo_ladder = Vec::new();
let lhs = a1[i+j].clone();//.apply_substitution(&|k| self.σ.get(k).cloned()).clone();
let rhs = a2[j].clone();//.apply_substitution(&|k| self.σ.get(k).cloned()).clone();
while let Some(rhs) = l2_iter.next() {
//eprintln!("take rhs = {:?}", rhs);
if let Some((i, lhs)) = l1_iter.next() {
//eprintln!("take lhs ({}) = {:?}", i, lhs);
let mut addr = unification_pair.addr.clone();
addr.push(i);
//eprintln!("addr = {:?}", addr);
if let Ok(rung_halo) = self.eval_subtype(
match (lhs.clone(), rhs.clone()) {
(t, TypeTerm::TypeID(TypeID::Var(v))) => {
if self.add_upper_subtype_bound(v,t.clone()).is_ok() {
let mut new_upper_bound_ladder = vec![ t ];
if let Some(next_rhs) = l2_iter.next() {
// TODO
} else {
// take everything
while let Some((i,t)) = l1_iter.next() {
new_upper_bound_ladder.push(t);
}
}
new_upper_bound_ladder.reverse();
if self.add_upper_subtype_bound(v, TypeTerm::Ladder(new_upper_bound_ladder)).is_ok() {
// ok
} else {
return Err(UnificationError {
addr,
t1: lhs,
t2: rhs
});
}
} else {
return Err(UnificationError {
addr,
t1: lhs,
t2: rhs
});
}
}
(lhs, rhs) => {
if let Ok(ψ) = self.eval_subtype(
UnificationPair {
lhs: lhs.clone(), rhs: rhs.clone(),
addr: new_addr.clone(),
halo: TypeTerm::unit()
addr:addr.clone(), halo: TypeTerm::unit()
}
) {
halo.push(rung_halo);
// ok.
//eprintln!("rungs are subtypes. continue");
halo_ladder.push(ψ);
} else {
return Err(UnificationError{ addr: new_addr, t1: lhs, t2: rhs })
return Err(UnificationError {
addr,
t1: lhs,
t2: rhs
});
}
}
}
return Ok(
if halo.len() == 1 {
halo.pop().unwrap()
} else {
TypeTerm::Ladder(halo).strip()
});
} else {
halo.push(a1[i].clone());
//eprintln!("could not unify ladders");
// not a subtype,
return Err(UnificationError {
addr: vec![],
t1: unification_pair.lhs,
t2: unification_pair.rhs
});
}
}
//eprintln!("left ladder fully consumed");
Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs })
for (i,t) in l1_iter {
halo_ladder.push(t);
}
halo_ladder.reverse();
Ok(TypeTerm::Ladder(halo_ladder).strip().param_normalize())
},
(t, TypeTerm::Ladder(a1)) => {
@ -339,7 +404,7 @@ impl UnificationProblem {
// eprintln!("after strip: {:?}", y);
// eprintln!("APP<> eval {:?} \n ?<=? {:?} ", x, y);
if let Ok(halo) = self.eval_subtype(
match self.eval_subtype(
UnificationPair {
lhs: x.clone(),
rhs: y.clone(),
@ -347,15 +412,16 @@ impl UnificationProblem {
addr: new_addr,
}
) {
Ok(halo) => {
if halo == TypeTerm::unit() {
let mut y = y.clone();
y.apply_substitution(&|k| self.σ.get(k).cloned());
y = y.strip();
let mut top = y.get_lnf_vec().first().unwrap().clone();
halo_args.push(top.clone());
// eprintln!("add top {:?}", top);
//eprintln!("add top {:?}", top);
} else {
// eprintln!("add halo {:?}", halo);
//eprintln!("add halo {:?}", halo);
if n_halos_required > 0 {
let x = &mut halo_args[n_halos_required-1];
if let TypeTerm::Ladder(argrs) = x {
@ -376,13 +442,13 @@ impl UnificationProblem {
halo_args.push(halo);
n_halos_required += 1;
}
} else {
return Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs });
},
Err(err) => { return Err(err); }
}
}
if n_halos_required > 0 {
// eprintln!("halo args : {:?}", halo_args);
//eprintln!("halo args : {:?}", halo_args);
Ok(TypeTerm::App(halo_args))
} else {
Ok(TypeTerm::unit())