add substitution trait

This commit is contained in:
Michael Sippel 2025-03-24 14:06:16 +01:00
parent bda36b4856
commit cae616b7ae
Signed by: senvas
GPG key ID: F96CF119C34B64A6
9 changed files with 92 additions and 63 deletions

View file

@ -2,6 +2,8 @@
pub mod bimap;
pub mod dict;
pub mod term;
pub mod substitution;
pub mod lexer;
pub mod parser;
pub mod unparser;
@ -25,6 +27,7 @@ mod pretty;
pub use {
dict::*,
term::*,
substitution::*,
sugar::*,
unification::*,
morphism::*

View file

@ -57,13 +57,13 @@ impl<M: Morphism + Clone> MorphismInstance<M> {
src_type: TypeTerm::Ladder(vec![
self.halo.clone(),
self.m.get_type().src_type.clone()
]).apply_substitution(&|k| self.σ.get(k).cloned())
]).apply_subst(&self.σ)
.clone(),
dst_type: TypeTerm::Ladder(vec![
self.halo.clone(),
self.m.get_type().dst_type.clone()
]).apply_substitution(&|k| self.σ.get(k).cloned())
]).apply_subst(&self.σ)
.clone(),
/*
trait_bounds: Vec::new(),

View file

@ -57,7 +57,7 @@ impl<'a, M:Morphism+Clone> ShortestPathProblem<'a, M> {
|| dst_type.contains_var(*varid) {
new_σ.insert(
k.clone(),
v.clone().apply_substitution(&|k| σ.get(k).cloned()).clone().strip()
v.clone().apply_subst(&σ).clone().strip()
);
}
}
@ -66,18 +66,12 @@ impl<'a, M:Morphism+Clone> ShortestPathProblem<'a, M> {
if let TypeID::Var(varid) = k {
if src_type.contains_var(*varid)
|| dst_type.contains_var(*varid) {
new_σ.insert(
k.clone(),
v.clone().apply_substitution(&|k| σ.get(k).cloned()).clone().strip()
);
new_σ.insert( k.clone(), v.clone().apply_subst(&σ).clone().strip() );
}
}
}
n.halo = n.halo.clone().apply_substitution(
&|k| σ.get(k).cloned()
).clone().strip().param_normalize();
n.halo = n.halo.clone().apply_subst(&σ).clone().strip().param_normalize();
n.σ = new_σ;
}
@ -98,20 +92,20 @@ impl<'a, M:Morphism+Clone> ShortestPathProblem<'a, M> {
for (k,v) in next_morph_inst.σ.iter() {
new_σ.insert(
k.clone(),
v.clone().apply_substitution(&|k| next_morph_inst.σ.get(k).cloned()).clone()
v.clone().apply_subst(&next_morph_inst.σ).clone()
);
}
for (k,v) in n.σ.iter() {
new_σ.insert(
k.clone(),
v.clone().apply_substitution(&|k| next_morph_inst.σ.get(k).cloned()).clone()
v.clone().apply_subst(&next_morph_inst.σ).clone()
);
}
n.halo = n.halo.clone().apply_substitution(
&|k| next_morph_inst.σ.get(k).cloned()
).clone().strip().param_normalize();
n.halo = n.halo.clone()
.apply_subst( &next_morph_inst.σ ).clone()
.strip().param_normalize();
n.σ = new_σ;
}

View file

@ -38,8 +38,8 @@ impl SteinerTree {
// goal reached.
for e in self.edges.iter_mut() {
e.src_type = e.src_type.apply_substitution(&|x| σ.get(x).cloned()).clone();
e.dst_type = e.dst_type.apply_substitution(&|x| σ.get(x).cloned()).clone();
e.src_type = e.src_type.apply_subst(&σ).clone();
e.dst_type = e.dst_type.apply_subst(&σ).clone();
}
added = true;
} else {

62
src/substitution.rs Normal file
View file

@ -0,0 +1,62 @@
use crate::{
TypeID,
TypeTerm
};
//<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\
pub trait Substitution {
fn get(&self, t: &TypeID) -> Option< TypeTerm >;
}
impl<S: Fn(&TypeID)->Option<TypeTerm>> Substitution for S {
fn get(&self, t: &TypeID) -> Option< TypeTerm > {
(self)(t)
}
}
impl Substitution for std::collections::HashMap< TypeID, TypeTerm > {
fn get(&self, t: &TypeID) -> Option< TypeTerm > {
(self as &std::collections::HashMap< TypeID, TypeTerm >).get(t).cloned()
}
}
impl TypeTerm {
/// recursively apply substitution to all subterms,
/// which will replace all occurences of variables which map
/// some type-term in `subst`
pub fn apply_substitution(
&mut self,
σ: &impl Substitution
) -> &mut Self {
self.apply_subst(σ)
}
pub fn apply_subst(
&mut self,
σ: &impl Substitution
) -> &mut Self {
match self {
TypeTerm::TypeID(typid) => {
if let Some(t) = σ.get(typid) {
*self = t;
}
}
TypeTerm::Ladder(rungs) => {
for r in rungs.iter_mut() {
r.apply_subst(σ);
}
}
TypeTerm::App(args) => {
for r in args.iter_mut() {
r.apply_subst(σ);
}
}
_ => {}
}
self
}
}

View file

@ -92,35 +92,6 @@ impl TypeTerm {
}
}
/// recursively apply substitution to all subterms,
/// which will replace all occurences of variables which map
/// some type-term in `subst`
pub fn apply_substitution(
&mut self,
subst: &impl Fn(&TypeID) -> Option<TypeTerm>
) -> &mut Self {
match self {
TypeTerm::TypeID(typid) => {
if let Some(t) = subst(typid) {
*self = t;
}
}
TypeTerm::Ladder(rungs) => {
for r in rungs.iter_mut() {
r.apply_substitution(subst);
}
}
TypeTerm::App(args) => {
for r in args.iter_mut() {
r.apply_substitution(subst);
}
}
_ => {}
}
self
}
/* strip away empty ladders
* & unwrap singletons

View file

@ -1,7 +1,7 @@
use {
crate::{dict::*, term::*, parser::*, unparser::*},
std::iter::FromIterator
crate::{dict::*, term::*, parser::*, unparser::*, substitution::*},
std::iter::FromIterator,
};
//<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\
@ -24,8 +24,7 @@ fn test_subst() {
assert_eq!(
dict.parse("<Seq T~U>").unwrap()
.apply_substitution(&|typid|{ σ.get(typid).cloned() }).clone(),
dict.parse("<Seq T~U>").unwrap().apply_subst(&σ).clone(),
dict.parse("<Seq ~<Seq Char>>").unwrap()
);
}

View file

@ -23,8 +23,8 @@ fn test_unify(ts1: &str, ts2: &str, expect_unificator: bool) {
let σ = σ.unwrap();
assert_eq!(
t1.apply_substitution(&|v| σ.get(v).cloned()),
t2.apply_substitution(&|v| σ.get(v).cloned())
t1.apply_subst(&σ),
t2.apply_subst(&σ)
);
} else {
assert!(! σ.is_ok());

View file

@ -79,7 +79,7 @@ impl UnificationProblem {
let mut new_σ = HashMap::new();
for (v, tt) in self.σ.iter() {
let mut tt = tt.clone().normalize();
tt.apply_substitution(&|v| self.σ.get(v).cloned());
tt.apply_subst(&self.σ);
tt = tt.normalize();
//eprintln!("update σ : {:?} --> {:?}", v, tt);
new_σ.insert(v.clone(), tt);
@ -415,7 +415,7 @@ impl UnificationProblem {
Ok(halo) => {
if halo == TypeTerm::unit() {
let mut y = y.clone();
y.apply_substitution(&|k| self.σ.get(k).cloned());
y.apply_subst(&self.σ);
y = y.strip();
let mut top = y.get_lnf_vec().first().unwrap().clone();
halo_args.push(top.clone());
@ -426,7 +426,7 @@ impl UnificationProblem {
let x = &mut halo_args[n_halos_required-1];
if let TypeTerm::Ladder(argrs) = x {
let mut a = a2[n_halos_required-1].clone();
a.apply_substitution(&|k| self.σ.get(k).cloned());
a.apply_subst(&self.σ);
a = a.get_lnf_vec().first().unwrap().clone();
argrs.push(a);
} else {
@ -435,7 +435,7 @@ impl UnificationProblem {
a2[n_halos_required-1].clone().get_lnf_vec().first().unwrap().clone()
]);
x.apply_substitution(&|k| self.σ.get(k).cloned());
x.apply_subst(&self.σ);
}
}
@ -465,8 +465,8 @@ impl UnificationProblem {
pub fn solve(mut self) -> Result<(Vec<TypeTerm>, HashMap<TypeID, TypeTerm>), UnificationError> {
// solve equations
while let Some( mut equal_pair ) = self.equal_pairs.pop() {
equal_pair.lhs.apply_substitution(&|v| self.σ.get(v).cloned());
equal_pair.rhs.apply_substitution(&|v| self.σ.get(v).cloned());
equal_pair.lhs.apply_subst(&self.σ);
equal_pair.rhs.apply_subst(&self.σ);
self.eval_equation(equal_pair)?;
}
@ -474,8 +474,8 @@ impl UnificationProblem {
// solve subtypes
// eprintln!("------ SOLVE SUBTYPES ---- ");
for mut subtype_pair in self.subtype_pairs.clone().into_iter() {
subtype_pair.lhs.apply_substitution(&|v| self.σ.get(v).cloned());
subtype_pair.rhs.apply_substitution(&|v| self.σ.get(v).cloned());
subtype_pair.lhs.apply_subst(&self.σ);
subtype_pair.rhs.apply_subst(&self.σ);
let _halo = self.eval_subtype( subtype_pair.clone() )?.strip();
}
@ -495,8 +495,8 @@ impl UnificationProblem {
// eprintln!("------ MAKE HALOS -----");
let mut halo_types = Vec::new();
for mut subtype_pair in self.subtype_pairs.clone().into_iter() {
subtype_pair.lhs = subtype_pair.lhs.apply_substitution(&|v| self.σ.get(v).cloned()).clone().strip();
subtype_pair.rhs = subtype_pair.rhs.apply_substitution(&|v| self.σ.get(v).cloned()).clone().strip();
subtype_pair.lhs = subtype_pair.lhs.apply_subst(&self.σ).clone().strip();
subtype_pair.rhs = subtype_pair.rhs.apply_subst(&self.σ).clone().strip();
let halo = self.eval_subtype( subtype_pair.clone() )?.strip();
halo_types.push(halo);