From 02d8815acd9c29934cd2c95180cd5da7214d9595 Mon Sep 17 00:00:00 2001
From: Michael Sippel <micha@fragmental.art>
Date: Wed, 1 May 2024 15:10:29 +0200
Subject: [PATCH] add param_normalize() to get Parameter-Normal-Form (PNF)

---
 src/lib.rs      |   1 +
 src/pnf.rs      | 113 ++++++++++++++++++++++++++++++++++++++++++++++++
 src/test/mod.rs |   1 +
 src/test/pnf.rs |  41 ++++++++++++++++++
 4 files changed, 156 insertions(+)
 create mode 100644 src/pnf.rs
 create mode 100644 src/test/pnf.rs

diff --git a/src/lib.rs b/src/lib.rs
index 1a270cc..970c555 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -7,6 +7,7 @@ pub mod parser;
 pub mod unparser;
 pub mod curry;
 pub mod lnf;
+pub mod pnf;
 pub mod subtype;
 pub mod unification;
 
diff --git a/src/pnf.rs b/src/pnf.rs
new file mode 100644
index 0000000..4576be5
--- /dev/null
+++ b/src/pnf.rs
@@ -0,0 +1,113 @@
+use crate::term::TypeTerm;
+
+//<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\
+
+impl TypeTerm {
+    /// transmute type into Parameter-Normal-Form (PNF)
+    ///
+    /// Example:
+    /// ```ignore
+    /// <Seq <Digit 10>>~<Seq Char>
+    /// ⇒ <Seq <Digit 10>~Char>
+    /// ```
+    pub fn param_normalize(self) -> Self {
+        match self {
+            TypeTerm::Ladder(mut rungs) => {
+                if rungs.len() > 0 {
+                    // normalize all rungs separately
+                    for r in rungs.iter_mut() {
+                        *r = r.clone().param_normalize();
+                    }
+
+                    // take top-rung
+                    match rungs.remove(0) {
+                        TypeTerm::App(params_top) => {
+                            let mut params_ladders = Vec::new();
+                            let mut tail : Vec<TypeTerm> = Vec::new();
+
+                            // append all other rungs to ladders inside
+                            // the application
+                            for p in params_top {
+                                params_ladders.push(vec![ p ]);
+                            }
+
+                            for r in rungs {
+                                match r {
+                                    TypeTerm::App(mut params_rung) => {
+                                        if params_rung.len() > 0 {
+                                            let mut first_param = params_rung.remove(0); 
+
+                                            if first_param == params_ladders[0][0] {
+                                               for (l, p) in params_ladders.iter_mut().skip(1).zip(params_rung) {
+                                                   l.push(p.param_normalize());
+                                               }
+                                            } else {
+                                               params_rung.insert(0, first_param);
+                                               tail.push(TypeTerm::App(params_rung));
+                                            }
+                                        }
+                                    }
+
+                                    TypeTerm::Ladder(mut rs) => {
+                                        for r in rs {
+                                            tail.push(r.param_normalize());
+                                        }
+                                    }
+
+                                    atomic => {
+                                        tail.push(atomic);
+                                    }
+                                }
+                            }
+
+                            let head = TypeTerm::App(
+                                params_ladders.into_iter()
+                                    .map(
+                                        |mut l| {
+                                            l.dedup();
+                                            match l.len() {
+                                                0 => TypeTerm::unit(),
+                                                1 => l.remove(0),
+                                                _ => TypeTerm::Ladder(l).param_normalize()
+                                            }
+                                        }
+                                    )
+                                    .collect()
+                            );
+
+                            if tail.len() > 0 {
+                                tail.insert(0, head);
+                                TypeTerm::Ladder(tail)
+                            } else {
+                                head
+                            }
+                        }
+
+                        TypeTerm::Ladder(mut r) => {
+                            r.append(&mut rungs);
+                            TypeTerm::Ladder(r)
+                        }
+
+                        atomic => {
+                            rungs.insert(0, atomic);
+                            TypeTerm::Ladder(rungs)
+                        }
+                    }
+                } else {
+                    TypeTerm::unit()
+                }
+            }
+
+            TypeTerm::App(params) => {
+                TypeTerm::App(
+                    params.into_iter()
+                        .map(|p| p.param_normalize())
+                        .collect())
+            }
+
+            atomic => atomic
+        }
+    }
+}
+
+//<<<<>>>><<>><><<>><<<*>>><<>><><<>><<<<>>>>\\
diff --git a/src/test/mod.rs b/src/test/mod.rs
index d116412..29c14bc 100644
--- a/src/test/mod.rs
+++ b/src/test/mod.rs
@@ -3,6 +3,7 @@ pub mod lexer;
 pub mod parser;
 pub mod curry;
 pub mod lnf;
+pub mod pnf;
 pub mod subtype;
 pub mod substitution;
 pub mod unification;
diff --git a/src/test/pnf.rs b/src/test/pnf.rs
new file mode 100644
index 0000000..2303b3e
--- /dev/null
+++ b/src/test/pnf.rs
@@ -0,0 +1,41 @@
+use crate::dict::TypeDict;
+
+#[test]
+fn test_param_normalize() {
+    let mut dict = TypeDict::new();
+
+    assert_eq!(
+        dict.parse("A~B~C").expect("parse error"),
+        dict.parse("A~B~C").expect("parse error").param_normalize(),
+    );
+
+    assert_eq!(
+        dict.parse("<A B>~C").expect("parse error"),
+        dict.parse("<A B>~C").expect("parse error").param_normalize(),
+    );
+
+    assert_eq!(
+        dict.parse("<A B~C>").expect("parse error"),
+        dict.parse("<A B>~<A C>").expect("parse error").param_normalize(),
+    );
+
+    assert_eq!(
+        dict.parse("<A B~C D~E>").expect("parse error"),
+        dict.parse("<A B D>~<A C D>~<A C E>").expect("parse errror").param_normalize(),
+    );
+
+    assert_eq!(
+        dict.parse("<Seq <Digit 10>~Char>").expect("parse error"),
+        dict.parse("<Seq <Digit 10>>~<Seq Char>").expect("parse errror").param_normalize(),
+    );
+
+    assert_eq!(
+        dict.parse("<A <B C~D~E> F~G H H>").expect("parse error"),
+        dict.parse("<A <B C> F H H>
+                   ~<A <B D> F H H>
+                   ~<A <B E> F H H>
+                   ~<A <B E> G H H>").expect("parse errror")
+               .param_normalize(),
+    );
+}
+