From 5d7668573a17afa4e0e3f80b47389e6291e922a1 Mon Sep 17 00:00:00 2001 From: Michael Sippel Date: Mon, 12 Aug 2024 21:18:17 +0200 Subject: [PATCH] initial implementation of solver for steiner trees --- src/lib.rs | 1 + src/morphism.rs | 1 - src/steiner_tree.rs | 162 +++++++++++++++++++++++++++++++++++++++++++ src/test/morphism.rs | 42 +++++++++-- 4 files changed, 201 insertions(+), 5 deletions(-) create mode 100644 src/steiner_tree.rs diff --git a/src/lib.rs b/src/lib.rs index bf775b4..5cdff81 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,7 @@ pub mod pnf; pub mod subtype; pub mod unification; pub mod morphism; +pub mod steiner_tree; #[cfg(test)] mod test; diff --git a/src/morphism.rs b/src/morphism.rs index 33eafd5..9160610 100644 --- a/src/morphism.rs +++ b/src/morphism.rs @@ -188,7 +188,6 @@ impl MorphismBase { None } - pub fn find_morphism(&self, ty: &MorphismType) -> Option< ( M, HashMap ) > { diff --git a/src/steiner_tree.rs b/src/steiner_tree.rs new file mode 100644 index 0000000..f5338e9 --- /dev/null +++ b/src/steiner_tree.rs @@ -0,0 +1,162 @@ +use { + std::collections::HashMap, + crate::{ + TypeID, + TypeTerm, + morphism::{ + MorphismType, + Morphism, + MorphismBase + } + } +}; + +//<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\ + +#[derive(Clone)] +pub struct SteinerTree { + weight: u64, + goals: Vec< TypeTerm >, + pub edges: Vec< MorphismType >, +} + +impl SteinerTree { + fn add_edge(&mut self, ty: MorphismType) { + self.weight += 1; + + let ty = ty.normalize(); + + // check if by adding this new edge, we reach a goal + let mut new_goals = Vec::new(); + let mut added = false; + + for g in self.goals.clone() { + if let Ok(σ) = crate::unify(&ty.dst_type, &g) { + if !added { + self.edges.push(ty.clone()); + + // goal reached. + for e in self.edges.iter_mut() { + e.src_type = e.src_type.apply_substitution(&|x| σ.get(x).cloned()).clone(); + e.dst_type = e.dst_type.apply_substitution(&|x| σ.get(x).cloned()).clone(); + } + added = true; + } else { + new_goals.push(g); + } + } else { + new_goals.push(g); + } + } + + if !added { + self.edges.push(ty.clone()); + } + + self.goals = new_goals; + } + + fn is_solved(&self) -> bool { + self.goals.len() == 0 + } + + fn contains(&self, t: &TypeTerm) -> Option< HashMap > { + for e in self.edges.iter() { + if let Ok(σ) = crate::unify(&e.dst_type, t) { + return Some(σ) + } + } + + None + } +} + +/* given a representation tree with the available + * represenatations `src_types`, try to find + * a sequence of morphisms that span up all + * representations in `dst_types`. + */ +pub struct SteinerTreeProblem { + src_types: Vec< TypeTerm >, + queue: Vec< SteinerTree > +} + +impl SteinerTreeProblem { + pub fn new( + src_types: Vec< TypeTerm >, + dst_types: Vec< TypeTerm > + ) -> Self { + SteinerTreeProblem { + src_types: src_types.into_iter().map(|t| t.normalize()).collect(), + queue: vec![ + SteinerTree { + weight: 0, + goals: dst_types.into_iter().map(|t| t.normalize()).collect(), + edges: Vec::new() + } + ] + } + } + + pub fn next(&mut self) -> Option< SteinerTree > { + eprintln!("queue size = {}", self.queue.len()); + + /* FIXME: by giving the highest priority to + * candidate tree with the least remaining goals, + * the optimality of the search algorithm + * is probably destroyed, but it dramatically helps + * to tame the combinatorical explosion in this algorithm. + */ + self.queue.sort_by(|t1, t2| + if t1.goals.len() < t2.goals.len() { + std::cmp::Ordering::Greater + } else if t1.goals.len() == t2.goals.len() { + if t1.weight < t2.weight { + std::cmp::Ordering::Greater + } else { + std::cmp::Ordering::Less + } + } else { + std::cmp::Ordering::Less + } + ); + self.queue.pop() + } + + pub fn solve_bfs(&mut self, dict: &crate::dict::TypeDict, morphisms: &MorphismBase) -> Option< SteinerTree > { + + // take the currently smallest tree and extend it by one step + while let Some( mut current_tree ) = self.next() { + + // check if current tree is a solution + if current_tree.goals.len() == 0 { + return Some(current_tree); + } + + // get all vertices spanned by this tree + let mut current_nodes = self.src_types.clone(); + for e in current_tree.edges.iter() { + current_nodes.push( e.dst_type.clone() ); + } + + // extend the tree by one edge and add it to the queue + for src_type in current_nodes.iter() { + for (dst_halo,dst_ty) in morphisms.enum_morphisms_with_subtyping( &src_type ) { + let dst_type = TypeTerm::Ladder(vec![ + dst_halo, dst_ty + ]).normalize(); + + if !current_nodes.contains( &dst_type ) { + let mut new_tree = current_tree.clone(); + let src_type = src_type.clone(); + new_tree.add_edge(MorphismType { src_type, dst_type }.normalize()); + self.queue.push( new_tree ); + } + } + } + } + + None + } +} + diff --git a/src/test/morphism.rs b/src/test/morphism.rs index 47bd100..b908101 100644 --- a/src/test/morphism.rs +++ b/src/test/morphism.rs @@ -1,5 +1,5 @@ use { - crate::{dict::*, morphism::*, TypeTerm} + crate::{dict::*, morphism::*, steiner_tree::*, TypeTerm} }; //<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\ @@ -26,9 +26,8 @@ impl Morphism for DummyMorphism { })) } } - -#[test] -fn test_morphism_path() { + +fn morphism_test_setup() -> ( TypeDict, MorphismBase ) { let mut dict = TypeDict::new(); let mut base = MorphismBase::::new( dict.add_typename("Seq".into()) ); @@ -67,6 +66,13 @@ fn test_morphism_path() { }) ); + (dict, base) +} + +#[test] +fn test_morphism_path() { + let (mut dict, mut base) = morphism_test_setup(); + assert_eq!( base.find_morphism_path(MorphismType { src_type: dict.parse("ℕ ~ ~ ~ Char>").unwrap(), @@ -125,3 +131,31 @@ fn test_morphism_path() { ); } +#[test] +fn test_steiner_tree() { + let (mut dict, mut base) = morphism_test_setup(); + + + let mut steiner_tree_problem = SteinerTreeProblem::new( + // source reprs + vec![ + dict.parse("ℕ ~ ~ ~ Char>").unwrap(), + ], + + // destination reprs + vec![ + dict.parse("ℕ ~ ~ ~ Char>").unwrap(), + dict.parse("ℕ ~ ~ ~ Char>").unwrap(), + dict.parse("ℕ ~ ~ ~ Char>").unwrap() + ] + ); + + if let Some(solution) = steiner_tree_problem.solve_bfs( &dict, &base ) { + for e in solution.edges.iter() { + eprintln!(" :: {}\n--> {}", dict.unparse(&e.src_type), dict.unparse(&e.dst_type)); + } + } else { + eprintln!("no solution"); + } +} +