From 5d7668573a17afa4e0e3f80b47389e6291e922a1 Mon Sep 17 00:00:00 2001
From: Michael Sippel <micha@fragmental.art>
Date: Mon, 12 Aug 2024 21:18:17 +0200
Subject: [PATCH] initial implementation of solver for steiner trees

---
 src/lib.rs           |   1 +
 src/morphism.rs      |   1 -
 src/steiner_tree.rs  | 162 +++++++++++++++++++++++++++++++++++++++++++
 src/test/morphism.rs |  42 +++++++++--
 4 files changed, 201 insertions(+), 5 deletions(-)
 create mode 100644 src/steiner_tree.rs

diff --git a/src/lib.rs b/src/lib.rs
index bf775b4..5cdff81 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -11,6 +11,7 @@ pub mod pnf;
 pub mod subtype;
 pub mod unification;
 pub mod morphism;
+pub mod steiner_tree;
 
 #[cfg(test)]
 mod test;
diff --git a/src/morphism.rs b/src/morphism.rs
index 33eafd5..9160610 100644
--- a/src/morphism.rs
+++ b/src/morphism.rs
@@ -188,7 +188,6 @@ impl<M: Morphism + Clone> MorphismBase<M> {
         None
     }
 
-
     pub fn find_morphism(&self, ty: &MorphismType)
     -> Option< ( M, HashMap<TypeID, TypeTerm> ) > {
 
diff --git a/src/steiner_tree.rs b/src/steiner_tree.rs
new file mode 100644
index 0000000..f5338e9
--- /dev/null
+++ b/src/steiner_tree.rs
@@ -0,0 +1,162 @@
+use {
+    std::collections::HashMap,
+    crate::{
+        TypeID,
+        TypeTerm,
+        morphism::{
+            MorphismType,
+            Morphism,
+            MorphismBase
+        }
+    }
+};
+
+//<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\
+
+#[derive(Clone)]
+pub struct SteinerTree {
+    weight: u64,
+    goals: Vec< TypeTerm >,
+    pub edges: Vec< MorphismType >,
+}
+
+impl SteinerTree {
+    fn add_edge(&mut self, ty: MorphismType) {
+        self.weight += 1;
+
+        let ty = ty.normalize();
+
+        // check if by adding this new edge, we reach a goal
+        let mut new_goals = Vec::new();
+        let mut added = false;
+
+        for g in self.goals.clone() { 
+            if let Ok(σ) = crate::unify(&ty.dst_type, &g) {
+                if !added {
+                    self.edges.push(ty.clone());
+
+                    // goal reached.
+                    for e in self.edges.iter_mut() {
+                        e.src_type = e.src_type.apply_substitution(&|x| σ.get(x).cloned()).clone();
+                        e.dst_type = e.dst_type.apply_substitution(&|x| σ.get(x).cloned()).clone();
+                    }
+                    added = true;
+                } else {
+                    new_goals.push(g);
+                }
+            } else {
+                new_goals.push(g);
+            }
+        }
+
+        if !added {
+            self.edges.push(ty.clone());
+        }
+
+        self.goals = new_goals;
+    }
+
+    fn is_solved(&self) -> bool {
+        self.goals.len() == 0
+    }
+
+    fn contains(&self, t: &TypeTerm) -> Option< HashMap<TypeID, TypeTerm> > {
+        for e in self.edges.iter() {
+            if let Ok(σ) = crate::unify(&e.dst_type, t) {
+                return Some(σ)
+            }
+        }
+
+        None
+    }
+}
+
+/* given a representation tree with the available
+ * represenatations `src_types`, try to find
+ * a sequence of morphisms that span up all
+ * representations in `dst_types`.
+ */
+pub struct SteinerTreeProblem {
+    src_types: Vec< TypeTerm >,
+    queue: Vec< SteinerTree >
+}
+
+impl SteinerTreeProblem {
+    pub fn new(
+        src_types: Vec< TypeTerm >,
+        dst_types: Vec< TypeTerm >
+    ) -> Self {
+        SteinerTreeProblem {
+            src_types: src_types.into_iter().map(|t| t.normalize()).collect(),
+            queue: vec![
+                SteinerTree {
+                    weight: 0,
+                    goals: dst_types.into_iter().map(|t| t.normalize()).collect(),
+                    edges: Vec::new()
+                }
+            ]
+        }
+    }
+
+    pub fn next(&mut self) -> Option< SteinerTree > {
+        eprintln!("queue size = {}", self.queue.len());
+
+        /* FIXME: by giving the highest priority to
+         * candidate tree with the least remaining goals,
+         * the optimality of the search algorithm
+         * is probably destroyed, but it dramatically helps
+         * to tame the combinatorical explosion in this algorithm.
+         */
+        self.queue.sort_by(|t1, t2|
+            if t1.goals.len() < t2.goals.len() {
+                std::cmp::Ordering::Greater
+            } else if t1.goals.len() == t2.goals.len() {
+                if t1.weight < t2.weight {
+                    std::cmp::Ordering::Greater
+                } else {
+                    std::cmp::Ordering::Less
+                }
+            } else {
+                std::cmp::Ordering::Less
+            }
+        );
+        self.queue.pop()
+    }
+
+    pub fn solve_bfs<M: Morphism + Clone>(&mut self, dict: &crate::dict::TypeDict, morphisms: &MorphismBase<M>) -> Option< SteinerTree > {
+
+        // take the currently smallest tree and extend it by one step
+        while let Some( mut current_tree ) = self.next() {
+
+            // check if current tree is a solution
+            if current_tree.goals.len() == 0 {
+                return Some(current_tree);
+            }
+
+            // get all vertices spanned by this tree
+            let mut current_nodes = self.src_types.clone();
+            for e in current_tree.edges.iter() {
+                current_nodes.push( e.dst_type.clone() );
+            }
+
+            // extend the tree by one edge and add it to the queue
+            for src_type in current_nodes.iter() {
+                for (dst_halo,dst_ty) in morphisms.enum_morphisms_with_subtyping( &src_type ) {
+                    let dst_type = TypeTerm::Ladder(vec![
+                        dst_halo, dst_ty
+                    ]).normalize();
+
+                    if !current_nodes.contains( &dst_type ) {
+                        let mut new_tree = current_tree.clone();
+                        let src_type = src_type.clone();
+                        new_tree.add_edge(MorphismType { src_type, dst_type }.normalize());
+                        self.queue.push( new_tree );
+                    }
+                }
+            }
+        }
+
+        None
+    }
+}
+
diff --git a/src/test/morphism.rs b/src/test/morphism.rs
index 47bd100..b908101 100644
--- a/src/test/morphism.rs
+++ b/src/test/morphism.rs
@@ -1,5 +1,5 @@
 use {
-    crate::{dict::*, morphism::*, TypeTerm}
+    crate::{dict::*, morphism::*, steiner_tree::*, TypeTerm}
 };
 
 //<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\
@@ -26,9 +26,8 @@ impl Morphism for DummyMorphism {
         }))
     }
 }
