wip subtype unification

This commit is contained in:
Michael Sippel 2025-03-06 14:01:57 +01:00
parent a6a6677920
commit 893d09255d
Signed by: senvas
GPG key ID: F96CF119C34B64A6
5 changed files with 523 additions and 208 deletions

View file

@ -46,6 +46,7 @@ pub fn common_halo(
}
if halo_rungs.len() == 0 {
None
} else {
Some(TypeTerm::Ladder(halo_rungs).normalize())

View file

@ -11,13 +11,22 @@ use {
pub struct MorphismType {
pub src_type: TypeTerm,
pub dst_type: TypeTerm,
/*
pub subtype_bounds: Vec< (TypeTerm, TypeTerm) >,
pub trait_bounds: Vec< (TypeTerm, TypeTerm) >,
pub equal_bounds: Vec< (TypeTerm, TypeTerm) >,
*/
}
impl MorphismType {
pub fn normalize(self) -> Self {
MorphismType {
src_type: self.src_type.normalize().param_normalize(),
dst_type: self.dst_type.normalize().param_normalize()
dst_type: self.dst_type.normalize().param_normalize(),
/*
subtype_bounds: Vec::new(),
trait_bounds: Vec::new()
*/
}
}
}
@ -55,7 +64,11 @@ impl<M: Morphism + Clone> MorphismInstance<M> {
self.halo.clone(),
self.m.get_type().dst_type.clone()
]).apply_substitution(&|k| self.σ.get(k).cloned())
.clone()
.clone(),
/*
trait_bounds: Vec::new(),
subtype_bounds: Vec::new()
*/
}.normalize()
}
}
@ -267,7 +280,8 @@ impl<M: Morphism + Clone> MorphismBase<M> {
let dst_type = TypeTerm::Ladder(vec![
halo.clone(),
morph_type.dst_type.clone()
]);
]).normalize().param_normalize();
eprintln!("-----------> {} <= {}",
dict.unparse(&dst_type), dict.unparse(&ty.dst_type)
);
@ -284,7 +298,7 @@ impl<M: Morphism + Clone> MorphismBase<M> {
eprintln!("match. halo2 = {}", dict.unparse(&halo2));
return Some(MorphismInstance {
m: m.clone(),
halo: halo2,
halo: halo,
σ,
});
}
@ -295,13 +309,13 @@ impl<M: Morphism + Clone> MorphismBase<M> {
pub fn find_map_morphism(&self, ty: &MorphismType, dict: &mut impl TypeDict) -> Option< MorphismInstance<M> > {
for seq_type in self.seq_types.iter() {
if let Ok((halo, σ)) = UnificationProblem::new(vec![
if let Ok((halos, σ)) = UnificationProblem::new_sub(vec![
(ty.src_type.clone().param_normalize(),
TypeTerm::App(vec![ seq_type.clone(), TypeTerm::TypeID(TypeID::Var(100)) ])),
(TypeTerm::App(vec![ seq_type.clone(), TypeTerm::TypeID(TypeID::Var(101)) ]),
ty.dst_type.clone().param_normalize()),
]).solve_subtype() {
]).solve() {
// TODO: use real fresh variable names
let item_morph_type = MorphismType {
src_type: σ.get(&TypeID::Var(100)).unwrap().clone(),
@ -314,7 +328,7 @@ impl<M: Morphism + Clone> MorphismBase<M> {
return Some( MorphismInstance {
m: list_morph,
σ,
halo
halo: halos[0].clone()
} );
}
}

View file

@ -121,6 +121,38 @@ impl TypeTerm {
self
}
/* strip away empty ladders
* & unwrap singletons
*/
pub fn strip(self) -> Self {
match self {
TypeTerm::Ladder(rungs) => {
let mut rungs :Vec<_> = rungs.into_iter()
.filter_map(|mut r| {
r = r.strip();
if r != TypeTerm::unit() { Some(r) }
else { None }
}).collect();
if rungs.len() == 1 {
rungs.pop().unwrap()
} else {
TypeTerm::Ladder(rungs)
}
},
TypeTerm::App(args) => {
let mut args :Vec<_> = args.into_iter().map(|arg| arg.strip()).collect();
if args.len() == 0 {
TypeTerm::unit()
} else if args.len() == 1 {
args.pop().unwrap()
} else {
TypeTerm::App(args)
}
}
atom => atom
}
}
}
//<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\

View file

@ -97,11 +97,12 @@ fn test_unification() {
dict.add_varname(String::from("W"));
assert_eq!(
UnificationProblem::new(vec![
UnificationProblem::new_eq(vec![
(dict.parse("U").unwrap(), dict.parse("<Seq Char>").unwrap()),
(dict.parse("T").unwrap(), dict.parse("<Seq U>").unwrap()),
]).solve(),
Ok(
Ok((
vec![],
vec![
// T
(TypeID::Var(0), dict.parse("<Seq <Seq Char>>").unwrap()),
@ -109,15 +110,16 @@ fn test_unification() {
// U
(TypeID::Var(1), dict.parse("<Seq Char>").unwrap())
].into_iter().collect()
)
))
);
assert_eq!(
UnificationProblem::new(vec![
UnificationProblem::new_eq(vec![
(dict.parse("<Seq T>").unwrap(), dict.parse("<Seq W~<Seq Char>>").unwrap()),
(dict.parse("<Seq >").unwrap(), dict.parse("<Seq W>").unwrap()),
]).solve(),
Ok(
Ok((
vec![],
vec![
// W
(TypeID::Var(3), dict.parse("").unwrap()),
@ -125,7 +127,7 @@ fn test_unification() {
// T
(TypeID::Var(0), dict.parse("~<Seq Char>").unwrap())
].into_iter().collect()
)
))
);
}
@ -139,12 +141,14 @@ fn test_subtype_unification() {
dict.add_varname(String::from("W"));
assert_eq!(
UnificationProblem::new(vec![
UnificationProblem::new_sub(vec![
(dict.parse("<Seq~T <Digit 10> ~ Char>").unwrap(),
dict.parse("<Seq~<LengthPrefix x86.UInt64> Char ~ Ascii>").unwrap()),
]).solve_subtype(),
]).solve(),
Ok((
dict.parse("<Seq <Digit 10>>").unwrap(),
vec![
dict.parse("<Seq <Digit 10>>").unwrap()
],
vec![
// T
(TypeID::Var(0), dict.parse("<LengthPrefix x86.UInt64>").unwrap())
@ -153,12 +157,15 @@ fn test_subtype_unification() {
);
assert_eq!(
UnificationProblem::new(vec![
UnificationProblem::new_sub(vec![
(dict.parse("U").unwrap(), dict.parse("<Seq Char>").unwrap()),
(dict.parse("T").unwrap(), dict.parse("<Seq U>").unwrap()),
]).solve_subtype(),
]).solve(),
Ok((
TypeTerm::unit(),
vec![
TypeTerm::unit(),
TypeTerm::unit(),
],
vec![
// T
(TypeID::Var(0), dict.parse("<Seq <Seq Char>>").unwrap()),
@ -170,16 +177,17 @@ fn test_subtype_unification() {
);
assert_eq!(
UnificationProblem::new(vec![
UnificationProblem::new_sub(vec![
(dict.parse("<Seq T>").unwrap(),
dict.parse("<Seq W~<Seq Char>>").unwrap()),
(dict.parse("<Seq~<LengthPrefix x86.UInt64> ~<PosInt 10 BigEndian>>").unwrap(),
dict.parse("<<LengthPrefix x86.UInt64> W>").unwrap()),
]).solve_subtype(),
]).solve(),
Ok((
dict.parse("
<Seq ~<PosInt 10 BigEndian>>
").unwrap(),
vec![
TypeTerm::unit(),
dict.parse("<Seq >").unwrap(),
],
vec![
// W
(TypeID::Var(3), dict.parse("~<PosInt 10 BigEndian>").unwrap()),
@ -189,4 +197,81 @@ fn test_subtype_unification() {
].into_iter().collect()
))
);
assert_eq!(
subtype_unify(
&dict.parse("<Seq~List~Vec <Digit 16>~Char>").expect(""),
&dict.parse("<List~Vec Char>").expect("")
),
Ok((
dict.parse("<Seq~List <Digit 16>>").expect(""),
vec![].into_iter().collect()
))
);
assert_eq!(
subtype_unify(
&dict.parse(" ~ <PosInt 16 BigEndian> ~ <Seq~List~Vec <Digit 16>~Char>").expect(""),
&dict.parse("<List~Vec Char>").expect("")
),
Ok((
dict.parse(" ~ <PosInt 16 BigEndian> ~ <Seq~List <Digit 16>>").expect(""),
vec![].into_iter().collect()
))
);
}
#[test]
pub fn test_subtype_delim() {
let mut dict = BimapTypeDict::new();
dict.add_varname(String::from("T"));
dict.add_varname(String::from("Delim"));
assert_eq!(
UnificationProblem::new_sub(vec![
(
//given type
dict.parse("
< Seq <Seq <Digit 10>~Char~Ascii~UInt8> >
~ < ValueSep ':' Char~Ascii~UInt8 >
~ < Seq~<LengthPrefix UInt64> Char~Ascii~UInt8 >
").expect(""),
//expected type
dict.parse("
< Seq <Seq T> >
~ < ValueSep Delim T >
~ < Seq~<LengthPrefix UInt64> T >
").expect("")
),
// subtype bounds
(
dict.parse("T").expect(""),
dict.parse("UInt8").expect("")
),
/* todo
(
dict.parse("<TypeOf Delim>").expect(""),
dict.parse("T").expect("")
),
*/
]).solve(),
Ok((
// halo types for each rhs in the sub-equations
vec![
dict.parse("<Seq <Seq <Digit 10>>>").expect(""),
dict.parse("Char~Ascii").expect(""),
],
// variable substitution
vec![
(dict.get_typeid(&"T".into()).unwrap(), dict.parse("Char~Ascii~UInt8").expect("")),
(dict.get_typeid(&"Delim".into()).unwrap(), TypeTerm::Char(':')),
].into_iter().collect()
))
);
}

View file

@ -1,6 +1,5 @@
use {
std::collections::HashMap,
crate::{term::*, dict::*}
crate::{dict::*, term::*}, std::{collections::HashMap, env::consts::ARCH}
};
//<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\
@ -12,21 +11,71 @@ pub struct UnificationError {
pub t2: TypeTerm
}
#[derive(Clone, Debug)]
pub struct UnificationPair {
addr: Vec<usize>,
halo: TypeTerm,
lhs: TypeTerm,
rhs: TypeTerm,
}
#[derive(Debug)]
pub struct UnificationProblem {
eqs: Vec<(TypeTerm, TypeTerm, Vec<usize>)>,
σ: HashMap<TypeID, TypeTerm>
σ: HashMap<TypeID, TypeTerm>,
upper_bounds: HashMap< u64, TypeTerm >,
lower_bounds: HashMap< u64, TypeTerm >,
equal_pairs: Vec<UnificationPair>,
subtype_pairs: Vec<UnificationPair>,
trait_pairs: Vec<UnificationPair>
}
impl UnificationProblem {
pub fn new(eqs: Vec<(TypeTerm, TypeTerm)>) -> Self {
pub fn new(
equal_pairs: Vec<(TypeTerm, TypeTerm)>,
subtype_pairs: Vec<(TypeTerm, TypeTerm)>,
trait_pairs: Vec<(TypeTerm, TypeTerm)>
) -> Self {
UnificationProblem {
eqs: eqs.iter().map(|(lhs,rhs)| (lhs.clone(),rhs.clone(),vec![])).collect(),
σ: HashMap::new()
σ: HashMap::new(),
equal_pairs: equal_pairs.into_iter().map(|(lhs,rhs)|
UnificationPair{
lhs,rhs,
halo: TypeTerm::unit(),
addr: Vec::new()
}).collect(),
subtype_pairs: subtype_pairs.into_iter().map(|(lhs,rhs)|
UnificationPair{
lhs,rhs,
halo: TypeTerm::unit(),
addr: Vec::new()
}).collect(),
trait_pairs: trait_pairs.into_iter().map(|(lhs,rhs)|
UnificationPair{
lhs,rhs,
halo: TypeTerm::unit(),
addr: Vec::new()
}).collect(),
upper_bounds: HashMap::new(),
lower_bounds: HashMap::new(),
}
}
pub fn new_eq(eqs: Vec<(TypeTerm, TypeTerm)>) -> Self {
UnificationProblem::new( eqs, Vec::new(), Vec::new() )
}
pub fn new_sub(subs: Vec<(TypeTerm, TypeTerm)>) -> Self {
UnificationProblem::new( Vec::new(), subs, Vec::new() )
}
/// update all values in substitution
pub fn reapply_subst(&mut self) {
// update all values in substitution
let mut new_σ = HashMap::new();
for (v, tt) in self.σ.iter() {
let mut tt = tt.clone().normalize();
@ -38,225 +87,359 @@ impl UnificationProblem {
self.σ = new_σ;
}
pub fn eval_subtype(&mut self, lhs: TypeTerm, rhs: TypeTerm, addr: Vec<usize>) -> Result<Vec<TypeTerm>, UnificationError> {
match (lhs.clone(), rhs.clone()) {
pub fn eval_equation(&mut self, unification_pair: UnificationPair) -> Result<(), UnificationError> {
match (&unification_pair.lhs, &unification_pair.rhs) {
(TypeTerm::TypeID(TypeID::Var(varid)), t) |
(t, TypeTerm::TypeID(TypeID::Var(varid))) => {
if ! t.contains_var( varid ) {
self.σ.insert(TypeID::Var(varid), t.clone());
if ! t.contains_var( *varid ) {
self.σ.insert(TypeID::Var(*varid), t.clone());
self.reapply_subst();
Ok(vec![])
} else if t == TypeTerm::TypeID(TypeID::Var(varid)) {
Ok(vec![])
Ok(())
} else if t == &TypeTerm::TypeID(TypeID::Var(*varid)) {
Ok(())
} else {
Err(UnificationError{ addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t })
Err(UnificationError{ addr: unification_pair.addr, t1: TypeTerm::TypeID(TypeID::Var(*varid)), t2: t.clone() })
}
}
(TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => {
if a1 == a2 { Ok(vec![]) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) }
if a1 == a2 { Ok(()) } else { Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs }) }
}
(TypeTerm::Num(n1), TypeTerm::Num(n2)) => {
if n1 == n2 { Ok(vec![]) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) }
if n1 == n2 { Ok(()) } else { Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs }) }
}
(TypeTerm::Char(c1), TypeTerm::Char(c2)) => {
if c1 == c2 { Ok(vec![]) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) }
}
(TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) => {
let mut halo = Vec::new();
for i in 0..a1.len() {
if let Ok((r_halo, σ)) = subtype_unify( &a1[i], &a2[0] ) {
//eprintln!("unified ladders at {}, r_halo = {:?}", i, r_halo);
for (k,v) in σ.iter() {
self.σ.insert(k.clone(),v.clone());
}
for j in 0..a2.len() {
if i+j < a1.len() {
let mut new_addr = addr.clone();
new_addr.push(i+j);
self.eqs.push((a1[i+j].clone().apply_substitution(&|k| σ.get(k).cloned()).clone(),
a2[j].clone().apply_substitution(&|k| σ.get(k).cloned()).clone(),
new_addr));
}
}
return Ok(halo)
} else {
halo.push(a1[i].clone());
//eprintln!("could not unify ladders");
}
}
Err(UnificationError{ addr, t1: lhs, t2: rhs })
},
(t, TypeTerm::Ladder(mut a1)) => {
/*
if let Ok(mut halo) = self.eval_subtype( t.clone(), a1.first().unwrap().clone(), addr.clone() )
a1.append(&mut halo);
Ok(a1)
} else {
*/
Err(UnificationError{ addr, t1: t, t2: TypeTerm::Ladder(a1) })
//}
}
(TypeTerm::Ladder(mut a1), t) => {
if let Ok(mut halo) = self.eval_subtype( a1.pop().unwrap(), t.clone(), addr.clone() ) {
a1.append(&mut halo);
Ok(a1)
} else {
Err(UnificationError{ addr, t1: TypeTerm::Ladder(a1), t2: t })
}
}
(TypeTerm::App(a1), TypeTerm::App(a2)) => {
if a1.len() == a2.len() {
let mut halo_args = Vec::new();
let mut halo_required = false;
for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate() {
let mut new_addr = addr.clone();
new_addr.push(i);
//self.eqs.push((x, y, new_addr));
if let Ok(halo) = self.eval_subtype( x.clone(), y.clone(), new_addr ) {
if halo.len() == 0 {
halo_args.push(y.get_lnf_vec().first().unwrap().clone());
} else {
halo_args.push(TypeTerm::Ladder(halo));
halo_required = true;
}
} else {
return Err(UnificationError{ addr, t1: x, t2: y })
}
}
if halo_required {
Ok(vec![ TypeTerm::App(halo_args) ])
} else {
Ok(vec![])
}
} else {
Err(UnificationError{ addr, t1: lhs, t2: rhs })
}
}
_ => Err(UnificationError{ addr, t1: lhs, t2: rhs})
}
}
pub fn eval_equation(&mut self, lhs: TypeTerm, rhs: TypeTerm, addr: Vec<usize>) -> Result<(), UnificationError> {
match (lhs.clone(), rhs.clone()) {
(TypeTerm::TypeID(TypeID::Var(varid)), t) |
(t, TypeTerm::TypeID(TypeID::Var(varid))) => {
if ! t.contains_var( varid ) {
self.σ.insert(TypeID::Var(varid), t.clone());
self.reapply_subst();
Ok(())
} else if t == TypeTerm::TypeID(TypeID::Var(varid)) {
Ok(())
} else {
Err(UnificationError{ addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t })
}
}
(TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => {
if a1 == a2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) }
}
(TypeTerm::Num(n1), TypeTerm::Num(n2)) => {
if n1 == n2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) }
}
(TypeTerm::Char(c1), TypeTerm::Char(c2)) => {
if c1 == c2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) }
if c1 == c2 { Ok(()) } else { Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs }) }
}
(TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) |
(TypeTerm::App(a1), TypeTerm::App(a2)) => {
if a1.len() == a2.len() {
for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate().rev() {
let mut new_addr = addr.clone();
let mut new_addr = unification_pair.addr.clone();
new_addr.push(i);
self.eqs.push((x, y, new_addr));
self.equal_pairs.push(
UnificationPair {
lhs: x,
rhs: y,
halo: TypeTerm::unit(),
addr: new_addr
});
}
Ok(())
} else {
Err(UnificationError{ addr, t1: lhs, t2: rhs })
Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs })
}
}
_ => Err(UnificationError{ addr, t1: lhs, t2: rhs})
_ => Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs })
}
}
pub fn solve(mut self) -> Result<HashMap<TypeID, TypeTerm>, UnificationError> {
while let Some( (mut lhs,mut rhs,addr) ) = self.eqs.pop() {
lhs.apply_substitution(&|v| self.σ.get(v).cloned());
rhs.apply_substitution(&|v| self.σ.get(v).cloned());
self.eval_equation(lhs, rhs, addr)?;
}
Ok(self.σ)
}
pub fn eval_subtype(&mut self, unification_pair: UnificationPair) -> Result<
// ok: halo type
TypeTerm,
// error
UnificationError
> {
match (unification_pair.lhs.clone(), unification_pair.rhs.clone()) {
pub fn solve_subtype(mut self) -> Result<(TypeTerm, HashMap<TypeID, TypeTerm>), UnificationError> {
/*
Variables
*/
pub fn insert_halo_at(
t: &mut TypeTerm,
mut addr: Vec<usize>,
halo_type: TypeTerm
) {
match t {
TypeTerm::Ladder(rungs) => {
if let Some(idx) = addr.pop() {
for i in rungs.len()..idx+1 {
rungs.push(TypeTerm::unit());
(TypeTerm::TypeID(TypeID::Var(varid)), t) => {
eprintln!("variable <= t");
if let Some(upper_bound) = self.upper_bounds.get(&varid).cloned() {
if let Ok(_halo) = self.eval_subtype(
UnificationPair {
lhs: t.clone(),
rhs: upper_bound,
halo: TypeTerm::unit(),
addr: vec![]
}
insert_halo_at( &mut rungs[idx], addr, halo_type );
} else {
rungs.push(halo_type);
) {
// found a lower upper bound
self.upper_bounds.insert(varid, t);
}
},
} else {
self.upper_bounds.insert(varid, t);
}
Ok(TypeTerm::unit())
}
TypeTerm::App(args) => {
if let Some(idx) = addr.pop() {
insert_halo_at( &mut args[idx], addr, halo_type );
(t, TypeTerm::TypeID(TypeID::Var(varid))) => {
eprintln!("t <= variable");
if ! t.contains_var( varid ) {
// let x = self.σ.get(&TypeID::Var(varid)).cloned();
if let Some(lower_bound) = self.lower_bounds.get(&varid).cloned() {
eprintln!("var already exists. check max. type");
if let Ok(halo) = self.eval_subtype(
UnificationPair {
lhs: lower_bound.clone(),
rhs: t.clone(),
halo: TypeTerm::unit(),
addr: vec![]
}
) {
eprintln!("found more general lower bound");
eprintln!("set var {}'s lowerbound to {:?}", varid, t.clone());
// generalize variable type to supertype
self.lower_bounds.insert(varid, t.clone());
} else if let Ok(halo) = self.eval_subtype(
UnificationPair{
lhs: t.clone(),
rhs: lower_bound.clone(),
halo: TypeTerm::unit(),
addr: vec![]
}
) {
eprintln!("OK, is already larger type");
} else {
eprintln!("violated subtype restriction");
return Err(UnificationError{ addr: unification_pair.addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t });
}
} else {
*t = TypeTerm::Ladder(vec![
halo_type,
t.clone()
]);
eprintln!("set var {}'s lowerbound to {:?}", varid, t.clone());
self.lower_bounds.insert(varid, t.clone());
}
self.reapply_subst();
Ok(TypeTerm::unit())
} else if t == TypeTerm::TypeID(TypeID::Var(varid)) {
Ok(TypeTerm::unit())
} else {
Err(UnificationError{ addr: unification_pair.addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t })
}
}
/*
Atoms
*/
(TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => {
if a1 == a2 { Ok(TypeTerm::unit()) } else { Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs}) }
}
(TypeTerm::Num(n1), TypeTerm::Num(n2)) => {
if n1 == n2 { Ok(TypeTerm::unit()) } else { Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs }) }
}
(TypeTerm::Char(c1), TypeTerm::Char(c2)) => {
if c1 == c2 { Ok(TypeTerm::unit()) } else { Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs }) }
}
/*
Ladders
*/
(TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) => {
let mut halo = Vec::new();
for i in 0..a1.len() {
let mut new_addr = unification_pair.addr.clone();
new_addr.push(i);
if let Ok(r_halo) = self.eval_subtype( UnificationPair {
lhs: a1[i].clone(),
rhs: a2[0].clone(),
halo: TypeTerm::unit(),
addr: new_addr
}) {
eprintln!("unified ladders at {}, r_halo = {:?}", i, r_halo);
for j in 0..a2.len() {
if i+j < a1.len() {
let mut new_addr = unification_pair.addr.clone();
new_addr.push(i+j);
let lhs = a1[i+j].clone();//.apply_substitution(&|k| self.σ.get(k).cloned()).clone();
let rhs = a2[j].clone();//.apply_substitution(&|k| self.σ.get(k).cloned()).clone();
if let Ok(rung_halo) = self.eval_subtype(
UnificationPair {
lhs: lhs.clone(), rhs: rhs.clone(),
addr: new_addr.clone(),
halo: TypeTerm::unit()
}
) {
if rung_halo != TypeTerm::unit() {
halo.push(rung_halo);
}
} else {
return Err(UnificationError{ addr: new_addr, t1: lhs, t2: rhs })
}
}
}
return Ok(
if halo.len() == 1 {
halo.pop().unwrap()
} else {
TypeTerm::Ladder(halo)
});
} else {
halo.push(a1[i].clone());
//eprintln!("could not unify ladders");
}
}
atomic => {
Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs })
},
(t, TypeTerm::Ladder(a1)) => {
Err(UnificationError{ addr: unification_pair.addr, t1: t, t2: TypeTerm::Ladder(a1) })
}
(TypeTerm::Ladder(mut a1), t) => {
let mut new_addr = unification_pair.addr.clone();
new_addr.push( a1.len() -1 );
if let Ok(halo) = self.eval_subtype(
UnificationPair {
lhs: a1.pop().unwrap(),
rhs: t.clone(),
halo: TypeTerm::unit(),
addr: new_addr
}
) {
a1.push(halo);
if a1.len() == 1 {
Ok(a1.pop().unwrap())
} else {
Ok(TypeTerm::Ladder(a1))
}
} else {
Err(UnificationError{ addr: unification_pair.addr, t1: TypeTerm::Ladder(a1), t2: t })
}
}
}
//let mut halo_type = TypeTerm::unit();
let mut halo_rungs = Vec::new();
while let Some( (mut lhs, mut rhs, mut addr) ) = self.eqs.pop() {
lhs.apply_substitution(&|v| self.σ.get(v).cloned());
rhs.apply_substitution(&|v| self.σ.get(v).cloned());
//eprintln!("eval subtype\n\tLHS={:?}\n\tRHS={:?}\n", lhs, rhs);
let mut new_halo = self.eval_subtype(lhs, rhs, addr.clone())?;
if new_halo.len() > 0 {
//insert_halo_at( &mut halo_type, addr, TypeTerm::Ladder(new_halo) );
if addr.len() == 0 {
halo_rungs.push(TypeTerm::Ladder(new_halo))
/*
Application
*/
(TypeTerm::App(a1), TypeTerm::App(a2)) => {
eprintln!("sub unify {:?}, {:?}", a1, a2);
if a1.len() == a2.len() {
let mut halo_args = Vec::new();
let mut n_halos_required = 0;
for (i, (mut x, mut y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate() {
let mut new_addr = unification_pair.addr.clone();
new_addr.push(i);
x = x.strip();
eprintln!("before strip: {:?}", y);
y = y.strip();
eprintln!("after strip: {:?}", y);
eprintln!("APP<> eval {:?} \n ?<=? {:?} ", x, y);
if let Ok(halo) = self.eval_subtype(
UnificationPair {
lhs: x.clone(),
rhs: y.clone(),
halo: TypeTerm::unit(),
addr: new_addr,
}
) {
if halo == TypeTerm::unit() {
let mut y = y.clone();
y.apply_substitution(&|k| self.σ.get(k).cloned());
y = y.strip();
let mut top = y.get_lnf_vec().first().unwrap().clone();
halo_args.push(top.clone());
eprintln!("add top {:?}", top);
} else {
eprintln!("add halo {:?}", halo);
if n_halos_required > 0 {
let x = &mut halo_args[n_halos_required-1];
if let TypeTerm::Ladder(argrs) = x {
let mut a = a2[n_halos_required-1].clone();
a.apply_substitution(&|k| self.σ.get(k).cloned());
a = a.get_lnf_vec().first().unwrap().clone();
argrs.push(a);
} else {
*x = TypeTerm::Ladder(vec![
x.clone(),
a2[n_halos_required-1].clone().get_lnf_vec().first().unwrap().clone()
]);
x.apply_substitution(&|k| self.σ.get(k).cloned());
}
}
halo_args.push(halo);
n_halos_required += 1;
}
} else {
return Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs });
}
}
if n_halos_required > 0 {
eprintln!("halo args : {:?}", halo_args);
Ok(TypeTerm::App(halo_args))
} else {
Ok(TypeTerm::unit())
}
} else {
Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs })
}
}
_ => Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs })
}
}
pub fn solve(mut self) -> Result<(Vec<TypeTerm>, HashMap<TypeID, TypeTerm>), UnificationError> {
// solve equations
while let Some( mut equal_pair ) = self.equal_pairs.pop() {
equal_pair.lhs.apply_substitution(&|v| self.σ.get(v).cloned());
equal_pair.rhs.apply_substitution(&|v| self.σ.get(v).cloned());
self.eval_equation(equal_pair)?;
}
let mut halo_type = TypeTerm::Ladder(halo_rungs);
halo_type = halo_type.normalize();
halo_type = halo_type.apply_substitution(&|k| self.σ.get(k).cloned()).clone();
// solve subtypes
eprintln!("------ SOLVE SUBTYPES ---- ");
for mut subtype_pair in self.subtype_pairs.clone().into_iter() {
subtype_pair.lhs.apply_substitution(&|v| self.σ.get(v).cloned());
subtype_pair.rhs.apply_substitution(&|v| self.σ.get(v).cloned());
let _halo = self.eval_subtype( subtype_pair.clone() )?.strip();
}
Ok((halo_type.param_normalize(), self.σ))
// add variables from subtype bounds
for (var_id, t) in self.upper_bounds.iter() {
eprintln!("VAR {} upper bound {:?}", var_id, t);
self.σ.insert(TypeID::Var(*var_id), t.clone().strip());
}
for (var_id, t) in self.lower_bounds.iter() {
eprintln!("VAR {} lower bound {:?}", var_id, t);
self.σ.insert(TypeID::Var(*var_id), t.clone().strip());
}
self.reapply_subst();
eprintln!("------ MAKE HALOS -----");
let mut halo_types = Vec::new();
for mut subtype_pair in self.subtype_pairs.clone().into_iter() {
subtype_pair.lhs = subtype_pair.lhs.apply_substitution(&|v| self.σ.get(v).cloned()).clone().strip();
subtype_pair.rhs = subtype_pair.rhs.apply_substitution(&|v| self.σ.get(v).cloned()).clone().strip();
let halo = self.eval_subtype( subtype_pair.clone() )?.strip();
halo_types.push(halo);
}
// solve traits
while let Some( trait_pair ) = self.trait_pairs.pop() {
unimplemented!();
}
Ok((halo_types, self.σ))
}
}
@ -264,16 +447,16 @@ pub fn unify(
t1: &TypeTerm,
t2: &TypeTerm
) -> Result<HashMap<TypeID, TypeTerm>, UnificationError> {
let mut unification = UnificationProblem::new(vec![ (t1.clone(), t2.clone()) ]);
unification.solve()
let unification = UnificationProblem::new_eq(vec![ (t1.clone(), t2.clone()) ]);
Ok(unification.solve()?.1)
}
pub fn subtype_unify(
t1: &TypeTerm,
t2: &TypeTerm
) -> Result<(TypeTerm, HashMap<TypeID, TypeTerm>), UnificationError> {
let mut unification = UnificationProblem::new(vec![ (t1.clone(), t2.clone()) ]);
unification.solve_subtype()
let unification = UnificationProblem::new_sub(vec![ (t1.clone(), t2.clone()) ]);
unification.solve().map( |(halos,σ)| ( halos.first().cloned().unwrap_or(TypeTerm::unit()), σ) )
}
//<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\