/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.nlpcraft.probe.mgrs.model

import com.typesafe.scalalogging.LazyLogging
import org.antlr.v4.runtime._
import org.antlr.v4.runtime.tree._
import org.apache.nlpcraft.common._
import org.apache.nlpcraft.model.NCToken
import org.apache.nlpcraft.model.intent.utils.NCDslTokenPredicate
import org.apache.nlpcraft.probe.mgrs.model.antlr4.{NCSynonymDslBaseListener, NCSynonymDslLexer, NCSynonymDslParser}
//import org.apache.nlpcraft.probe.mgrs.model.antlr4._

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

/**
 * Compiler for model synonym DSL.
 */
object NCModelSynonymDslCompiler extends LazyLogging {
    private type Predicate = java.util.function.Function[NCToken, java.lang.Boolean]
    
    def toJavaFunc(alias: String, func: NCToken ⇒ Boolean): Predicate = (tok: NCToken) ⇒ {
        val res = func(tok)
    
        // Store predicate's alias, if any, in token metadata if this token satisfies this predicate.
        // NOTE: token can have multiple aliases associated with it.
        if (res && alias != null) {
            val meta = tok.getMetadata
            
            if (!meta.containsKey(TOK_META_ALIASES_KEY))
                meta.put(TOK_META_ALIASES_KEY, new java.util.HashSet[String]())
            
            val aliases = meta.get(TOK_META_ALIASES_KEY).asInstanceOf[java.util.Set[String]]
            
            aliases.add(alias)
        }
        
        res
    }
    
    /**
     *
     */
    class FiniteStateMachine extends NCSynonymDslBaseListener {
        private val predStack = new mutable.ArrayStack[NCToken ⇒ Boolean] // Stack of predicates.
        private val lvalParts = ArrayBuffer.empty[String] // lval parts collector.
        private val rvalList = ArrayBuffer.empty[String] // rval list collector.
        private var alias: String = _
        private var rval: String = _
    
        /**
         * Gets compiled synonym DSL.
         * 
         * @return
         */
        def getCompiledSynonymDsl: NCModelSynonymDsl = {
            NCModelSynonymDsl(alias, toJavaFunc(alias, predStack.pop()))
        }
    
        override def exitRvalSingle(ctx: NCSynonymDslParser.RvalSingleContext): Unit = {
            rval = ctx.getText.trim()
        }
    
        override def exitRvalList(ctx: NCSynonymDslParser.RvalListContext): Unit = {
            rvalList += rval
        }
    
        override def exitLvalPart(ctx: NCSynonymDslParser.LvalPartContext): Unit = {
            lvalParts += ctx.ID().getText.trim()
        }
    
        override def exitAlias(ctx: NCSynonymDslParser.AliasContext): Unit = {
            alias = ctx.ID().getText.trim()
        }
    
        override def exitItem(ctx: NCSynonymDslParser.ItemContext): Unit = {
            if (ctx.EXCL() != null) {
                val p = predStack.pop
        
                predStack.push(new Function[NCToken, Boolean] {
                    override def apply(tok: NCToken): Boolean = !p.apply(tok)
                    override def toString: String = s"!$p"
                })
            }
            else if (ctx.AND() != null) {
                // Note that stack is LIFO so order is flipped.
                val p2 = predStack.pop
                val p1 = predStack.pop
        
                predStack.push(new Function[NCToken, Boolean] {
                    override def apply(tok: NCToken): Boolean = {
                        // To bypass any possible compiler optimizations.
                        if (!p1.apply(tok))
                            false
                        else if (!p2.apply(tok))
                            false
                        else
                            true
                    }
                    override def toString: String = s"$p1 && $p2"
                })
            }
            else if (ctx.OR() != null) {
                // Note that stack is LIFO so order is flipped.
                val p2 = predStack.pop
                val p1 = predStack.pop
        
                predStack.push(new Function[NCToken, Boolean] {
                    override def apply(tok: NCToken): Boolean = {
                        // To bypass any possible compiler optimizations.
                        if (p1.apply(tok))
                            true
                        else if (p2.apply(tok))
                            true
                        else
                            false
                    }
                    override def toString: String = s"$p1 || $p2"
                })
            }
            else if (ctx.RPAREN() != null && ctx.LPAREN() != null) {
                val p = predStack.pop
        
                predStack.push(new Function[NCToken, Boolean] {
                    override def apply(tok: NCToken): Boolean = p.apply(tok)
                    override def toString: String = s"($p)"
                })
            }
            
            // In all other cases the current predicate is already on the top of the stack.
        }
    
