diff --git a/src/lib.rs b/src/lib.rs index 517b36d..1a270cc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,7 @@ pub mod unparser; pub mod curry; pub mod lnf; pub mod subtype; +pub mod unification; #[cfg(test)] mod test; @@ -15,5 +16,6 @@ mod test; pub use { dict::*, term::*, + unification::* }; diff --git a/src/test/mod.rs b/src/test/mod.rs index 4cde4e3..d116412 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -5,4 +5,5 @@ pub mod curry; pub mod lnf; pub mod subtype; pub mod substitution; +pub mod unification; diff --git a/src/test/unification.rs b/src/test/unification.rs new file mode 100644 index 0000000..40b7d68 --- /dev/null +++ b/src/test/unification.rs @@ -0,0 +1,79 @@ + +use { + crate::{dict::*, term::*, unification::*}, + std::iter::FromIterator +}; + +//<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\ + +fn test_unify(ts1: &str, ts2: &str, expect_unificator: bool) { + let mut dict = TypeDict::new(); + dict.add_varname(String::from("T")); + dict.add_varname(String::from("U")); + dict.add_varname(String::from("V")); + dict.add_varname(String::from("W")); + + let mut t1 = dict.parse(ts1).unwrap(); + let mut t2 = dict.parse(ts2).unwrap(); + let σ = crate::unify( &t1, &t2 ); + + if expect_unificator { + assert!(σ.is_ok()); + + let σ = σ.unwrap(); + + assert_eq!( + t1.apply_substitution(&|v| σ.get(v).cloned()), + t2.apply_substitution(&|v| σ.get(v).cloned()) + ); + } else { + assert!(! σ.is_ok()); + } +} + +#[test] +fn test_unification_error() { + let mut dict = TypeDict::new(); + dict.add_varname(String::from("T")); + + assert_eq!( + crate::unify( + &dict.parse("").unwrap(), + &dict.parse("").unwrap() + ), + + Err(UnificationError { + addr: vec![0], + t1: dict.parse("A").unwrap(), + t2: dict.parse("B").unwrap() + }) + ); + + assert_eq!( + crate::unify( + &dict.parse(" T>").unwrap(), + &dict.parse(" T>").unwrap() + ), + + Err(UnificationError { + addr: vec![1, 1], + t1: dict.parse("A").unwrap(), + t2: dict.parse("B").unwrap() + }) + ); +} + +#[test] +fn test_unification() { + test_unify("A", "A", true); + test_unify("A", "B", false); + test_unify("", "", true); + test_unify("", "", true); + + test_unify( + ">~~", + ">~~", + true + ); +} + diff --git a/src/unification.rs b/src/unification.rs new file mode 100644 index 0000000..6d7598a --- /dev/null +++ b/src/unification.rs @@ -0,0 +1,107 @@ +use { + std::collections::HashMap, + crate::{term::*, dict::*} +}; + +//<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\ + +#[derive(Clone, Eq, PartialEq, Debug)] +pub struct UnificationError { + pub addr: Vec, + pub t1: TypeTerm, + pub t2: TypeTerm +} + +impl UnificationError { + pub fn new(t1: &TypeTerm, t2: &TypeTerm) -> Self { + UnificationError { + addr: vec![], + t1: t1.clone(), + t2: t2.clone() + } + } +} +/* +struct UnificationProblem { + eqs: Vec<(TypeTerm, TypeTerm)>, + σ: HashMap +} + +impl UnificationProblem { + pub fn new() -> Self { + UnificationProblem { + eqs: Vec::new(), + σ: HashMap::new() + } + } + + pub fn eval_equation(&mut self, lhs: &TypeTerm, rhs: &TypeTerm) -> Option { + match (lhs, rhs) { + + } + } + + pub fn solve(self) -> Result, UnificationError> { + + } +} +*/ +pub fn unify( + t1: &TypeTerm, + t2: &TypeTerm +) -> Result, UnificationError> { + let mut σ = HashMap::new(); + + match (t1, t2) { + (TypeTerm::TypeID(TypeID::Var(varid)), t) | + (t, TypeTerm::TypeID(TypeID::Var(varid))) => { + σ.insert(TypeID::Var(*varid), t.clone()); + Ok(σ) + } + + (TypeTerm::TypeID(a1), TypeTerm::TypeID(a2)) => { + if a1 == a2 { Ok(σ) } else { Err(UnificationError::new(&t1, &t2)) } + } + (TypeTerm::Num(n1), TypeTerm::Num(n2)) => { + if n1 == n2 { Ok(σ) } else { Err(UnificationError::new(&t1, &t2)) } + } + (TypeTerm::Char(c1), TypeTerm::Char(c2)) => { + if c1 == c2 { Ok(σ) } else { Err(UnificationError::new(&t1, &t2)) } + } + + (TypeTerm::Ladder(a1), TypeTerm::Ladder(a2)) | + (TypeTerm::App(a1), TypeTerm::App(a2)) => { + if a1.len() == a2.len() { + for (i, (x, y)) in a1.iter().cloned().zip(a2.iter().cloned()).enumerate() { + let (mut x, mut y) = (x.clone(), y.clone()); + x.apply_substitution(&|v| σ.get(v).cloned()); + y.apply_substitution(&|v| σ.get(v).cloned()); + + match unify(&x, &y) { + Ok(τ) => { + for (v,t) in τ { + σ.insert(v,t); + } + } + Err(mut err) => { + err.addr.insert(0, i); + return Err(err); + } + } + } + Ok(σ) + } else { + Err(UnificationError::new(&t1, &t2)) + } + } + + (TypeTerm::Ladder(l1), TypeTerm::Ladder(l2)) => { + Err(UnificationError::new(&t1, &t2)) + } + + _ => Err(UnificationError::new(t1, t2)) + } +} + +//<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\ +