From db7e0173c40043dc17c0e2628ef6508aff423045 Mon Sep 17 00:00:00 2001
From: Michael Sippel <micha@fragmental.art>
Date: Fri, 27 Sep 2024 12:15:40 +0200
Subject: [PATCH] add steiner tree solver based on shortest path

---
 src/steiner_tree.rs | 93 ++++++++++++++++++++++++++++++++++++++++++---
 1 file changed, 87 insertions(+), 6 deletions(-)

diff --git a/src/steiner_tree.rs b/src/steiner_tree.rs
index f5338e9..f854dd9 100644
--- a/src/steiner_tree.rs
+++ b/src/steiner_tree.rs
@@ -17,10 +17,14 @@ use {
 pub struct SteinerTree {
     weight: u64,
     goals: Vec< TypeTerm >,
-    pub edges: Vec< MorphismType >,
+    edges: Vec< MorphismType >,
 }
 
 impl SteinerTree {
+    pub fn into_edges(self) -> Vec< MorphismType > {
+        self.edges
+    }
+
     fn add_edge(&mut self, ty: MorphismType) {
         self.weight += 1;
 
@@ -71,6 +75,72 @@ impl SteinerTree {
     }
 }
 
+
+pub struct PathApproxSteinerTreeSolver {
+    root: TypeTerm,
+    leaves: Vec< TypeTerm >
+}
+
+impl PathApproxSteinerTreeSolver {
+    pub fn new(
+        root: TypeTerm,
+        leaves: Vec<TypeTerm>
+    ) -> Self {
+        PathApproxSteinerTreeSolver {
+            root, leaves
+        }
+    }
+
+    pub fn solve<M: Morphism + Clone>(self, morphisms: &MorphismBase<M>) -> Option< SteinerTree > {
+        let mut tree = Vec::<MorphismType>::new();
+
+        for goal in self.leaves {
+            // try to find shortest path from root to current leaf
+            if let Some(new_path) = morphisms.find_morphism_path(
+                MorphismType {
+                    src_type: self.root.clone(),
+                    dst_type: goal.clone()
+                }
+            ) {
+                // reduce new path so that it does not collide with any existing path
+                let mut src_type = self.root.clone();
+                let mut new_path_iter = new_path.into_iter().peekable();
+
+                // check all existing nodes..
+                for mt in tree.iter() {
+//                    assert!( mt.src_type == &src_type );
+                    if let Some(t) = new_path_iter.peek() {
+                        if &mt.dst_type == t {
+                            // eliminate this node from new path
+                            src_type = new_path_iter.next().unwrap().clone();
+                        }
+                    } else {
+                        break;
+                    }
+                }
+
+                for dst_type in new_path_iter {
+                    tree.push(MorphismType {
+                        src_type: src_type.clone(),
+                        dst_type: dst_type.clone()
+                    });
+                    src_type = dst_type;
+                }
+            } else {
+                eprintln!("could not find path\nfrom {:?}\nto {:?}", &self.root, &goal);
+                return None;
+            }
+        }
+
+        Some(SteinerTree {
+            weight: 0,
+            goals: vec![],
+            edges: tree
+        })
+    }
+}
+
+
 /* given a representation tree with the available
  * represenatations `src_types`, try to find
  * a sequence of morphisms that span up all
@@ -122,8 +192,15 @@ impl SteinerTreeProblem {
         );
         self.queue.pop()
     }
+/*
+    pub fn solve_approx_path<M: Morphism + Clone>(&mut self, morphisms: &MorphismBase<M>) -> Option< SteinerTree > {
+        if let Some(master) = self.src_types.first() {
 
-    pub fn solve_bfs<M: Morphism + Clone>(&mut self, dict: &crate::dict::TypeDict, morphisms: &MorphismBase<M>) -> Option< SteinerTree > {
+            
+        }
+    }
+*/
+    pub fn solve_bfs<M: Morphism + Clone>(&mut self, morphisms: &MorphismBase<M>) -> Option< SteinerTree > {
 
         // take the currently smallest tree and extend it by one step
         while let Some( mut current_tree ) = self.next() {
@@ -140,16 +217,20 @@ impl SteinerTreeProblem {
             }
 
             // extend the tree by one edge and add it to the queue
-            for src_type in current_nodes.iter() {
+            for src_type in current_nodes {
                 for (dst_halo,dst_ty) in morphisms.enum_morphisms_with_subtyping( &src_type ) {
                     let dst_type = TypeTerm::Ladder(vec![
                         dst_halo, dst_ty
                     ]).normalize();
 
-                    if !current_nodes.contains( &dst_type ) {
+                    if current_tree.contains( &dst_type ).is_none() {
                         let mut new_tree = current_tree.clone();
-                        let src_type = src_type.clone();
-                        new_tree.add_edge(MorphismType { src_type, dst_type }.normalize());
+                        {
+                            let src_type = src_type.clone();
+                            let dst_type = dst_type.clone();
+                            new_tree.add_edge(MorphismType { src_type, dst_type }.normalize());
+                        }
+
                         self.queue.push( new_tree );
                     }
                 }