blob: 0c01fb14d9e93573f8e3f9865a9738906ac96cd2 [file] [log] [blame]
/****************************************************************
* Licensed to the Apache Software Foundation (ASF) under one *
* or more contributor license agreements. See the NOTICE file *
* distributed with this work for additional information *
* regarding copyright ownership. The ASF licenses this file *
* to you under the Apache License, Version 2.0 (the *
* "License"); you may not use this file except in compliance *
* with the License. You may obtain a copy of the License at *
* *
* http://www.apache.org/licenses/LICENSE-2.0 *
* *
* Unless required by applicable law or agreed to in writing, *
* software distributed under the License is distributed on an *
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY *
* KIND, either express or implied. See the License for the *
* specific language governing permissions and limitations *
* under the License. *
****************************************************************/
package org.apache.james.jmap.routes
import eu.timepit.refined.numeric.NonNegative
import eu.timepit.refined.refineV
import eu.timepit.refined.types.numeric.NonNegInt
import org.apache.james.jmap.core.Invocation.{Arguments, MethodCallId, MethodName}
import org.apache.james.jmap.core.{ClientId, Id, Invocation, ServerId}
import org.apache.james.jmap.json.BackReferenceDeserializer
import play.api.libs.json.{JsArray, JsError, JsObject, JsResult, JsString, JsSuccess, JsValue, Reads}
import scala.collection.IndexedSeq
import scala.collection.Map
import scala.collection.immutable.{Map => ImmutableMap}
import scala.util.Try
sealed trait JsonPathPart
case object WildcardPart extends JsonPathPart
case class PlainPart(name: String) extends JsonPathPart {
def read(jsValue: JsValue): JsResult[JsValue] = jsValue match {
case JsObject(underlying) => underlying.get(name).map(JsSuccess(_)).getOrElse(JsError(s"Expected path $name was missing"))
case _ => JsError("Expecting a JsObject but got a different structure")
}
}
object ArrayElementPart {
def parse(string: String): Option[ArrayElementPart] = {
if (string forall Character.isDigit) {
Try(string.toInt)
.fold(_ => None, fromInt)
} else {
None
}
}
private def fromInt(position: Int): Option[ArrayElementPart] =
refineV[NonNegative](position)
.fold(_ => None,
ref => Some(ArrayElementPart(ref)))
}
case class ArrayElementPart(position: NonNegInt) extends JsonPathPart {
def read(jsValue: JsValue): JsResult[JsValue] = jsValue match {
case JsArray(values) => values.lift(position.value)
.map(JsSuccess(_))
.getOrElse(JsError(s"Supplied array have no $position element"))
case _ => JsError("Expecting a JsArray but got a different structure")
}
}
object JsonPath {
def parse(string: String): JsonPath = JsonPath(string.split('/').toList
.flatMap {
case "" => Nil
case "*" => List(WildcardPart)
case part: String => ArrayElementPart.parse(part)
.orElse(Some(PlainPart(part)))
})
}
case class JsonPath(parts: List[JsonPathPart]) {
def evaluate(jsValue: JsValue): JsResult[JsValue] = parts match {
case Nil => JsSuccess(jsValue)
case head :: tail =>
val tailAsJsonPath = JsonPath(tail)
head match {
case part: PlainPart => part.read(jsValue).flatMap(tailAsJsonPath.evaluate)
case part: ArrayElementPart => part.read(jsValue).flatMap(tailAsJsonPath.evaluate)
case WildcardPart => tailAsJsonPath.readWildcard(jsValue)
}
}
private def readWildcard(jsValue: JsValue): JsResult[JsValue] = jsValue match {
case JsArray(arrayItems) =>
val evaluationResults: IndexedSeq[JsResult[JsValue]] = arrayItems.map(evaluate)
evaluationResults.find(_.isInstanceOf[JsError])
.getOrElse(JsSuccess(expendArray(evaluationResults)))
case _ => JsError("Expecting an array")
}
private def expendArray(evaluationResults: IndexedSeq[JsResult[JsValue]]): JsArray =
JsArray(evaluationResults
.map(_.get)
.flatMap({
case JsArray(nestedArray) => nestedArray
case other: JsValue => List(other)
}))
}
case class BackReference(name: MethodName, path: JsonPath, resultOf: MethodCallId) {
def resolve(invocation: Invocation): JsResult[JsValue] = if (!(invocation.methodName equals name)) {
JsError(s"$resultOf references a ${invocation.methodName} method")
} else {
path.evaluate(invocation.arguments.value)
}
}
case class InvalidResultReferenceException(message: String) extends IllegalArgumentException
case class ProcessingContext(private val creationIds: ImmutableMap[ClientId, ServerId],
private val invocations: ImmutableMap[MethodCallId, Invocation]) {
def recordCreatedId(clientId: ClientId, serverId: ServerId): ProcessingContext = ProcessingContext(creationIds + (clientId -> serverId), invocations)
def recordInvocation(invocation: Invocation): ProcessingContext = ProcessingContext(creationIds, invocations + (invocation.methodCallId -> invocation))
def resolveBackReferences(invocation: Invocation): Either[InvalidResultReferenceException, Invocation] =
backReferenceResolver.reads(invocation.arguments.value) match {
case JsError(e) => Left(InvalidResultReferenceException(e.toString()))
case JsSuccess(JsObject(underlying), _) => Right(Invocation(methodName = invocation.methodName,
methodCallId = invocation.methodCallId,
arguments = Arguments(JsObject(underlying))))
case others: JsSuccess[JsValue] => Left(InvalidResultReferenceException(s"Unexpected value $others"))
}
private val backReferenceResolver: Reads[JsValue] = {
case JsArray(value) => resolveBackReferences(value)
case JsObject(underlying) => resolveBackReference(underlying)
case JsString(value) if value.startsWith("#") => resolveCreationId(value)
.fold(_ => JsSuccess(JsString(value)),
serverId => JsSuccess(JsString(serverId.value.value)))
case others: JsValue => JsSuccess(others)
}
private def resolveBackReferences(array: IndexedSeq[JsValue]): JsResult[JsValue] = {
val results: IndexedSeq[JsResult[JsValue]] = array.map(backReferenceResolver.reads)
results.find(_.isError)
.getOrElse(JsSuccess(JsArray(results.map(_.get))))
}
private def resolveBackReference(underlying: Map[String, JsValue]): JsResult[JsObject] = {
val resolutions = underlying.map(resolveBackReference)
val firstError = resolutions.flatMap(_.left.toOption).headOption
val transformedMap = resolutions.flatMap(_.toOption).toMap
firstError.getOrElse(JsSuccess(JsObject(transformedMap)))
}
private def resolveBackReference(entry: (String, JsValue)): Either[JsError, (String, JsValue)] = {
if (entry._1.startsWith("#")) {
val newEntry: String = entry._1.substring(1)
BackReferenceDeserializer.deserializeBackReference(entry._2) match {
case JsSuccess(backReference, _) => resolveBackReference(newEntry, backReference)
// If the JSON object is not a back-reference continue parsing (it could be a creationId)
case _: JsError =>
backReferenceResolver.reads(entry._2)
.fold(e => Left(JsError(e)),
json => resolveCreationId(entry._1)
.fold(_ => Right((entry._1, json)),
serverId => Right((serverId.value.value, json))))
}
} else {
propagateBackReferenceResolution(entry)
}
}
private def resolveBackReference(newEntry: String, backReference: BackReference): Either[JsError, (String, JsValue)] =
resolve(backReference) match {
case e: JsError => Left(e)
case JsSuccess(resolvedBackReference, _) => Right((newEntry, resolvedBackReference))
}
private def propagateBackReferenceResolution(entry: (String, JsValue)): Either[JsError, (String, JsValue)] =
backReferenceResolver.reads(entry._2) match {
case e: JsError => Left(e)
case JsSuccess(newValue, _) => Right((entry._1, newValue))
}
private def retrieveInvocation(callId: MethodCallId): Option[Invocation] = invocations.get(callId)
private def resolve(backReference: BackReference): JsResult[JsValue] = retrieveInvocation(backReference.resultOf)
.map(backReference.resolve)
.getOrElse(JsError("Back reference could not be resolved"))
private def resolveCreationId(creationId: String): Either[IllegalArgumentException, ServerId] =
Id.validate(creationId)
.flatMap(id => resolveServerId(ClientId(id)))
private def resolveServerId(id: ClientId): Either[IllegalArgumentException, ServerId] =
id.retrieveOriginalClientId
.map(maybePreviousClientId => maybePreviousClientId.flatMap(previousClientId => retrieveServerId(previousClientId)
.map(Right(_))
.getOrElse(Left[IllegalArgumentException, ServerId](new IllegalArgumentException(s"$id was not used in previously defined creationIds")))))
.getOrElse(Right(ServerId(id.value)))
private def retrieveServerId(clientId: ClientId): Option[ServerId] = creationIds.get(clientId)
}