use {
    crate::{
        expr::{LTExpr, Statement, TypeError, TypeTag},
        lexer::{LTIRLexer, LTIRToken, LexError, InputRegionTag},
    },
    std::{
        iter::Peekable,
        sync::{Arc, RwLock},
    },
};

#[derive(Clone, Debug, PartialEq)]
pub enum ParseError {
    LexError(LexError),
    UnexpectedClose,
    UnexpectedEnd,
    UnexpectedToken,
    TypeParseError(laddertypes::parser::ParseError)
}

pub fn parse_expect<It>(
    tokens: &mut Peekable<It>,
    expected_token: LTIRToken,
) -> Result<(), (InputRegionTag, ParseError)>
where It: Iterator<Item = (InputRegionTag, Result<LTIRToken, LexError>)>
{
    match tokens.next() {
        Some((region, Ok(t))) => {
            if t == expected_token {
                Ok(())
            } else {
                Err((region, ParseError::UnexpectedToken))
            }
        }
        Some((region, Err(err))) => Err((region, ParseError::LexError(err))),
        None => Err((InputRegionTag::default(), ParseError::UnexpectedEnd)),
    }
}

/* parse symbol name
 */
pub fn parse_symbol<It>(tokens: &mut Peekable<It>) -> Result<String, (InputRegionTag, ParseError)>
where It: Iterator<Item = (InputRegionTag, Result<LTIRToken, LexError>)>
{
    match tokens.next() {
        Some((region, Ok(LTIRToken::Symbol(name)))) => Ok(name),
        Some((region, Ok(_))) => Err((region, ParseError::UnexpectedToken)),
        Some((region, Err(err))) => Err((region, ParseError::LexError(err))),
        None => Err((InputRegionTag::default(), ParseError::UnexpectedEnd)),
    }
}

/* parse an optional type annotation
 *  `: T`
 */
pub fn parse_type_tag<It>(
    typectx: &Arc<RwLock<laddertypes::dict::TypeDict>>,
    tokens: &mut Peekable<It>,
) -> Result<Option<laddertypes::TypeTerm>, (InputRegionTag, ParseError)>
where It: Iterator<Item = (InputRegionTag, Result<LTIRToken, LexError>)>
{
    let peek = { tokens.peek().cloned() };
    if let Some((region, peektok)) = peek {
        match peektok {
            Ok(LTIRToken::AssignType(typeterm_str)) => {
                tokens.next();
                match typectx.write().unwrap().parse(typeterm_str.as_str()) {
                    Ok(typeterm) => Ok(Some(typeterm)),
                    Err(parse_error) => Err((region, ParseError::TypeParseError(parse_error))),
                }
            }
            _ => Ok(None),
        }
    } else {
        Ok(None)
    }
}

#[derive(Debug, PartialEq, Eq)]
pub enum VariableBinding {
    Atomic {
        symbol: String,
        typtag: Option<laddertypes::TypeTerm>
    },
    Struct {
        members: Vec< VariableBinding >,
        typtag: Option<laddertypes::TypeTerm>
    }
}

impl VariableBinding {
    pub fn flatten(self) -> Vec<(String, Option<laddertypes::TypeTerm>)> {
        match self {
            VariableBinding::Atomic{ symbol, typtag } =>
                vec![ (symbol, typtag) ],
            VariableBinding::Struct{ members, typtag } =>
                members
                    .into_iter()
                    .map(|a| a.flatten().into_iter())
                    .flatten()
                    .collect()
        }
   }
}

/* parse a symbol binding of the form
 *     `x`
 * or  `x : T`
 */
pub fn parse_binding_expr<It>(
    typectx: &Arc<RwLock<laddertypes::dict::TypeDict>>,
    tokens: &mut Peekable<It>,
) -> Result< VariableBinding, (InputRegionTag, ParseError)>
where It: Iterator<Item = (InputRegionTag, Result<LTIRToken, LexError>)>
{
    if let Some((region, peektok)) = tokens.peek().clone() {
        match peektok {
            Ok(LTIRToken::BlockOpen) => {
                Ok(VariableBinding::Struct {
                    members: parse_binding_block(typectx, tokens)?,
                    typtag: parse_type_tag(typectx, tokens)?
                })
            }
            Ok(LTIRToken::Symbol(_)) => {
                Ok(VariableBinding::Atomic{
                    symbol: parse_symbol(tokens)?,
                    typtag: parse_type_tag(typectx, tokens)?
                })
            }
            Err(err) => Err((*region, ParseError::LexError(err.clone()))),
            _ => Err((*region, ParseError::UnexpectedToken))
        }
    } else {
        Err((InputRegionTag::default(), ParseError::UnexpectedEnd))
    }
}