        /**
         *
         * @param rv
         * @return
         */
        private def mkRvalObject(rv: String): Any = {
            if (rv == "null") null // Try 'null'.
            else if (rv == "true") true // Try 'boolean'.
            else if (rv == "false") false // Try 'boolean'.
            // Only numeric values below...
            else {
                // Strip '_' from numeric values.
                val rvalNum = rv.replaceAll("_", "")
        
                try
                    java.lang.Integer.parseInt(rvalNum) // Try 'int'.
                catch {
                    case _: NumberFormatException ⇒
                        try
                            java.lang.Long.parseLong(rvalNum) // Try 'long'.
                        catch {
                            case _: NumberFormatException ⇒
                                try
                                    java.lang.Double.parseDouble(rvalNum) // Try 'double'.
                                catch {
                                    case _: NumberFormatException ⇒ rv // String by default.
                                }
                        }
                }
            }
        }
    
        override def exitPredicate(ctx: NCSynonymDslParser.PredicateContext): Unit = {
            var lval: String = null
            var lvalFunc: String = null
            var op: String = null
    
            def getLvalNode(tree: ParseTree): String =
                tree.getChild(if (tree.getChildCount == 1) 0 else 1).getText.trim
    
            if (ctx.children.size() == 3) {
                lval = getLvalNode(ctx.getChild(0))
                op = ctx.getChild(1).getText.trim
            }
            else {
                lvalFunc = ctx.getChild(0).getText.trim
                lval = getLvalNode(ctx.getChild(2))
                op = ctx.getChild(4).getText.trim
            }
            
            val pred = new NCDslTokenPredicate(
                lvalParts.asJava,
                lvalFunc,
                lval,
                op,
                if (rvalList.isEmpty) mkRvalObject(rval) else rvalList.map(mkRvalObject).asJava
            )
    
            predStack.push(new Function[NCToken, Boolean] {
                override def apply(tok: NCToken): Boolean = pred.apply(tok)
                override def toString: String = pred.toString
            })
    
            // Reset.
            lvalParts.clear()
            rvalList.clear()
            rval = null
        }
    }
    
    /**
     * Custom error handler.
     */
    class CompilerErrorListener(dsl: String) extends BaseErrorListener {
        /**
         *
         * @param len
         * @param pos
         * @return
         */
        private def makeCharPosPointer(len: Int, pos: Int): String = {
            val s = (for (_ ← 1 to len) yield '-').mkString("")
            
            s.substring(0, pos - 1) + '^' + s.substring(pos)
        }
        
        /**
         *
         * @param recognizer
         * @param offendingSymbol
         * @param line
         * @param charPos
         * @param msg
         * @param e
         */
        override def syntaxError(
            recognizer: Recognizer[_, _],
            offendingSymbol: scala.Any,
            line: Int,
            charPos: Int,
            msg: String,
            e: RecognitionException): Unit = {
            
            val errMsg = s"Synonym DSL syntax error at line $line:$charPos - $msg"
    
            logger.error(errMsg)
            logger.error(s"  |-- ${c("Expression:")} $dsl")
            logger.error(s"  +-- ${c("Error:")}      ${makeCharPosPointer(dsl.length, charPos)}")
            
            throw new NCE(errMsg)
        }
    }

    /**
     *
     * @param dsl Synonym DSL to parse.
     * @return
     */
    def parse(dsl: String): NCModelSynonymDsl = {
        require(dsl != null)
        
        // ANTLR4 armature.
        val lexer = new NCSynonymDslLexer(CharStreams.fromString(dsl))
        val tokens = new CommonTokenStream(lexer)
        val parser = new NCSynonymDslParser(tokens)
        
        // Set custom error handlers.
        lexer.removeErrorListeners()
        parser.removeErrorListeners()
        lexer.addErrorListener(new CompilerErrorListener(dsl))
        parser.addErrorListener(new CompilerErrorListener(dsl))
        
        // State automata.
        val fsm = new FiniteStateMachine
        
        // Parse the input DSL and walk built AST.
        (new ParseTreeWalker).walk(fsm, parser.synonym())
        
        fsm.getCompiledSynonymDsl
    }
}
