diff --git a/src/morphism.rs b/src/morphism.rs index 7b2ac5e..9f29ba8 100644 --- a/src/morphism.rs +++ b/src/morphism.rs @@ -21,8 +21,8 @@ pub struct MorphismType { impl MorphismType { pub fn normalize(self) -> Self { MorphismType { - src_type: self.src_type.normalize().param_normalize(), - dst_type: self.dst_type.normalize().param_normalize(), + src_type: self.src_type.strip().param_normalize(), + dst_type: self.dst_type.strip().param_normalize(), /* subtype_bounds: Vec::new(), trait_bounds: Vec::new() @@ -149,7 +149,7 @@ impl<M: Morphism + Clone> MorphismBase<M> { if let Some( map_morph ) = item_morph_inst.m.map_morphism( seq_type.clone() ) { dst_types.push( MorphismInstance { - halo: TypeTerm::Ladder(dst_halo_ladder).normalize(), + halo: TypeTerm::Ladder(dst_halo_ladder).strip().param_normalize(), m: map_morph, σ: item_morph_inst.σ } @@ -221,7 +221,7 @@ impl<M: Morphism + Clone> MorphismBase<M> { n.halo = n.halo.clone().apply_substitution( &|k| σ.get(k).cloned() - ).clone().strip(); + ).clone().strip().param_normalize(); n.σ = new_σ; } @@ -256,7 +256,7 @@ impl<M: Morphism + Clone> MorphismBase<M> { n.halo = n.halo.clone().apply_substitution( &|k| next_morph_inst.σ.get(k).cloned() - ).clone().strip(); + ).clone().strip().param_normalize(); n.σ = new_σ; } diff --git a/src/test/morphism.rs b/src/test/morphism.rs index 294ac00..e6db4a3 100644 --- a/src/test/morphism.rs +++ b/src/test/morphism.rs @@ -330,51 +330,6 @@ fn test_morphism_path_posint() { ] ) ); - - - -/* - assert_eq!( - base.find_morphism_path(MorphismType { - src_type: dict.parse("Symbol ~ ℕ ~ <PosInt 10 BigEndian> ~ <Seq <Digit 10> ~ Char>").unwrap(), - dst_type: dict.parse("Symbol ~ ℕ ~ <PosInt 16 BigEndian> ~ <Seq <Digit 16> ~ Char>").unwrap() - }), - Some( - vec![ - dict.parse("Symbol ~ ℕ ~ <PosInt 10 BigEndian> ~ <Seq <Digit 10> ~ Char>").unwrap().normalize(), - dict.parse("Symbol ~ ℕ ~ <PosInt 10 BigEndian> ~ <Seq <Digit 10> ~ ℤ_2^64 ~ machine.UInt64>").unwrap().normalize(), - dict.parse("Symbol ~ ℕ ~ <PosInt 10 LittleEndian> ~ <Seq <Digit 10> ~ ℤ_2^64 ~ machine.UInt64>").unwrap().normalize(), - dict.parse("Symbol ~ ℕ ~ <PosInt 16 LittleEndian> ~ <Seq <Digit 16> ~ ℤ_2^64 ~ machine.UInt64>").unwrap().normalize(), - dict.parse("Symbol ~ ℕ ~ <PosInt 16 BigEndian> ~ <Seq <Digit 16> ~ ℤ_2^64 ~ machine.UInt64>").unwrap().normalize(), - dict.parse("Symbol ~ ℕ ~ <PosInt 16 BigEndian> ~ <Seq <Digit 16> ~ Char>").unwrap().normalize(), - ] - ) - ); - */ -/* - assert_eq!( - base.find_morphism_with_subtyping( - &MorphismType { - src_type: dict.parse("Symbol ~ ℕ ~ <PosInt 10 BigEndian> ~ <Seq <Digit 10> ~ Char>").unwrap(), - dst_type: dict.parse("Symbol ~ ℕ ~ <PosInt 10 BigEndian> ~ <Seq <Digit 10> ~ ℤ_2^64 ~ machine.UInt64>").unwrap() - } - ), - - Some(( - DummyMorphism(MorphismType{ - src_type: dict.parse("<Seq <Digit Radix> ~ Char>").unwrap(), - dst_type: dict.parse("<Seq <Digit Radix> ~ ℤ_2^64 ~ machine.UInt64>").unwrap() - }), - - dict.parse("Symbol ~ ℕ ~ <PosInt 10 BigEndian> ~ <Seq <Digit 10>>").unwrap(), - - vec![ - (dict.get_typeid(&"Radix".into()).unwrap(), - dict.parse("10").unwrap()) - ].into_iter().collect::<std::collections::HashMap<TypeID, TypeTerm>>() - )) - ); -*/ } #[test] @@ -421,8 +376,20 @@ fn test_morphism_path_listedit() ); base.add_morphism( DummyMorphism(MorphismType{ - src_type: dict.parse("<List Char>").unwrap(), - dst_type: dict.parse("<List Char~ReprTree>").unwrap() + src_type: dict.parse("Char").unwrap(), + dst_type: dict.parse("Char ~ ReprTree").unwrap() + }) + ); + base.add_morphism( + DummyMorphism(MorphismType{ + src_type: dict.parse("Char ~ ReprTree").unwrap(), + dst_type: dict.parse("Char").unwrap() + }) + ); + base.add_morphism( + DummyMorphism(MorphismType{ + src_type: dict.parse("<List~Vec Char>").unwrap(), + dst_type: dict.parse("<List Char>").unwrap() }) ); base.add_morphism( @@ -483,12 +450,12 @@ fn test_morphism_path_listedit() }, MorphismInstance { m: DummyMorphism(MorphismType{ - src_type: dict.parse("<List~Vec ReprTree>").unwrap(), - dst_type: dict.parse("<List ReprTree> ~ EditTree").unwrap() + src_type: dict.parse("<List~Vec Char~ReprTree>").unwrap(), + dst_type: dict.parse("<List Char> ~ EditTree").unwrap() }), - halo: dict.parse("<Seq~List <Digit 10>~Char>").unwrap(), + halo: dict.parse("<Seq~List <Digit 10>>").unwrap(), σ: HashMap::new() - } + }, ]) ); } diff --git a/src/test/unification.rs b/src/test/unification.rs index de31d5b..6021dda 100644 --- a/src/test/unification.rs +++ b/src/test/unification.rs @@ -132,7 +132,61 @@ fn test_unification() { } #[test] -fn test_subtype_unification() { +fn test_subtype_unification1() { + let mut dict = BimapTypeDict::new(); + dict.add_varname(String::from("T")); + + assert_eq!( + UnificationProblem::new_sub(vec![ + (dict.parse("A ~ B").unwrap(), + dict.parse("B").unwrap()), + ]).solve(), + Ok(( + vec![ dict.parse("A").unwrap() ], + vec![].into_iter().collect() + )) + ); + + assert_eq!( + UnificationProblem::new_sub(vec![ + (dict.parse("A ~ B ~ C ~ D").unwrap(), + dict.parse("C ~ D").unwrap()), + ]).solve(), + Ok(( + vec![ dict.parse("A ~ B").unwrap() ], + vec![].into_iter().collect() + )) + ); + + assert_eq!( + UnificationProblem::new_sub(vec![ + (dict.parse("A ~ B ~ C ~ D").unwrap(), + dict.parse("T ~ D").unwrap()), + ]).solve(), + Ok(( + vec![ TypeTerm::unit() ], + vec![ + (dict.get_typeid(&"T".into()).unwrap(), dict.parse("A ~ B ~ C").unwrap()) + ].into_iter().collect() + )) + ); + + assert_eq!( + UnificationProblem::new_sub(vec![ + (dict.parse("A ~ B ~ C ~ D").unwrap(), + dict.parse("B ~ T ~ D").unwrap()), + ]).solve(), + Ok(( + vec![ dict.parse("A").unwrap() ], + vec![ + (dict.get_typeid(&"T".into()).unwrap(), dict.parse("C").unwrap()) + ].into_iter().collect() + )) + ); +} + +#[test] +fn test_subtype_unification2() { let mut dict = BimapTypeDict::new(); dict.add_varname(String::from("T")); @@ -142,7 +196,7 @@ fn test_subtype_unification() { assert_eq!( UnificationProblem::new_sub(vec![ - (dict.parse("<Seq~T <Digit 10> ~ Char>").unwrap(), + (dict.parse("<Seq~T <Digit 10> ~ Char ~ Ascii>").unwrap(), dict.parse("<Seq~<LengthPrefix x86.UInt64> Char ~ Ascii>").unwrap()), ]).solve(), Ok(( @@ -232,9 +286,9 @@ fn test_trait_not_subtype() { &dict.parse("A ~ B ~ C").expect("") ), Err(UnificationError { - addr: vec![], - t1: dict.parse("A ~ B").expect(""), - t2: dict.parse("A ~ B ~ C").expect("") + addr: vec![1], + t1: dict.parse("B").expect(""), + t2: dict.parse("C").expect("") }) ); @@ -244,13 +298,33 @@ fn test_trait_not_subtype() { &dict.parse("<Seq~List~Vec Char~ReprTree>").expect("") ), Err(UnificationError { - addr: vec![1], - t1: dict.parse("<Digit 10> ~ Char").expect(""), - t2: dict.parse("Char ~ ReprTree").expect("") + addr: vec![1,1], + t1: dict.parse("Char").expect(""), + t2: dict.parse("ReprTree").expect("") }) ); } +#[test] +fn test_reprtree_list_subtype() { + let mut dict = BimapTypeDict::new(); + + dict.add_varname("Item".into()); + + assert_eq!( + subtype_unify( + &dict.parse("<List~Vec <Digit 10>~Char~ReprTree>").expect(""), + &dict.parse("<List~Vec Item~ReprTree>").expect("") + ), + Ok(( + TypeTerm::unit(), + vec![ + (dict.get_typeid(&"Item".into()).unwrap(), dict.parse("<Digit 10>~Char").unwrap()) + ].into_iter().collect() + )) + ); +} + #[test] pub fn test_subtype_delim() { let mut dict = BimapTypeDict::new(); diff --git a/src/unification.rs b/src/unification.rs index eb55646..767bbde 100644 --- a/src/unification.rs +++ b/src/unification.rs @@ -137,85 +137,119 @@ impl UnificationProblem { } } + + + pub fn add_lower_subtype_bound(&mut self, v: u64, new_lower_bound: TypeTerm) -> Result<(),()> { + + if new_lower_bound == TypeTerm::TypeID(TypeID::Var(v)) { + return Ok(()); + } + + if new_lower_bound.contains_var(v) { + // loop + return Err(()); + } + + if let Some(lower_bound) = self.lower_bounds.get(&v).cloned() { +// eprintln!("var already exists. check max. type"); + if let Ok(halo) = self.eval_subtype( + UnificationPair { + lhs: lower_bound.clone(), + rhs: new_lower_bound.clone(), + halo: TypeTerm::unit(), + addr: vec![] + } + ) { +// eprintln!("found more general lower bound"); +// eprintln!("set var {}'s lowerbound to {:?}", varid, t.clone()); + // generalize variable type to supertype + self.lower_bounds.insert(v, new_lower_bound); + Ok(()) + } else if let Ok(halo) = self.eval_subtype( + UnificationPair{ + lhs: new_lower_bound, + rhs: lower_bound, + halo: TypeTerm::unit(), + addr: vec![] + } + ) { +// eprintln!("OK, is already larger type"); + Ok(()) + } else { +// eprintln!("violated subtype restriction"); + Err(()) + } + } else { +// eprintln!("set var {}'s lowerbound to {:?}", varid, t.clone()); + self.lower_bounds.insert(v, new_lower_bound); + Ok(()) + } + } + + + pub fn add_upper_subtype_bound(&mut self, v: u64, new_upper_bound: TypeTerm) -> Result<(),()> { + if new_upper_bound == TypeTerm::TypeID(TypeID::Var(v)) { + return Ok(()); + } + + if new_upper_bound.contains_var(v) { + // loop + return Err(()); + } + + if let Some(upper_bound) = self.upper_bounds.get(&v).cloned() { + if let Ok(_halo) = self.eval_subtype( + UnificationPair { + lhs: new_upper_bound.clone(), + rhs: upper_bound, + halo: TypeTerm::unit(), + addr: vec![] + } + ) { + // found a lower upper bound + self.upper_bounds.insert(v, new_upper_bound); + Ok(()) + } else { + Err(()) + } + } else { + self.upper_bounds.insert(v, new_upper_bound); + Ok(()) + } + } + pub fn eval_subtype(&mut self, unification_pair: UnificationPair) -> Result< // ok: halo type TypeTerm, // error UnificationError > { + eprintln!("eval_subtype {:?} <=? {:?}", unification_pair.lhs, unification_pair.rhs); match (unification_pair.lhs.clone(), unification_pair.rhs.clone()) { /* Variables */ - (t, TypeTerm::TypeID(TypeID::Var(varid))) => { -// eprintln!("t <= variable"); - if ! t.contains_var( varid ) { - // let x = self.σ.get(&TypeID::Var(varid)).cloned(); - if let Some(lower_bound) = self.lower_bounds.get(&varid).cloned() { -// eprintln!("var already exists. check max. type"); - if let Ok(halo) = self.eval_subtype( - UnificationPair { - lhs: lower_bound.clone(), - rhs: t.clone(), - halo: TypeTerm::unit(), - addr: vec![] - } - ) { -// eprintln!("found more general lower bound"); -// eprintln!("set var {}'s lowerbound to {:?}", varid, t.clone()); - // generalize variable type to supertype - self.lower_bounds.insert(varid, t.clone()); - } else if let Ok(halo) = self.eval_subtype( - UnificationPair{ - lhs: t.clone(), - rhs: lower_bound.clone(), - halo: TypeTerm::unit(), - addr: vec![] - } - ) { -// eprintln!("OK, is already larger type"); - } else { -// eprintln!("violated subtype restriction"); - return Err(UnificationError{ addr: unification_pair.addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t }); - } - } else { -// eprintln!("set var {}'s lowerbound to {:?}", varid, t.clone()); - self.lower_bounds.insert(varid, t.clone()); - } - self.reapply_subst(); - Ok(TypeTerm::unit()) - } else if t == TypeTerm::TypeID(TypeID::Var(varid)) { + (t, TypeTerm::TypeID(TypeID::Var(v))) => { + //eprintln!("t <= variable"); + if self.add_lower_subtype_bound(v, t.clone()).is_ok() { Ok(TypeTerm::unit()) } else { - Err(UnificationError{ addr: unification_pair.addr, t1: TypeTerm::TypeID(TypeID::Var(varid)), t2: t }) + Err(UnificationError{ addr: unification_pair.addr, t1: TypeTerm::TypeID(TypeID::Var(v)), t2: t }) } } - (TypeTerm::TypeID(TypeID::Var(varid)), t) => { -// eprintln!("variable <= t"); - - if let Some(upper_bound) = self.upper_bounds.get(&varid).cloned() { - if let Ok(_halo) = self.eval_subtype( - UnificationPair { - lhs: t.clone(), - rhs: upper_bound, - halo: TypeTerm::unit(), - addr: vec![] - } - ) { - // found a lower upper bound - self.upper_bounds.insert(varid, t); - } + (TypeTerm::TypeID(TypeID::Var(v)), t) => { + //eprintln!("variable <= t"); + if self.add_upper_subtype_bound(v, t.clone()).is_ok() { + Ok(TypeTerm::unit()) } else { - self.upper_bounds.insert(varid, t); + Err(UnificationError{ addr: unification_pair.addr, t1: TypeTerm::TypeID(TypeID::Var(v)), t2: t }) } - Ok(TypeTerm::unit()) } - /* Atoms */ @@ -236,60 +270,91 @@ impl UnificationProblem { */ (TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) => { - let mut halo = Vec::new(); - if a1.len() < a2.len() { - return Err(UnificationError { - addr: vec![], - t1: unification_pair.lhs, - t2: unification_pair.rhs - }); - } - for i in 0..a1.len() { - let mut new_addr = unification_pair.addr.clone(); - new_addr.push(i); - if let Ok(r_halo) = self.eval_subtype( UnificationPair { - lhs: a1[i].clone(), - rhs: a2[0].clone(), + let mut l1_iter = a1.into_iter().enumerate().rev(); + let mut l2_iter = a2.into_iter().rev(); - halo: TypeTerm::unit(), - addr: new_addr - }) { - for j in 0..a2.len() { - if i+j < a1.len() { - let mut new_addr = unification_pair.addr.clone(); - new_addr.push(i+j); + let mut halo_ladder = Vec::new(); - let lhs = a1[i+j].clone();//.apply_substitution(&|k| self.σ.get(k).cloned()).clone(); - let rhs = a2[j].clone();//.apply_substitution(&|k| self.σ.get(k).cloned()).clone(); + while let Some(rhs) = l2_iter.next() { + //eprintln!("take rhs = {:?}", rhs); + if let Some((i, lhs)) = l1_iter.next() { + //eprintln!("take lhs ({}) = {:?}", i, lhs); + let mut addr = unification_pair.addr.clone(); + addr.push(i); + //eprintln!("addr = {:?}", addr); - if let Ok(rung_halo) = self.eval_subtype( + match (lhs.clone(), rhs.clone()) { + (t, TypeTerm::TypeID(TypeID::Var(v))) => { + + if self.add_upper_subtype_bound(v,t.clone()).is_ok() { + let mut new_upper_bound_ladder = vec![ t ]; + + if let Some(next_rhs) = l2_iter.next() { + + // TODO + + } else { + // take everything + + while let Some((i,t)) = l1_iter.next() { + new_upper_bound_ladder.push(t); + } + } + + new_upper_bound_ladder.reverse(); + if self.add_upper_subtype_bound(v, TypeTerm::Ladder(new_upper_bound_ladder)).is_ok() { + // ok + } else { + return Err(UnificationError { + addr, + t1: lhs, + t2: rhs + }); + } + } else { + return Err(UnificationError { + addr, + t1: lhs, + t2: rhs + }); + } + } + (lhs, rhs) => { + if let Ok(ψ) = self.eval_subtype( UnificationPair { lhs: lhs.clone(), rhs: rhs.clone(), - addr: new_addr.clone(), - halo: TypeTerm::unit() + addr:addr.clone(), halo: TypeTerm::unit() } ) { - halo.push(rung_halo); + // ok. + //eprintln!("rungs are subtypes. continue"); + halo_ladder.push(ψ); } else { - return Err(UnificationError{ addr: new_addr, t1: lhs, t2: rhs }) + return Err(UnificationError { + addr, + t1: lhs, + t2: rhs + }); } } } - - return Ok( - if halo.len() == 1 { - halo.pop().unwrap() - } else { - TypeTerm::Ladder(halo).strip() - }); } else { - halo.push(a1[i].clone()); - //eprintln!("could not unify ladders"); + // not a subtype, + return Err(UnificationError { + addr: vec![], + t1: unification_pair.lhs, + t2: unification_pair.rhs + }); } } + //eprintln!("left ladder fully consumed"); - Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs }) + for (i,t) in l1_iter { + halo_ladder.push(t); + } + halo_ladder.reverse(); + Ok(TypeTerm::Ladder(halo_ladder).strip().param_normalize()) }, (t, TypeTerm::Ladder(a1)) => { @@ -339,7 +404,7 @@ impl UnificationProblem { // eprintln!("after strip: {:?}", y); // eprintln!("APP<> eval {:?} \n ?<=? {:?} ", x, y); - if let Ok(halo) = self.eval_subtype( + match self.eval_subtype( UnificationPair { lhs: x.clone(), rhs: y.clone(), @@ -347,15 +412,16 @@ impl UnificationProblem { addr: new_addr, } ) { + Ok(halo) => { if halo == TypeTerm::unit() { let mut y = y.clone(); y.apply_substitution(&|k| self.σ.get(k).cloned()); y = y.strip(); let mut top = y.get_lnf_vec().first().unwrap().clone(); halo_args.push(top.clone()); -// eprintln!("add top {:?}", top); + //eprintln!("add top {:?}", top); } else { -// eprintln!("add halo {:?}", halo); + //eprintln!("add halo {:?}", halo); if n_halos_required > 0 { let x = &mut halo_args[n_halos_required-1]; if let TypeTerm::Ladder(argrs) = x { @@ -376,13 +442,13 @@ impl UnificationProblem { halo_args.push(halo); n_halos_required += 1; } - } else { - return Err(UnificationError{ addr: unification_pair.addr, t1: unification_pair.lhs, t2: unification_pair.rhs }); + }, + Err(err) => { return Err(err); } } } if n_halos_required > 0 { -// eprintln!("halo args : {:?}", halo_args); + //eprintln!("halo args : {:?}", halo_args); Ok(TypeTerm::App(halo_args)) } else { Ok(TypeTerm::unit())