- 
-#[test]
-fn test_morphism_path() {
+
+fn morphism_test_setup() -> ( TypeDict, MorphismBase<DummyMorphism> ) {
     let mut dict = TypeDict::new();
     let mut base = MorphismBase::<DummyMorphism>::new( dict.add_typename("Seq".into()) );
 
@@ -67,6 +66,13 @@ fn test_morphism_path() {
         })
     );
 
+    (dict, base)
+}
+
+#[test]
+fn test_morphism_path() {
+    let (mut dict, mut base) = morphism_test_setup();
+
     assert_eq!(
         base.find_morphism_path(MorphismType {
             src_type: dict.parse("ℕ ~ <PosInt 10 BigEndian> ~ <Seq <Digit 10> ~ Char>").unwrap(),
@@ -125,3 +131,31 @@ fn test_morphism_path() {
     );
 }
 
+#[test]
+fn test_steiner_tree() {
+    let (mut dict, mut base) = morphism_test_setup();
+
+
+    let mut steiner_tree_problem = SteinerTreeProblem::new(
+        // source reprs
+        vec![
+            dict.parse("ℕ ~ <PosInt 10 BigEndian> ~ <Seq <Digit 10> ~ Char>").unwrap(),
+        ],
+
+        // destination reprs
+        vec![
+            dict.parse("ℕ ~ <PosInt 2 BigEndian> ~ <Seq <Digit 2> ~ Char>").unwrap(),
+            dict.parse("ℕ ~ <PosInt 10 LittleEndian> ~ <Seq <Digit 10> ~ Char>").unwrap(),            
+            dict.parse("ℕ ~ <PosInt 16 LittleEndian> ~ <Seq <Digit 16> ~ Char>").unwrap()
+        ]
+    );
+
+    if let Some(solution) = steiner_tree_problem.solve_bfs( &dict, &base ) {
+        for e in solution.edges.iter() {
+            eprintln!(" :: {}\n--> {}", dict.unparse(&e.src_type), dict.unparse(&e.dst_type));
+        }
+    } else {
+        eprintln!("no solution");
+    }
+}
+