diff --git a/src/lib.rs b/src/lib.rs index a06c3d1..47e67aa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -46,6 +46,7 @@ pub fn common_halo( } if halo_rungs.len() == 0 { + None } else { Some(TypeTerm::Ladder(halo_rungs).normalize()) diff --git a/src/morphism.rs b/src/morphism.rs index 4a7d87e..4216b33 100644 --- a/src/morphism.rs +++ b/src/morphism.rs @@ -11,13 +11,22 @@ use { pub struct MorphismType { pub src_type: TypeTerm, pub dst_type: TypeTerm, +/* + pub subtype_bounds: Vec< (TypeTerm, TypeTerm) >, + pub trait_bounds: Vec< (TypeTerm, TypeTerm) >, + pub equal_bounds: Vec< (TypeTerm, TypeTerm) >, +*/ } impl MorphismType { pub fn normalize(self) -> Self { MorphismType { src_type: self.src_type.normalize().param_normalize(), - dst_type: self.dst_type.normalize().param_normalize() + dst_type: self.dst_type.normalize().param_normalize(), +/* + subtype_bounds: Vec::new(), + trait_bounds: Vec::new() + */ } } } @@ -55,7 +64,11 @@ impl<M: Morphism + Clone> MorphismInstance<M> { self.halo.clone(), self.m.get_type().dst_type.clone() ]).apply_substitution(&|k| self.σ.get(k).cloned()) - .clone() + .clone(), +/* + trait_bounds: Vec::new(), + subtype_bounds: Vec::new() + */ }.normalize() } } @@ -267,7 +280,8 @@ impl<M: Morphism + Clone> MorphismBase<M> { let dst_type = TypeTerm::Ladder(vec![ halo.clone(), morph_type.dst_type.clone() - ]); + ]).normalize().param_normalize(); + eprintln!("-----------> {} <= {}", dict.unparse(&dst_type), dict.unparse(&ty.dst_type) ); @@ -284,7 +298,7 @@ impl<M: Morphism + Clone> MorphismBase<M> { eprintln!("match. halo2 = {}", dict.unparse(&halo2)); return Some(MorphismInstance { m: m.clone(), - halo: halo2, + halo: halo, σ, }); } @@ -295,13 +309,13 @@ impl<M: Morphism + Clone> MorphismBase<M> { pub fn find_map_morphism(&self, ty: &MorphismType, dict: &mut impl TypeDict) -> Option< MorphismInstance<M> > { for seq_type in self.seq_types.iter() { - if let Ok((halo, σ)) = UnificationProblem::new(vec![ + if let Ok((halos, σ)) = UnificationProblem::new_sub(vec![ (ty.src_type.clone().param_normalize(), TypeTerm::App(vec![ seq_type.clone(), TypeTerm::TypeID(TypeID::Var(100)) ])), (TypeTerm::App(vec![ seq_type.clone(), TypeTerm::TypeID(TypeID::Var(101)) ]), ty.dst_type.clone().param_normalize()), - ]).solve_subtype() { + ]).solve() { // TODO: use real fresh variable names let item_morph_type = MorphismType { src_type: σ.get(&TypeID::Var(100)).unwrap().clone(), @@ -314,7 +328,7 @@ impl<M: Morphism + Clone> MorphismBase<M> { return Some( MorphismInstance { m: list_morph, σ, - halo + halo: halos[0].clone() } ); } } diff --git a/src/term.rs b/src/term.rs index c93160b..7f9354e 100644 --- a/src/term.rs +++ b/src/term.rs @@ -121,6 +121,38 @@ impl TypeTerm { self } + + /* strip away empty ladders + * & unwrap singletons + */ + pub fn strip(self) -> Self { + match self { + TypeTerm::Ladder(rungs) => { + let mut rungs :Vec<_> = rungs.into_iter() + .filter_map(|mut r| { + r = r.strip(); + if r != TypeTerm::unit() { Some(r) } + else { None } + }).collect(); + if rungs.len() == 1 { + rungs.pop().unwrap() + } else { + TypeTerm::Ladder(rungs) + } + }, + TypeTerm::App(args) => { + let mut args :Vec<_> = args.into_iter().map(|arg| arg.strip()).collect(); + if args.len() == 0 { + TypeTerm::unit() + } else if args.len() == 1 { + args.pop().unwrap() + } else { + TypeTerm::App(args) + } + } + atom => atom + } + } } //<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\ diff --git a/src/test/unification.rs b/src/test/unification.rs index 7811647..7a1acb3 100644 --- a/src/test/unification.rs +++ b/src/test/unification.rs @@ -97,11 +97,12 @@ fn test_unification() { dict.add_varname(String::from("W")); assert_eq!( - UnificationProblem::new(vec![ + UnificationProblem::new_eq(vec![ (dict.parse("U").unwrap(), dict.parse("<Seq Char>").unwrap()), (dict.parse("T").unwrap(), dict.parse("<Seq U>").unwrap()), ]).solve(), - Ok( + Ok(( + vec![], vec![ // T (TypeID::Var(0), dict.parse("<Seq <Seq Char>>").unwrap()), @@ -109,15 +110,16 @@ fn test_unification() { // U (TypeID::Var(1), dict.parse("<Seq Char>").unwrap()) ].into_iter().collect() - ) + )) ); assert_eq!( - UnificationProblem::new(vec![ + UnificationProblem::new_eq(vec![ (dict.parse("<Seq T>").unwrap(), dict.parse("<Seq W~<Seq Char>>").unwrap()), (dict.parse("<Seq ℕ>").unwrap(), dict.parse("<Seq W>").unwrap()), ]).solve(), - Ok( + Ok(( + vec![], vec![ // W (TypeID::Var(3), dict.parse("ℕ").unwrap()), @@ -125,7 +127,7 @@ fn test_unification() { // T (TypeID::Var(0), dict.parse("ℕ~<Seq Char>").unwrap()) ].into_iter().collect() - ) + )) ); } @@ -139,12 +141,14 @@ fn test_subtype_unification() { dict.add_varname(String::from("W")); assert_eq!( - UnificationProblem::new(vec![ + UnificationProblem::new_sub(vec![ (dict.parse("<Seq~T <Digit 10> ~ Char>").unwrap(), dict.parse("<Seq~<LengthPrefix x86.UInt64> Char ~ Ascii>").unwrap()), - ]).solve_subtype(), + ]).solve(), Ok(( - dict.parse("<Seq <Digit 10>>").unwrap(), + vec![ + dict.parse("<Seq <Digit 10>>").unwrap() + ], vec![ // T (TypeID::Var(0), dict.parse("<LengthPrefix x86.UInt64>").unwrap()) @@ -153,12 +157,15 @@ fn test_subtype_unification() { ); assert_eq!( - UnificationProblem::new(vec![ + UnificationProblem::new_sub(vec![ (dict.parse("U").unwrap(), dict.parse("<Seq Char>").unwrap()), (dict.parse("T").unwrap(), dict.parse("<Seq U>").unwrap()), - ]).solve_subtype(), + ]).solve(), Ok(( - TypeTerm::unit(), + vec![ + TypeTerm::unit(), + TypeTerm::unit(), + ], vec![ // T (TypeID::Var(0), dict.parse("<Seq <Seq Char>>").unwrap()), @@ -170,16 +177,17 @@ fn test_subtype_unification() { ); assert_eq!( - UnificationProblem::new(vec![ + UnificationProblem::new_sub(vec![ (dict.parse("<Seq T>").unwrap(), dict.parse("<Seq W~<Seq Char>>").unwrap()), (dict.parse("<Seq~<LengthPrefix x86.UInt64> ℕ~<PosInt 10 BigEndian>>").unwrap(), dict.parse("<<LengthPrefix x86.UInt64> W>").unwrap()), - ]).solve_subtype(), + ]).solve(), Ok(( - dict.parse(" - <Seq ℕ~<PosInt 10 BigEndian>> - ").unwrap(), + vec![ + TypeTerm::unit(), + dict.parse("<Seq ℕ>").unwrap(), + ], vec![ // W (TypeID::Var(3), dict.parse("ℕ~<PosInt 10 BigEndian>").unwrap()), @@ -189,4 +197,81 @@ fn test_subtype_unification() { ].into_iter().collect() )) ); + + assert_eq!( + subtype_unify( + &dict.parse("<Seq~List~Vec <Digit 16>~Char>").expect(""), + &dict.parse("<List~Vec Char>").expect("") + ), + Ok(( + dict.parse("<Seq~List <Digit 16>>").expect(""), + vec![].into_iter().collect() + )) + ); + + assert_eq!( + subtype_unify( + &dict.parse("ℕ ~ <PosInt 16 BigEndian> ~ <Seq~List~Vec <Digit 16>~Char>").expect(""), + &dict.parse("<List~Vec Char>").expect("") + ), + Ok(( + dict.parse("ℕ ~ <PosInt 16 BigEndian> ~ <Seq~List <Digit 16>>").expect(""), + vec![].into_iter().collect() + )) + ); +} + + +#[test] +pub fn test_subtype_delim() { + let mut dict = BimapTypeDict::new(); + + dict.add_varname(String::from("T")); + dict.add_varname(String::from("Delim")); + + assert_eq!( + UnificationProblem::new_sub(vec![ + + ( + //given type + dict.parse(" + < Seq <Seq <Digit 10>~Char~Ascii~UInt8> > + ~ < ValueSep ':' Char~Ascii~UInt8 > + ~ < Seq~<LengthPrefix UInt64> Char~Ascii~UInt8 > + ").expect(""), + + //expected type + dict.parse(" + < Seq <Seq T> > + ~ < ValueSep Delim T > + ~ < Seq~<LengthPrefix UInt64> T > + ").expect("") + ), + + // subtype bounds + ( + dict.parse("T").expect(""), + dict.parse("UInt8").expect("") + ), + /* todo + ( + dict.parse("<TypeOf Delim>").expect(""), + dict.parse("T").expect("") + ), + */ + ]).solve(), + Ok(( + // halo types for each rhs in the sub-equations + vec![ + dict.parse("<Seq <Seq <Digit 10>>>").expect(""), + dict.parse("Char~Ascii").expect(""), + ], + + // variable substitution + vec![ + (dict.get_typeid(&"T".into()).unwrap(), dict.parse("Char~Ascii~UInt8").expect("")), + (dict.get_typeid(&"Delim".into()).unwrap(), TypeTerm::Char(':')), + ].into_iter().collect() + )) + ); } diff --git a/src/unification.rs b/src/unification.rs index 850d76c..03c3699 100644 --- a/src/unification.rs +++ b/src/unification.rs @@ -1,6 +1,5 @@ use { - std::collections::HashMap, - crate::{term::*, dict::*} + crate::{dict::*, term::*}, std::{collections::HashMap, env::consts::ARCH} }; //<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\ @@ -12,21 +11,71 @@ pub struct UnificationError { pub t2: TypeTerm } +#[derive(Clone, Debug)] +pub struct UnificationPair { + addr: Vec<usize>, + halo: TypeTerm, + + lhs: TypeTerm, + rhs: TypeTerm, +} + +#[derive(Debug)] pub struct UnificationProblem { - eqs: Vec<(TypeTerm, TypeTerm, Vec<usize>)>, - σ: HashMap<TypeID, TypeTerm> + σ: HashMap<TypeID, TypeTerm>, + upper_bounds: HashMap< u64, TypeTerm >, + lower_bounds: HashMap< u64, TypeTerm >, + equal_pairs: Vec<UnificationPair>, + subtype_pairs: Vec<UnificationPair>, + trait_pairs: Vec<UnificationPair> } impl UnificationProblem { - pub fn new(eqs: Vec<(TypeTerm, TypeTerm)>) -> Self { + pub fn new( + equal_pairs: Vec<(TypeTerm, TypeTerm)>, + subtype_pairs: Vec<(TypeTerm, TypeTerm)>, + trait_pairs: Vec<(TypeTerm, TypeTerm)> + ) -> Self { UnificationProblem { - eqs: eqs.iter().map(|(lhs,rhs)| (lhs.clone(),rhs.clone(),vec![])).collect(), - σ: HashMap::new() + σ: HashMap::new(), + + equal_pairs: equal_pairs.into_iter().map(|(lhs,rhs)| + UnificationPair{ + lhs,rhs, + halo: TypeTerm::unit(), + addr: Vec::new() + }).collect(), + + subtype_pairs: subtype_pairs.into_iter().map(|(lhs,rhs)| + UnificationPair{ + lhs,rhs, + halo: TypeTerm::unit(), + addr: Vec::new() + }).collect(), + + trait_pairs: trait_pairs.into_iter().map(|(lhs,rhs)| + UnificationPair{ + lhs,rhs, + halo: TypeTerm::unit(), + addr: Vec::new() + }).collect(), + + upper_bounds: HashMap::new(), + lower_bounds: HashMap::new(), } } + pub fn new_eq(eqs: Vec<(TypeTerm, TypeTerm)>) -> Self { + UnificationProblem::new( eqs, Vec::new(), Vec::new() ) + } + + pub fn new_sub(subs: Vec<(TypeTerm, TypeTerm)>) -> Self { + UnificationProblem::new( Vec::new(), subs, Vec::new() ) + } + + + /// update all values in substitution pub fn reapply_subst(&mut self) { - // update all values in substitution let mut new_σ = HashMap::new(); for (v, tt) in self.σ.iter() { let mut tt = tt.clone().normalize(); @@ -38,225 +87,359 @@ impl UnificationProblem { self.σ = new_σ; } - pub fn eval_subtype(&mut self, lhs: TypeTerm, rhs: TypeTerm, addr: Vec<usize>) -> Result<Vec<TypeTerm>, UnificationError> { - match (lhs.clone(), rhs.clone()) { + + pub fn eval_equation(&mut self, unification_pair: UnificationPair) -> Result<(), UnificationError> { + match (&unification_pair.lhs, &unification_pair.rhs) { (TypeTerm::TypeID(TypeID::Var(varid)), t) | (t, TypeTerm::TypeID(TypeID::Var(varid))) => { - if ! t.contains_var( varid ) { - self.σ.insert(TypeID::Var(varid), t.clone()); + if ! t.contains_var( *varid ) { + self.σ.insert(TypeID::Var(*varid), t.clone()); self.reapply_subst(); - Ok(vec![]) - } else if t == TypeTerm::TypeID(TypeID::Var(varid)) { - Ok(vec![]) + Ok(()) + } else if t == &TypeTerm::TypeID(TypeID::Var(*varid)) { + Ok(()) } else { - Err(UnificationError{ addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t }) + Err(UnificationError{ addr: unification_pair.addr, t1: TypeTerm::TypeID(TypeID::Var(*varid)), t2: t.clone() }) } } (TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => { - if a1 == a2 { Ok(vec![]) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } + if a1 == a2 { Ok(()) } else { Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs }) } } (TypeTerm::Num(n1), TypeTerm::Num(n2)) => { - if n1 == n2 { Ok(vec![]) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } + if n1 == n2 { Ok(()) } else { Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs }) } } (TypeTerm::Char(c1), TypeTerm::Char(c2)) => { - if c1 == c2 { Ok(vec![]) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } - } - - (TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) => { - let mut halo = Vec::new(); - for i in 0..a1.len() { - if let Ok((r_halo, σ)) = subtype_unify( &a1[i], &a2[0] ) { - //eprintln!("unified ladders at {}, r_halo = {:?}", i, r_halo); - for (k,v) in σ.iter() { - self.σ.insert(k.clone(),v.clone()); - } - - for j in 0..a2.len() { - if i+j < a1.len() { - let mut new_addr = addr.clone(); - new_addr.push(i+j); - self.eqs.push((a1[i+j].clone().apply_substitution(&|k| σ.get(k).cloned()).clone(), - a2[j].clone().apply_substitution(&|k| σ.get(k).cloned()).clone(), - new_addr)); - } - } - return Ok(halo) - } else { - halo.push(a1[i].clone()); - //eprintln!("could not unify ladders"); - } - } - - Err(UnificationError{ addr, t1: lhs, t2: rhs }) - }, - - (t, TypeTerm::Ladder(mut a1)) => { - /* - if let Ok(mut halo) = self.eval_subtype( t.clone(), a1.first().unwrap().clone(), addr.clone() ) - a1.append(&mut halo); - Ok(a1) - } else { - */ - Err(UnificationError{ addr, t1: t, t2: TypeTerm::Ladder(a1) }) - //} - } - - (TypeTerm::Ladder(mut a1), t) => { - if let Ok(mut halo) = self.eval_subtype( a1.pop().unwrap(), t.clone(), addr.clone() ) { - - a1.append(&mut halo); - Ok(a1) - } else { - Err(UnificationError{ addr, t1: TypeTerm::Ladder(a1), t2: t }) - } - } - - (TypeTerm::App(a1), TypeTerm::App(a2)) => { - if a1.len() == a2.len() { - let mut halo_args = Vec::new(); - let mut halo_required = false; - - for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate() { - let mut new_addr = addr.clone(); - new_addr.push(i); - //self.eqs.push((x, y, new_addr)); - - if let Ok(halo) = self.eval_subtype( x.clone(), y.clone(), new_addr ) { - if halo.len() == 0 { - halo_args.push(y.get_lnf_vec().first().unwrap().clone()); - } else { - halo_args.push(TypeTerm::Ladder(halo)); - halo_required = true; - } - } else { - return Err(UnificationError{ addr, t1: x, t2: y }) - } - } - - if halo_required { - Ok(vec![ TypeTerm::App(halo_args) ]) - } else { - Ok(vec![]) - } - } else { - Err(UnificationError{ addr, t1: lhs, t2: rhs }) - } - } - - _ => Err(UnificationError{ addr, t1: lhs, t2: rhs}) - } - } - - pub fn eval_equation(&mut self, lhs: TypeTerm, rhs: TypeTerm, addr: Vec<usize>) -> Result<(), UnificationError> { - match (lhs.clone(), rhs.clone()) { - (TypeTerm::TypeID(TypeID::Var(varid)), t) | - (t, TypeTerm::TypeID(TypeID::Var(varid))) => { - if ! t.contains_var( varid ) { - self.σ.insert(TypeID::Var(varid), t.clone()); - self.reapply_subst(); - Ok(()) - } else if t == TypeTerm::TypeID(TypeID::Var(varid)) { - Ok(()) - } else { - Err(UnificationError{ addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t }) - } - } - - (TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => { - if a1 == a2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } - } - (TypeTerm::Num(n1), TypeTerm::Num(n2)) => { - if n1 == n2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } - } - (TypeTerm::Char(c1), TypeTerm::Char(c2)) => { - if c1 == c2 { Ok(()) } else { Err(UnificationError{ addr, t1: lhs, t2: rhs}) } + if c1 == c2 { Ok(()) } else { Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs }) } } (TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) | (TypeTerm::App(a1), TypeTerm::App(a2)) => { if a1.len() == a2.len() { for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate().rev() { - let mut new_addr = addr.clone(); + let mut new_addr = unification_pair.addr.clone(); new_addr.push(i); - self.eqs.push((x, y, new_addr)); + self.equal_pairs.push( + UnificationPair { + lhs: x, + rhs: y, + halo: TypeTerm::unit(), + addr: new_addr + }); } Ok(()) } else { - Err(UnificationError{ addr, t1: lhs, t2: rhs }) + Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs }) } } - _ => Err(UnificationError{ addr, t1: lhs, t2: rhs}) + _ => Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs }) } } - pub fn solve(mut self) -> Result<HashMap<TypeID, TypeTerm>, UnificationError> { - while let Some( (mut lhs,mut rhs,addr) ) = self.eqs.pop() { - lhs.apply_substitution(&|v| self.σ.get(v).cloned()); - rhs.apply_substitution(&|v| self.σ.get(v).cloned()); - self.eval_equation(lhs, rhs, addr)?; - } - Ok(self.σ) - } + pub fn eval_subtype(&mut self, unification_pair: UnificationPair) -> Result< + // ok: halo type + TypeTerm, + // error + UnificationError + > { + match (unification_pair.lhs.clone(), unification_pair.rhs.clone()) { - pub fn solve_subtype(mut self) -> Result<(TypeTerm, HashMap<TypeID, TypeTerm>), UnificationError> { + /* + Variables + */ - pub fn insert_halo_at( - t: &mut TypeTerm, - mut addr: Vec<usize>, - halo_type: TypeTerm - ) { - match t { - TypeTerm::Ladder(rungs) => { - if let Some(idx) = addr.pop() { - for i in rungs.len()..idx+1 { - rungs.push(TypeTerm::unit()); + (TypeTerm::TypeID(TypeID::Var(varid)), t) => { + eprintln!("variable <= t"); + + if let Some(upper_bound) = self.upper_bounds.get(&varid).cloned() { + if let Ok(_halo) = self.eval_subtype( + UnificationPair { + lhs: t.clone(), + rhs: upper_bound, + + halo: TypeTerm::unit(), + addr: vec![] } - insert_halo_at( &mut rungs[idx], addr, halo_type ); - } else { - rungs.push(halo_type); + ) { + // found a lower upper bound + self.upper_bounds.insert(varid, t); } - }, + } else { + self.upper_bounds.insert(varid, t); + } + Ok(TypeTerm::unit()) + } - TypeTerm::App(args) => { - if let Some(idx) = addr.pop() { - insert_halo_at( &mut args[idx], addr, halo_type ); + (t, TypeTerm::TypeID(TypeID::Var(varid))) => { + eprintln!("t <= variable"); + if ! t.contains_var( varid ) { + // let x = self.σ.get(&TypeID::Var(varid)).cloned(); + if let Some(lower_bound) = self.lower_bounds.get(&varid).cloned() { + eprintln!("var already exists. check max. type"); + if let Ok(halo) = self.eval_subtype( + UnificationPair { + lhs: lower_bound.clone(), + rhs: t.clone(), + halo: TypeTerm::unit(), + addr: vec![] + } + ) { + eprintln!("found more general lower bound"); + eprintln!("set var {}'s lowerbound to {:?}", varid, t.clone()); + // generalize variable type to supertype + self.lower_bounds.insert(varid, t.clone()); + } else if let Ok(halo) = self.eval_subtype( + UnificationPair{ + lhs: t.clone(), + rhs: lower_bound.clone(), + halo: TypeTerm::unit(), + addr: vec![] + } + ) { + eprintln!("OK, is already larger type"); + } else { + eprintln!("violated subtype restriction"); + return Err(UnificationError{ addr: unification_pair.addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t }); + } } else { - *t = TypeTerm::Ladder(vec![ - halo_type, - t.clone() - ]); + eprintln!("set var {}'s lowerbound to {:?}", varid, t.clone()); + self.lower_bounds.insert(varid, t.clone()); + } + self.reapply_subst(); + Ok(TypeTerm::unit()) + } else if t == TypeTerm::TypeID(TypeID::Var(varid)) { + Ok(TypeTerm::unit()) + } else { + Err(UnificationError{ addr: unification_pair.addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t }) + } + } + + + /* + Atoms + */ + + (TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => { + if a1 == a2 { Ok(TypeTerm::unit()) } else { Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs}) } + } + (TypeTerm::Num(n1), TypeTerm::Num(n2)) => { + if n1 == n2 { Ok(TypeTerm::unit()) } else { Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs }) } + } + (TypeTerm::Char(c1), TypeTerm::Char(c2)) => { + if c1 == c2 { Ok(TypeTerm::unit()) } else { Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs }) } + } + + + /* + Ladders + */ + + (TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) => { + let mut halo = Vec::new(); + for i in 0..a1.len() { + let mut new_addr = unification_pair.addr.clone(); + new_addr.push(i); + if let Ok(r_halo) = self.eval_subtype( UnificationPair { + lhs: a1[i].clone(), + rhs: a2[0].clone(), + + halo: TypeTerm::unit(), + addr: new_addr + }) { + eprintln!("unified ladders at {}, r_halo = {:?}", i, r_halo); + + for j in 0..a2.len() { + if i+j < a1.len() { + let mut new_addr = unification_pair.addr.clone(); + new_addr.push(i+j); + + let lhs = a1[i+j].clone();//.apply_substitution(&|k| self.σ.get(k).cloned()).clone(); + let rhs = a2[j].clone();//.apply_substitution(&|k| self.σ.get(k).cloned()).clone(); + + if let Ok(rung_halo) = self.eval_subtype( + UnificationPair { + lhs: lhs.clone(), rhs: rhs.clone(), + addr: new_addr.clone(), + halo: TypeTerm::unit() + } + ) { + if rung_halo != TypeTerm::unit() { + halo.push(rung_halo); + } + } else { + return Err(UnificationError{ addr: new_addr, t1: lhs, t2: rhs }) + } + } + } + + return Ok( + if halo.len() == 1 { + halo.pop().unwrap() + } else { + TypeTerm::Ladder(halo) + }); + } else { + halo.push(a1[i].clone()); + //eprintln!("could not unify ladders"); } } - atomic => { + Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs }) + }, + (t, TypeTerm::Ladder(a1)) => { + Err(UnificationError{ addr: unification_pair.addr, t1: t, t2: TypeTerm::Ladder(a1) }) + } + + (TypeTerm::Ladder(mut a1), t) => { + let mut new_addr = unification_pair.addr.clone(); + new_addr.push( a1.len() -1 ); + if let Ok(halo) = self.eval_subtype( + UnificationPair { + lhs: a1.pop().unwrap(), + rhs: t.clone(), + halo: TypeTerm::unit(), + addr: new_addr + } + ) { + a1.push(halo); + if a1.len() == 1 { + Ok(a1.pop().unwrap()) + } else { + Ok(TypeTerm::Ladder(a1)) + } + } else { + Err(UnificationError{ addr: unification_pair.addr, t1: TypeTerm::Ladder(a1), t2: t }) } } - } - //let mut halo_type = TypeTerm::unit(); - let mut halo_rungs = Vec::new(); - while let Some( (mut lhs, mut rhs, mut addr) ) = self.eqs.pop() { - lhs.apply_substitution(&|v| self.σ.get(v).cloned()); - rhs.apply_substitution(&|v| self.σ.get(v).cloned()); - //eprintln!("eval subtype\n\tLHS={:?}\n\tRHS={:?}\n", lhs, rhs); - let mut new_halo = self.eval_subtype(lhs, rhs, addr.clone())?; - if new_halo.len() > 0 { - //insert_halo_at( &mut halo_type, addr, TypeTerm::Ladder(new_halo) ); - if addr.len() == 0 { - halo_rungs.push(TypeTerm::Ladder(new_halo)) + /* + Application + */ + + (TypeTerm::App(a1), TypeTerm::App(a2)) => { + eprintln!("sub unify {:?}, {:?}", a1, a2); + if a1.len() == a2.len() { + let mut halo_args = Vec::new(); + let mut n_halos_required = 0; + + for (i, (mut x, mut y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate() { + let mut new_addr = unification_pair.addr.clone(); + new_addr.push(i); + + x = x.strip(); + + eprintln!("before strip: {:?}", y); + y = y.strip(); + eprintln!("after strip: {:?}", y); + + eprintln!("APP<> eval {:?} \n ?<=? {:?} ", x, y); + + if let Ok(halo) = self.eval_subtype( + UnificationPair { + lhs: x.clone(), + rhs: y.clone(), + halo: TypeTerm::unit(), + addr: new_addr, + } + ) { + if halo == TypeTerm::unit() { + let mut y = y.clone(); + y.apply_substitution(&|k| self.σ.get(k).cloned()); + y = y.strip(); + let mut top = y.get_lnf_vec().first().unwrap().clone(); + halo_args.push(top.clone()); + eprintln!("add top {:?}", top); + } else { + eprintln!("add halo {:?}", halo); + if n_halos_required > 0 { + let x = &mut halo_args[n_halos_required-1]; + if let TypeTerm::Ladder(argrs) = x { + let mut a = a2[n_halos_required-1].clone(); + a.apply_substitution(&|k| self.σ.get(k).cloned()); + a = a.get_lnf_vec().first().unwrap().clone(); + argrs.push(a); + } else { + *x = TypeTerm::Ladder(vec![ + x.clone(), + a2[n_halos_required-1].clone().get_lnf_vec().first().unwrap().clone() + ]); + + x.apply_substitution(&|k| self.σ.get(k).cloned()); + } + } + + halo_args.push(halo); + n_halos_required += 1; + } + } else { + return Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs }); + } + } + + if n_halos_required > 0 { + eprintln!("halo args : {:?}", halo_args); + Ok(TypeTerm::App(halo_args)) + } else { + Ok(TypeTerm::unit()) + } + } else { + Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs }) } } + + _ => Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs }) + } + } + + pub fn solve(mut self) -> Result<(Vec<TypeTerm>, HashMap<TypeID, TypeTerm>), UnificationError> { + // solve equations + while let Some( mut equal_pair ) = self.equal_pairs.pop() { + equal_pair.lhs.apply_substitution(&|v| self.σ.get(v).cloned()); + equal_pair.rhs.apply_substitution(&|v| self.σ.get(v).cloned()); + + self.eval_equation(equal_pair)?; } - let mut halo_type = TypeTerm::Ladder(halo_rungs); - halo_type = halo_type.normalize(); - halo_type = halo_type.apply_substitution(&|k| self.σ.get(k).cloned()).clone(); + // solve subtypes + eprintln!("------ SOLVE SUBTYPES ---- "); + for mut subtype_pair in self.subtype_pairs.clone().into_iter() { + subtype_pair.lhs.apply_substitution(&|v| self.σ.get(v).cloned()); + subtype_pair.rhs.apply_substitution(&|v| self.σ.get(v).cloned()); + let _halo = self.eval_subtype( subtype_pair.clone() )?.strip(); + } - Ok((halo_type.param_normalize(), self.σ)) + // add variables from subtype bounds + for (var_id, t) in self.upper_bounds.iter() { + eprintln!("VAR {} upper bound {:?}", var_id, t); + self.σ.insert(TypeID::Var(*var_id), t.clone().strip()); + } + + for (var_id, t) in self.lower_bounds.iter() { + eprintln!("VAR {} lower bound {:?}", var_id, t); + self.σ.insert(TypeID::Var(*var_id), t.clone().strip()); + } + + self.reapply_subst(); + + eprintln!("------ MAKE HALOS -----"); + let mut halo_types = Vec::new(); + for mut subtype_pair in self.subtype_pairs.clone().into_iter() { + subtype_pair.lhs = subtype_pair.lhs.apply_substitution(&|v| self.σ.get(v).cloned()).clone().strip(); + subtype_pair.rhs = subtype_pair.rhs.apply_substitution(&|v| self.σ.get(v).cloned()).clone().strip(); + + let halo = self.eval_subtype( subtype_pair.clone() )?.strip(); + halo_types.push(halo); + } + + // solve traits + while let Some( trait_pair ) = self.trait_pairs.pop() { + unimplemented!(); + } + + Ok((halo_types, self.σ)) } } @@ -264,16 +447,16 @@ pub fn unify( t1: &TypeTerm, t2: &TypeTerm ) -> Result<HashMap<TypeID, TypeTerm>, UnificationError> { - let mut unification = UnificationProblem::new(vec![ (t1.clone(), t2.clone()) ]); - unification.solve() + let unification = UnificationProblem::new_eq(vec![ (t1.clone(), t2.clone()) ]); + Ok(unification.solve()?.1) } pub fn subtype_unify( t1: &TypeTerm, t2: &TypeTerm ) -> Result<(TypeTerm, HashMap<TypeID, TypeTerm>), UnificationError> { - let mut unification = UnificationProblem::new(vec![ (t1.clone(), t2.clone()) ]); - unification.solve_subtype() + let unification = UnificationProblem::new_sub(vec![ (t1.clone(), t2.clone()) ]); + unification.solve().map( |(halos,σ)| ( halos.first().cloned().unwrap_or(TypeTerm::unit()), σ) ) } //<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\