/* parse a block of symbol bidnings
 * `{ x:T; y:U; ... }`
 */
pub fn parse_binding_block<It>(
    typectx: &Arc<RwLock<laddertypes::dict::TypeDict>>,
    tokens: &mut Peekable<It>,
) -> Result< Vec<VariableBinding>, (InputRegionTag, ParseError)>
where It: Iterator<Item = (InputRegionTag, Result<LTIRToken, LexError>)>
{
    let mut last_region = InputRegionTag::default();

    let _ = parse_expect(tokens, LTIRToken::BlockOpen)?;

    let mut bindings = Vec::new();
    while let Some((region, peektok)) = tokens.peek() {
        last_region = *region;
        match peektok {
            Ok(LTIRToken::BlockClose) => {
                tokens.next();
                return Ok(bindings);
            }
            Ok(LTIRToken::StatementSep) => {
                tokens.next();
            }
            Ok(_) => {
                bindings.push(parse_binding_expr(typectx, tokens)?);
            }
            Err(err) => {
                return Err((last_region, ParseError::LexError(err.clone())));
            }
        }
    }

    Err((last_region, ParseError::UnexpectedEnd))
}

pub fn parse_statement<It>(
    typectx: &Arc<RwLock<laddertypes::dict::TypeDict>>,
    tokens: &mut Peekable<It>,
) -> Result<crate::expr::Statement, (InputRegionTag, ParseError)>
where It: Iterator<Item = (InputRegionTag, Result<LTIRToken, LexError>)>
{
    if let Some((region, peektok)) = tokens.peek() {
        match peektok {
            Ok(LTIRToken::Symbol(sym)) => {
                match sym.as_str() {
                    "!" => {
                        tokens.next();
                        // todo accept address-expression instead of symbol
                        let name = parse_symbol(tokens)?;
                        let val_expr = parse_expr(typectx, tokens)?;
                        let _ = parse_expect(tokens, LTIRToken::StatementSep)?;

                        Ok(Statement::Assignment {
                            var_id: name,
                            val_expr,
                        })
                    }
                    "let" => {
                        tokens.next();
                        let name = parse_symbol(tokens)?;
                        let typ = parse_type_tag(typectx, tokens)?;
                        /* todo
                        let mut variable_bindings = parse_binding_expr(typectx, tokens)?;
                        */
                        let _ = parse_expect(tokens, LTIRToken::AssignValue);
                        let val_expr = parse_expr(typectx, tokens)?;
                        let _ = parse_expect(tokens, LTIRToken::StatementSep)?;

                        Ok(Statement::LetAssign {
                            typ: match typ {
                                Some(t) => Some(Ok(t)),
                                None => None
                            },
                            var_id: name,
                            val_expr,
                        })
                    }
                    "while" => {
                        tokens.next();
                        let _ = parse_expect(tokens, LTIRToken::ExprOpen)?;
                        let cond = parse_expr(typectx, tokens)?;
                        let _ = parse_expect(tokens, LTIRToken::ExprClose)?;
                        Ok(Statement::WhileLoop {
                            condition: cond,
                            body: parse_statement_block(typectx, tokens)?,
                        })
                    }
                    "return" => {
                        tokens.next();
                        let expr = parse_expr(typectx, tokens)?;
                        let _ = parse_expect(tokens, LTIRToken::StatementSep)?;
                        Ok(Statement::Return(parse_expr(typectx, tokens)?))
                    }
                    _ => {
                        let expr = parse_expr(typectx, tokens)?;
                        let _ = parse_expect(tokens, LTIRToken::StatementSep)?;
                        Ok(Statement::Expr(expr))
                    }
                }
            }
            Ok(_) => {
                let expr = parse_expr(typectx, tokens)?;
                let _ = parse_expect(tokens, LTIRToken::StatementSep)?;
                Ok(Statement::Expr(expr))
            }
            Err(err) => Err((*region, ParseError::LexError(err.clone()))),
        }
    } else {
        Err((InputRegionTag::default(), ParseError::UnexpectedEnd))
    }
}

