add Substitution trait

This commit is contained in:
Michael Sippel 2025-03-24 14:06:16 +01:00
parent fab6818fe9
commit 08a9bad0ad
Signed by: senvas
GPG key ID: F96CF119C34B64A6
6 changed files with 80 additions and 45 deletions

View file

@ -2,6 +2,8 @@
pub mod bimap;
pub mod dict;
pub mod term;
pub mod substitution;
pub mod lexer;
pub mod parser;
pub mod unparser;
@ -21,6 +23,7 @@ mod pretty;
pub use {
dict::*,
term::*,
substitution::*,
sugar::*,
unification::*,
};

62
src/substitution.rs Normal file
View file

@ -0,0 +1,62 @@
use crate::{
TypeID,
TypeTerm
};
//<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\
pub trait Substitution {
fn get(&self, t: &TypeID) -> Option< TypeTerm >;
}
impl<S: Fn(&TypeID)->Option<TypeTerm>> Substitution for S {
fn get(&self, t: &TypeID) -> Option< TypeTerm > {
(self)(t)
}
}
impl Substitution for std::collections::HashMap< TypeID, TypeTerm > {
fn get(&self, t: &TypeID) -> Option< TypeTerm > {
(self as &std::collections::HashMap< TypeID, TypeTerm >).get(t).cloned()
}
}
impl TypeTerm {
/// recursively apply substitution to all subterms,
/// which will replace all occurences of variables which map
/// some type-term in `subst`
pub fn apply_substitution(
&mut self,
σ: &impl Substitution
) -> &mut Self {
self.apply_subst(σ)
}
pub fn apply_subst(
&mut self,
σ: &impl Substitution
) -> &mut Self {
match self {
TypeTerm::TypeID(typid) => {
if let Some(t) = σ.get(typid) {
*self = t;
}
}
TypeTerm::Ladder(rungs) => {
for r in rungs.iter_mut() {
r.apply_subst(σ);
}
}
TypeTerm::App(args) => {
for r in args.iter_mut() {
r.apply_subst(σ);
}
}
_ => {}
}
self
}
}

View file

@ -95,35 +95,6 @@ impl TypeTerm {
}
}
/// recursively apply substitution to all subterms,
/// which will replace all occurences of variables which map
/// some type-term in `subst`
pub fn apply_substitution(
&mut self,
subst: &impl Fn(&TypeID) -> Option<TypeTerm>
) -> &mut Self {
match self {
TypeTerm::TypeID(typid) => {
if let Some(t) = subst(typid) {
*self = t;
}
}
TypeTerm::Ladder(rungs) => {
for r in rungs.iter_mut() {
r.apply_substitution(subst);
}
}
TypeTerm::App(args) => {
for r in args.iter_mut() {
r.apply_substitution(subst);
}
}
_ => {}
}
self
}
/* strip away empty ladders
* & unwrap singletons

View file

@ -1,7 +1,7 @@
use {
crate::{dict::*, term::*, parser::*, unparser::*},
std::iter::FromIterator
crate::{dict::*, term::*, parser::*, unparser::*, substitution::*},
std::iter::FromIterator,
};
//<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\
@ -24,8 +24,7 @@ fn test_subst() {
assert_eq!(
dict.parse("<Seq T~U>").unwrap()
.apply_substitution(&|typid|{ σ.get(typid).cloned() }).clone(),
dict.parse("<Seq T~U>").unwrap().apply_subst(&σ).clone(),
dict.parse("<Seq ~<Seq Char>>").unwrap()
);
}

View file

@ -23,8 +23,8 @@ fn test_unify(ts1: &str, ts2: &str, expect_unificator: bool) {
let σ = σ.unwrap();
assert_eq!(
t1.apply_substitution(&|v| σ.get(v).cloned()),
t2.apply_substitution(&|v| σ.get(v).cloned())
t1.apply_subst(&σ),
t2.apply_subst(&σ)
);
} else {
assert!(! σ.is_ok());

View file

@ -79,7 +79,7 @@ impl UnificationProblem {
let mut new_σ = HashMap::new();
for (v, tt) in self.σ.iter() {
let mut tt = tt.clone().normalize();
tt.apply_substitution(&|v| self.σ.get(v).cloned());
tt.apply_subst(&self.σ);
tt = tt.normalize();
//eprintln!("update σ : {:?} --> {:?}", v, tt);
new_σ.insert(v.clone(), tt);
@ -414,7 +414,7 @@ impl UnificationProblem {
Ok(halo) => {
if halo == TypeTerm::unit() {
let mut y = y.clone();
y.apply_substitution(&|k| self.σ.get(k).cloned());
y.apply_subst(&self.σ);
y = y.strip();
let mut top = y.get_lnf_vec().first().unwrap().clone();
halo_args.push(top.clone());
@ -425,7 +425,7 @@ impl UnificationProblem {
let x = &mut halo_args[n_halos_required-1];
if let TypeTerm::Ladder(argrs) = x {
let mut a = a2[n_halos_required-1].clone();
a.apply_substitution(&|k| self.σ.get(k).cloned());
a.apply_subst(&self.σ);
a = a.get_lnf_vec().first().unwrap().clone();
argrs.push(a);
} else {
@ -434,7 +434,7 @@ impl UnificationProblem {
a2[n_halos_required-1].clone().get_lnf_vec().first().unwrap().clone()
]);
x.apply_substitution(&|k| self.σ.get(k).cloned());
x.apply_subst(&self.σ);
}
}
@ -464,8 +464,8 @@ impl UnificationProblem {
pub fn solve(mut self) -> Result<(Vec<TypeTerm>, HashMap<TypeID, TypeTerm>), UnificationError> {
// solve equations
while let Some( mut equal_pair ) = self.equal_pairs.pop() {
equal_pair.lhs.apply_substitution(&|v| self.σ.get(v).cloned());
equal_pair.rhs.apply_substitution(&|v| self.σ.get(v).cloned());
equal_pair.lhs.apply_subst(&self.σ);
equal_pair.rhs.apply_subst(&self.σ);
self.eval_equation(equal_pair)?;
}
@ -473,8 +473,8 @@ impl UnificationProblem {
// solve subtypes
// eprintln!("------ SOLVE SUBTYPES ---- ");
for mut subtype_pair in self.subtype_pairs.clone().into_iter() {
subtype_pair.lhs.apply_substitution(&|v| self.σ.get(v).cloned());
subtype_pair.rhs.apply_substitution(&|v| self.σ.get(v).cloned());
subtype_pair.lhs.apply_subst(&self.σ);
subtype_pair.rhs.apply_subst(&self.σ);
let _halo = self.eval_subtype( subtype_pair.clone() )?.strip();
}
@ -494,8 +494,8 @@ impl UnificationProblem {
// eprintln!("------ MAKE HALOS -----");
let mut halo_types = Vec::new();
for mut subtype_pair in self.subtype_pairs.clone().into_iter() {
subtype_pair.lhs = subtype_pair.lhs.apply_substitution(&|v| self.σ.get(v).cloned()).clone().strip();
subtype_pair.rhs = subtype_pair.rhs.apply_substitution(&|v| self.σ.get(v).cloned()).clone().strip();
subtype_pair.lhs = subtype_pair.lhs.apply_subst(&self.σ).clone().strip();
subtype_pair.rhs = subtype_pair.rhs.apply_subst(&self.σ).clone().strip();
let halo = self.eval_subtype( subtype_pair.clone() )?.strip();
halo_types.push(halo);