/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.atlas.query

import org.apache.atlas.query.Expressions._
import org.apache.atlas.typesystem.types.IDataType

class Resolver(srcExpr: Option[Expression] = None, aliases: Map[String, Expression] = Map(),
               connectClassExprToSrc: Boolean = false)
    extends PartialFunction[Expression, Expression] {

    import org.apache.atlas.query.TypeUtils._

    def isDefinedAt(x: Expression) = true

    def apply(e: Expression): Expression = e match {
        case idE@IdExpression(name) => {
            val backExpr = aliases.get(name)
            if (backExpr.isDefined) {
                return new BackReference(name, backExpr.get, None)
            }
            if (srcExpr.isDefined) {
                val fInfo = resolveReference(srcExpr.get.dataType, name)
                if (fInfo.isDefined) {
                    return new FieldExpression(name, fInfo.get, None)
                }
            }
            val cType = resolveAsClassType(name)
            if (cType.isDefined) {
                return new ClassExpression(name)
            }
            val tType = resolveAsTraitType(name)
            if (tType.isDefined) {
                return new TraitExpression(name)
            }
            idE
        }
        case ce@ClassExpression(clsName) if connectClassExprToSrc && srcExpr.isDefined => {
            val fInfo = resolveReference(srcExpr.get.dataType, clsName)
            if (fInfo.isDefined) {
                return new FieldExpression(clsName, fInfo.get, None)
            }
            ce
        }
        case f@UnresolvedFieldExpression(child, fieldName) if child.resolved => {
            var fInfo: Option[FieldInfo] = None

            fInfo = resolveReference(child.dataType, fieldName)
            if (fInfo.isDefined) {
                return new FieldExpression(fieldName, fInfo.get, Some(child))
            }
            val tType = resolveAsTraitType(fieldName)
            if (tType.isDefined) {
              return new FieldExpression(fieldName, FieldInfo(child.dataType, null, null, fieldName), Some(child))
            }
            f
        }
        case isTraitLeafExpression(traitName, classExpression)
            if srcExpr.isDefined && !classExpression.isDefined =>
            isTraitLeafExpression(traitName, srcExpr)
        case hasFieldLeafExpression(traitName, classExpression)
            if srcExpr.isDefined && !classExpression.isDefined =>
            hasFieldLeafExpression(traitName, srcExpr)
        case f@FilterExpression(inputExpr, condExpr) if inputExpr.resolved => {
            val r = new Resolver(Some(inputExpr), inputExpr.namedExpressions)
            return new FilterExpression(inputExpr, condExpr.transformUp(r))
        }
        case SelectExpression(child, selectList) if child.resolved => {
            val r = new Resolver(Some(child), child.namedExpressions)
            return new SelectExpression(child, selectList.map {
                _.transformUp(r)
            })
        }
        case l@LoopExpression(inputExpr, loopExpr, t) if inputExpr.resolved => {
            val r = new Resolver(Some(inputExpr), inputExpr.namedExpressions, true)
            return new LoopExpression(inputExpr, loopExpr.transformUp(r), t)
        }
        case x => x
    }
}

/**
 * - any FieldReferences that explicitly reference the input, can be converted to implicit references
 * - any FieldReferences that explicitly reference a
 */
object FieldValidator extends PartialFunction[Expression, Expression] {

    def isDefinedAt(x: Expression) = true

    def isSrc(e: Expression) = e.isInstanceOf[ClassExpression] || e.isInstanceOf[TraitExpression]

    def validateQualifiedField(srcDataType: IDataType[_]): PartialFunction[Expression, Expression] = {
        case FieldExpression(fNm, fInfo, Some(child))
            if (child.children == Nil && !child.isInstanceOf[BackReference] && child.dataType == srcDataType) =>
            FieldExpression(fNm, fInfo, None)
        case fe@FieldExpression(fNm, fInfo, Some(child)) if isSrc(child) =>
            throw new ExpressionException(fe, s"srcType of field doesn't match input type")
        case hasFieldUnaryExpression(fNm, child) if child.dataType == srcDataType =>
            hasFieldLeafExpression(fNm, Some(child))
        case hF@hasFieldUnaryExpression(fNm, child) if isSrc(child) =>
            throw new ExpressionException(hF, s"srcType of field doesn't match input type")
        case isTraitUnaryExpression(fNm, child) if child.dataType == srcDataType =>
            isTraitLeafExpression(fNm)
        case iT@isTraitUnaryExpression(fNm, child) if isSrc(child) =>
            throw new ExpressionException(iT, s"srcType of field doesn't match input type")
    }

    def validateOnlyFieldReferencesInLoopExpressions(loopExpr: LoopExpression)
    : PartialFunction[Expression, Unit] = {
        case f: FieldExpression => ()
        case x => throw new ExpressionException(loopExpr,
            s"Loop Expression can only contain field references; '${x.toString}' not supported.")
    }

    def apply(e: Expression): Expression = e match {
        case f@FilterExpression(inputExpr, condExpr) => {
            val validatedCE = condExpr.transformUp(validateQualifiedField(inputExpr.dataType))
            if (validatedCE.fastEquals(condExpr)) {
                f
            } else {
                new FilterExpression(inputExpr, validatedCE)
            }
        }
        case SelectExpression(child, selectList) if child.resolved => {
            val v = validateQualifiedField(child.dataType)
            return new SelectExpression(child, selectList.map {
                _.transformUp(v)
            })
        }
        case l@LoopExpression(inputExpr, loopExpr, t) => {
            val validatedLE = loopExpr.transformUp(validateQualifiedField(inputExpr.dataType))
            val l1 = {
                if (validatedLE.fastEquals(loopExpr)) l
                else new LoopExpression(inputExpr, validatedLE, t)
            }
            l1.loopingExpression.traverseUp(validateOnlyFieldReferencesInLoopExpressions(l1))
            l1
        }
        case x => x
    }
}