blob: 0370564788a2dcb4895449fef77110c5edf10503 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kafka.utils
import java.io.IOException
import org.apache.kafka.clients.{ClientRequest, ClientResponse, NetworkClient}
import org.apache.kafka.common.Node
import org.apache.kafka.common.requests.AbstractRequest
import org.apache.kafka.common.utils.Time
import scala.annotation.tailrec
import scala.collection.JavaConverters._
object NetworkClientBlockingOps {
implicit def networkClientBlockingOps(client: NetworkClient): NetworkClientBlockingOps =
new NetworkClientBlockingOps(client)
}
/**
* Provides extension methods for `NetworkClient` that are useful for implementing blocking behaviour. Use with care.
*
* Example usage:
*
* {{{
* val networkClient: NetworkClient = ...
* import NetworkClientBlockingOps._
* networkClient.blockingReady(...)
* }}}
*/
class NetworkClientBlockingOps(val client: NetworkClient) extends AnyVal {
/**
* Checks whether the node is currently connected, first calling `client.poll` to ensure that any pending
* disconnects have been processed.
*
* This method can be used to check the status of a connection prior to calling `blockingReady` to be able
* to tell whether the latter completed a new connection.
*/
def isReady(node: Node)(implicit time: Time): Boolean = {
val currentTime = time.milliseconds()
client.poll(0, currentTime)
client.isReady(node, currentTime)
}
/**
* Invokes `client.poll` to discard pending disconnects, followed by `client.ready` and 0 or more `client.poll`
* invocations until the connection to `node` is ready, the timeout expires or the connection fails.
*
* It returns `true` if the call completes normally or `false` if the timeout expires. If the connection fails,
* an `IOException` is thrown instead. Note that if the `NetworkClient` has been configured with a positive
* connection timeout, it is possible for this method to raise an `IOException` for a previous connection which
* has recently disconnected.
*
* This method is useful for implementing blocking behaviour on top of the non-blocking `NetworkClient`, use it with
* care.
*/
def blockingReady(node: Node, timeout: Long)(implicit time: Time): Boolean = {
require(timeout >=0, "timeout should be >= 0")
val startTime = time.milliseconds()
val expiryTime = startTime + timeout
@tailrec
def awaitReady(iterationStartTime: Long): Boolean = {
if (client.isReady(node, iterationStartTime))
true
else if (client.connectionFailed(node))
throw new IOException(s"Connection to $node failed")
else {
val pollTimeout = expiryTime - iterationStartTime
client.poll(pollTimeout, iterationStartTime)
val afterPollTime = time.milliseconds()
if (afterPollTime < expiryTime) awaitReady(afterPollTime)
else false
}
}
isReady(node) || client.ready(node, startTime) || awaitReady(startTime)
}
/**
* Invokes `client.send` followed by 1 or more `client.poll` invocations until a response is received or a
* disconnection happens (which can happen for a number of reasons including a request timeout).
*
* In case of a disconnection, an `IOException` is thrown.
*
* This method is useful for implementing blocking behaviour on top of the non-blocking `NetworkClient`, use it with
* care.
*/
def blockingSendAndReceive(request: ClientRequest)(implicit time: Time): ClientResponse = {
client.send(request, time.milliseconds())
pollContinuously { responses =>
val response = responses.find { response =>
response.requestHeader.correlationId == request.correlationId
}
response.foreach { r =>
if (r.wasDisconnected)
throw new IOException(s"Connection to ${request.destination} was disconnected before the response was read")
else if (r.versionMismatch() != null)
throw r.versionMismatch();
}
response
}
}
/**
* Invokes `client.poll` until `collect` returns `Some`. The value inside `Some` is returned.
*
* Exceptions thrown via `collect` are not handled and will bubble up.
*
* This method is useful for implementing blocking behaviour on top of the non-blocking `NetworkClient`, use it with
* care.
*/
private def pollContinuously[T](collect: Seq[ClientResponse] => Option[T])(implicit time: Time): T = {
@tailrec
def recursivePoll: T = {
// rely on request timeout to ensure we don't block forever
val responses = client.poll(Long.MaxValue, time.milliseconds()).asScala
collect(responses) match {
case Some(result) => result
case None => recursivePoll
}
}
recursivePoll
}
}