pub fn parse_statement_block<It>(
    typectx: &Arc<RwLock<laddertypes::dict::TypeDict>>,
    tokens: &mut Peekable<It>,
) -> Result<Vec<Statement>, (InputRegionTag, ParseError)>
where It: Iterator<Item = (InputRegionTag, Result<LTIRToken, LexError>)>
{
    let _ = parse_expect(tokens, LTIRToken::BlockOpen)?;

    let mut statements = Vec::new();
    while let Some((region, peektok)) = tokens.peek() {
        match peektok {
            Ok(LTIRToken::BlockClose) => {
                tokens.next();
                return Ok(statements);
            }
            Ok(_) => {
                statements.push(parse_statement(typectx, tokens)?);
            }
            Err(err) => {
                return Err((*region, ParseError::LexError(err.clone())));
            }
        }
    }

    Err((InputRegionTag::default(), ParseError::UnexpectedEnd))
}

pub fn parse_atom<It>(
    tokens: &mut Peekable<It>,
) -> Result<crate::expr::LTExpr, (InputRegionTag, ParseError)>
where It: Iterator<Item = (InputRegionTag, Result<LTIRToken, LexError>)>
{
    match tokens.next() {
        Some((region, Ok(LTIRToken::Symbol(sym)))) => Ok(LTExpr::symbol(sym.as_str())),
        Some((region, Ok(LTIRToken::Char(c)))) => Ok(LTExpr::lit_uint(c as u64)),
        Some((region, Ok(LTIRToken::Num(n)))) => Ok(LTExpr::lit_uint(n as u64)),
        Some((region, Ok(_))) => Err((region, ParseError::UnexpectedToken)),
        Some((region, Err(err))) => Err((region, ParseError::LexError(err))),
        None => Err((InputRegionTag::default(), ParseError::UnexpectedEnd)),
    }
}

pub fn parse_expr<It>(
    typectx: &Arc<RwLock<laddertypes::dict::TypeDict>>,
    tokens: &mut Peekable<It>,
) -> Result<crate::expr::LTExpr, (InputRegionTag, ParseError)>
where It: Iterator<Item = (InputRegionTag, Result<LTIRToken, LexError>)>
{
    let mut children = Vec::new();

    while let Some((region, tok)) = tokens.peek() {
        match tok {
            Ok(LTIRToken::Lambda) => {
                if children.len() == 0 {
                    tokens.next();

                    let mut variable_bindings = parse_binding_expr(typectx, tokens)?;
                    let _ = parse_expect(tokens, LTIRToken::MapsTo);
                    let body = parse_expr(typectx, tokens)?;

                    return Ok(LTExpr::Abstraction {
                        args: variable_bindings.flatten().into_iter().map(|(s,t)| (s,t.map(|t|Ok(t))) ).collect(),
                        body: Box::new(body),
                    });
                } else {
                    return Err((*region, ParseError::UnexpectedToken));
                }
            }
            Ok(LTIRToken::ExprOpen) => {
                tokens.next();
                while let Some((region, peektok)) = tokens.peek() {
                    match peektok {
                        Ok(LTIRToken::ExprClose) => {
                            tokens.next();
                            break;
                        }
                        _ => {}
                    }
                    children.push(parse_expr(typectx, tokens)?);
                }
            }
            Ok(LTIRToken::ExprClose) => {
                break;
            }
            Ok(LTIRToken::BlockOpen) => {
                children.push(LTExpr::block(parse_statement_block(typectx, tokens)?));
            }
            Ok(LTIRToken::BlockClose) => {
                break;
            }
            Ok(LTIRToken::StatementSep) => {
                break;
            }
            Ok(LTIRToken::Symbol(name)) => match name.as_str() {
                "if" => {
                    tokens.next();
                    let _ = parse_expect(tokens, LTIRToken::ExprOpen)?;
                    let cond = parse_expr(typectx, tokens)?;
                    let _ = parse_expect(tokens, LTIRToken::ExprClose)?;
                    let if_expr = LTExpr::block(parse_statement_block(typectx, tokens)?);
                    let mut else_expr = LTExpr::block(vec![]);

                    if let Some((region, peektok)) = tokens.peek() {
                        if let Ok(LTIRToken::Symbol(name)) = peektok {
                            if name == "else" {
                                tokens.next();
                                else_expr = parse_expr(typectx, tokens)?;
                            }
                        }
                    }

                    children.push(LTExpr::Branch {
                        condition: Box::new(cond),
                        if_expr: Box::new(if_expr),
                        else_expr: Box::new(else_expr),
                    });
                }
                name => {
                    children.push(parse_atom(tokens)?);
                }
            },
            Ok(atom) => {
                children.push(parse_atom(tokens)?);
            }
            Err(err) => {
                return Err((*region, ParseError::LexError(err.clone())));
            }
        }
    }

    if children.len() > 0 {
        let head = children.remove(0);
        Ok(LTExpr::Application {
            typ: None,
            head: Box::new(head),
            body: children,
        })
    } else {
        Err((InputRegionTag::default(), ParseError::UnexpectedEnd))
    }
}



