blob: fd4af6e949b61caa3728e72929b477005226146e [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kafka.utils
import java.io.IOException
import org.apache.kafka.clients.{ClientRequest, ClientResponse, NetworkClient}
import org.apache.kafka.common.Node
import scala.annotation.tailrec
import scala.collection.JavaConverters._
import org.apache.kafka.common.utils.{Time => JTime}
object NetworkClientBlockingOps {
implicit def networkClientBlockingOps(client: NetworkClient): NetworkClientBlockingOps =
new NetworkClientBlockingOps(client)
}
/**
* Provides extension methods for `NetworkClient` that are useful for implementing blocking behaviour. Use with care.
*
* Example usage:
*
* {{{
* val networkClient: NetworkClient = ...
* import NetworkClientBlockingOps._
* networkClient.blockingReady(...)
* }}}
*/
class NetworkClientBlockingOps(val client: NetworkClient) extends AnyVal {
/**
* Invokes `client.ready` followed by 0 or more `client.poll` invocations until the connection to `node` is ready,
* the timeout expires or the connection fails.
*
* It returns `true` if the call completes normally or `false` if the timeout expires. If the connection fails,
* an `IOException` is thrown instead.
*
* This method is useful for implementing blocking behaviour on top of the non-blocking `NetworkClient`, use it with
* care.
*/
def blockingReady(node: Node, timeout: Long)(implicit time: JTime): Boolean = {
require(timeout >=0, "timeout should be >= 0")
client.ready(node, time.milliseconds()) || pollUntil(timeout) { (_, now) =>
if (client.isReady(node, now))
true
else if (client.connectionFailed(node))
throw new IOException(s"Connection to $node failed")
else false
}
}
/**
* Invokes `client.send` followed by 1 or more `client.poll` invocations until a response is received or a
* disconnection happens (which can happen for a number of reasons including a request timeout).
*
* In case of a disconnection, an `IOException` is thrown.
*
* This method is useful for implementing blocking behaviour on top of the non-blocking `NetworkClient`, use it with
* care.
*/
def blockingSendAndReceive(request: ClientRequest)(implicit time: JTime): ClientResponse = {
client.send(request, time.milliseconds())
pollContinuously { responses =>
val response = responses.find { response =>
response.request.request.header.correlationId == request.request.header.correlationId
}
response.foreach { r =>
if (r.wasDisconnected) {
val destination = request.request.destination
throw new IOException(s"Connection to $destination was disconnected before the response was read")
}
}
response
}
}
/**
* Invokes `client.poll` until `predicate` returns `true` or the timeout expires.
*
* It returns `true` if the call completes normally or `false` if the timeout expires. Exceptions thrown via
* `predicate` are not handled and will bubble up.
*
* This method is useful for implementing blocking behaviour on top of the non-blocking `NetworkClient`, use it with
* care.
*/
private def pollUntil(timeout: Long)(predicate: (Seq[ClientResponse], Long) => Boolean)(implicit time: JTime): Boolean = {
val methodStartTime = time.milliseconds()
val timeoutExpiryTime = methodStartTime + timeout
@tailrec
def recursivePoll(iterationStartTime: Long): Boolean = {
val pollTimeout = timeoutExpiryTime - iterationStartTime
val responses = client.poll(pollTimeout, iterationStartTime).asScala
if (predicate(responses, iterationStartTime)) true
else {
val afterPollTime = time.milliseconds()
if (afterPollTime < timeoutExpiryTime) recursivePoll(afterPollTime)
else false
}
}
recursivePoll(methodStartTime)
}
/**
* Invokes `client.poll` until `collect` returns `Some`. The value inside `Some` is returned.
*
* Exceptions thrown via `collect` are not handled and will bubble up.
*
* This method is useful for implementing blocking behaviour on top of the non-blocking `NetworkClient`, use it with
* care.
*/
private def pollContinuously[T](collect: Seq[ClientResponse] => Option[T])(implicit time: JTime): T = {
@tailrec
def recursivePoll: T = {
// rely on request timeout to ensure we don't block forever
val responses = client.poll(Long.MaxValue, time.milliseconds()).asScala
collect(responses) match {
case Some(result) => result
case None => recursivePoll
}
}
recursivePoll
}
}