mod tests {
    use std::sync::{Arc, RwLock};
    
    #[test]
    fn test_parse_atomic_binding() {
        let mut lexer = crate::lexer::LTIRLexer::from("x".chars()).peekable();
        let typectx = Arc::new(RwLock::new(laddertypes::dict::TypeDict::new()));
        let bindings = crate::parser::parse_binding_expr( &typectx, &mut lexer );

        assert_eq!(
            bindings,
            Ok(crate::parser::VariableBinding::Atomic{
                symbol: "x".into(),
                typtag: None
            })
        );
    }

    #[test]
    fn test_parse_typed_atomic_binding() {
        let mut lexer = crate::lexer::LTIRLexer::from("x:T".chars()).peekable();
        let typectx = Arc::new(RwLock::new(laddertypes::dict::TypeDict::new()));
        let bindings = crate::parser::parse_binding_expr( &typectx, &mut lexer );

        assert_eq!(
            bindings,
            Ok(crate::parser::VariableBinding::Atomic{
                symbol: "x".into(),
                typtag: Some(typectx.write().unwrap().parse("T").unwrap())
            })
        );
    }

    #[test]
    fn test_parse_struct_binding() {
        let mut lexer = crate::lexer::LTIRLexer::from("{x y}".chars()).peekable();
        let typectx = Arc::new(RwLock::new(laddertypes::dict::TypeDict::new()));
        let bindings = crate::parser::parse_binding_expr( &typectx, &mut lexer );

        assert_eq!(
            bindings,
            Ok(crate::parser::VariableBinding::Struct{
                members: vec![
                    crate::parser::VariableBinding::Atomic{ symbol: "x".into(), typtag: None },                    
                    crate::parser::VariableBinding::Atomic{ symbol: "y".into(), typtag: None }
                ],
                typtag: None
            })
        );
    }

    #[test]
    fn test_parse_typed_struct_binding1() {
        let mut lexer = crate::lexer::LTIRLexer::from("{x y}:T".chars()).peekable();
        let typectx = Arc::new(RwLock::new(laddertypes::dict::TypeDict::new()));
        let bindings = crate::parser::parse_binding_expr( &typectx, &mut lexer );

        assert_eq!(
            bindings,
            Ok(crate::parser::VariableBinding::Struct{
                members: vec![
                    crate::parser::VariableBinding::Atomic{ symbol: "x".into(), typtag: None },                    
                    crate::parser::VariableBinding::Atomic{ symbol: "y".into(), typtag: None }
                ],
                typtag: Some(typectx.write().unwrap().parse("T").unwrap())
            })
        );
    }

    #[test]
    fn test_parse_typed_struct_binding2() {
        let mut lexer = crate::lexer::LTIRLexer::from("{x:U; y}:T".chars()).peekable();
        let typectx = Arc::new(RwLock::new(laddertypes::dict::TypeDict::new()));
        let bindings = crate::parser::parse_binding_expr( &typectx, &mut lexer );

        let type_u = typectx.write().unwrap().parse("U").unwrap();
        let type_t = typectx.write().unwrap().parse("T").unwrap();

        assert_eq!(
            bindings,
            Ok(crate::parser::VariableBinding::Struct{
                members: vec![
                    crate::parser::VariableBinding::Atomic{
                        symbol: "x".into(),
                        typtag: Some(type_u)
                    },
                    crate::parser::VariableBinding::Atomic{ symbol: "y".into(), typtag: None }
                ],
                typtag: Some(type_t)
            })
        );
    }
}