blob: 9ca36a3d8dface05e1e876a54cd2553ba7c80a69 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kafka.common
import org.apache.kafka.clients.{ClientRequest, ClientResponse, NetworkClient, RequestCompletionHandler}
import org.apache.kafka.common.Node
import org.apache.kafka.common.errors.{AuthenticationException, DisconnectException}
import org.apache.kafka.common.protocol.ApiKeys
import org.apache.kafka.common.requests.AbstractRequest
import org.apache.kafka.server.util.MockTime
import org.junit.jupiter.api.Assertions._
import org.junit.jupiter.api.Test
import org.mockito.ArgumentMatchers.{any, anyLong, same}
import org.mockito.ArgumentMatchers
import org.mockito.Mockito.{mock, verify, when}
import java.util
import scala.collection.mutable
class InterBrokerSendThreadTest {
private val time = new MockTime()
private val networkClient: NetworkClient = mock(classOf[NetworkClient])
private val completionHandler = new StubCompletionHandler
private val requestTimeoutMs = 1000
class TestInterBrokerSendThread(networkClient: NetworkClient = networkClient,
exceptionCallback: Throwable => Unit = t => throw t)
extends InterBrokerSendThread("name", networkClient, requestTimeoutMs, time) {
private val queue = mutable.Queue[RequestAndCompletionHandler]()
def enqueue(request: RequestAndCompletionHandler): Unit = {
queue += request
}
override def generateRequests(): Iterable[RequestAndCompletionHandler] = {
if (queue.isEmpty) {
None
} else {
Some(queue.dequeue())
}
}
override def pollOnce(maxTimeoutMs: Long): Unit = {
try super.pollOnce(maxTimeoutMs)
catch {
case e: Throwable => exceptionCallback(e)
}
}
}
@Test
def shutdownThreadShouldNotCauseException(): Unit = {
// InterBrokerSendThread#shutdown calls NetworkClient#initiateClose first so NetworkClient#poll
// can throw DisconnectException when thread is running
when(networkClient.poll(anyLong(), anyLong())).thenThrow(new DisconnectException())
var exception: Throwable = null
val thread = new TestInterBrokerSendThread(networkClient, e => exception = e)
thread.shutdown()
thread.pollOnce(100)
verify(networkClient).poll(anyLong(), anyLong())
assertNull(exception)
}
@Test
def shouldNotSendAnythingWhenNoRequests(): Unit = {
val sendThread = new TestInterBrokerSendThread()
// poll is always called but there should be no further invocations on NetworkClient
when(networkClient.poll(anyLong(), anyLong()))
.thenReturn(new util.ArrayList[ClientResponse]())
sendThread.doWork()
verify(networkClient).poll(anyLong(), anyLong())
assertFalse(completionHandler.executedWithDisconnectedResponse)
}
@Test
def shouldCreateClientRequestAndSendWhenNodeIsReady(): Unit = {
val request = new StubRequestBuilder()
val node = new Node(1, "", 8080)
val handler = RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler)
val sendThread = new TestInterBrokerSendThread()
val clientRequest = new ClientRequest("dest", request, 0, "1", 0, true, requestTimeoutMs, handler.handler)
when(networkClient.newClientRequest(
ArgumentMatchers.eq("1"),
same(handler.request),
anyLong(),
ArgumentMatchers.eq(true),
ArgumentMatchers.eq(requestTimeoutMs),
same(handler.handler)))
.thenReturn(clientRequest)
when(networkClient.ready(node, time.milliseconds()))
.thenReturn(true)
when(networkClient.poll(anyLong(), anyLong()))
.thenReturn(new util.ArrayList[ClientResponse]())
sendThread.enqueue(handler)
sendThread.doWork()
verify(networkClient).newClientRequest(
ArgumentMatchers.eq("1"),
same(handler.request),
anyLong(),
ArgumentMatchers.eq(true),
ArgumentMatchers.eq(requestTimeoutMs),
same(handler.handler))
verify(networkClient).ready(any[Node], anyLong())
verify(networkClient).send(same(clientRequest), anyLong())
verify(networkClient).poll(anyLong(), anyLong())
assertFalse(completionHandler.executedWithDisconnectedResponse)
}
@Test
def shouldCallCompletionHandlerWithDisconnectedResponseWhenNodeNotReady(): Unit = {
val request = new StubRequestBuilder
val node = new Node(1, "", 8080)
val handler = RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler)
val sendThread = new TestInterBrokerSendThread()
val clientRequest = new ClientRequest("dest", request, 0, "1", 0, true, requestTimeoutMs, handler.handler)
when(networkClient.newClientRequest(
ArgumentMatchers.eq("1"),
same(handler.request),
anyLong(),
ArgumentMatchers.eq(true),
ArgumentMatchers.eq(requestTimeoutMs),
same(handler.handler)))
.thenReturn(clientRequest)
when(networkClient.ready(node, time.milliseconds()))
.thenReturn(false)
when(networkClient.connectionDelay(any[Node], anyLong()))
.thenReturn(0)
when(networkClient.poll(anyLong(), anyLong()))
.thenReturn(new util.ArrayList[ClientResponse]())
when(networkClient.connectionFailed(node))
.thenReturn(true)
when(networkClient.authenticationException(node))
.thenReturn(new AuthenticationException(""))
sendThread.enqueue(handler)
sendThread.doWork()
verify(networkClient).newClientRequest(
ArgumentMatchers.eq("1"),
same(handler.request),
anyLong,
ArgumentMatchers.eq(true),
ArgumentMatchers.eq(requestTimeoutMs),
same(handler.handler))
verify(networkClient).ready(any[Node], anyLong)
verify(networkClient).connectionDelay(any[Node], anyLong)
verify(networkClient).poll(anyLong, anyLong)
verify(networkClient).connectionFailed(any[Node])
verify(networkClient).authenticationException(any[Node])
assertTrue(completionHandler.executedWithDisconnectedResponse)
}
@Test
def testFailingExpiredRequests(): Unit = {
val request = new StubRequestBuilder()
val node = new Node(1, "", 8080)
val handler = RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler)
val sendThread = new TestInterBrokerSendThread()
val clientRequest = new ClientRequest("dest",
request,
0,
"1",
time.milliseconds(),
true,
requestTimeoutMs,
handler.handler)
time.sleep(1500)
when(networkClient.newClientRequest(
ArgumentMatchers.eq("1"),
same(handler.request),
ArgumentMatchers.eq(handler.creationTimeMs),
ArgumentMatchers.eq(true),
ArgumentMatchers.eq(requestTimeoutMs),
same(handler.handler)))
.thenReturn(clientRequest)
// make the node unready so the request is not cleared
when(networkClient.ready(node, time.milliseconds()))
.thenReturn(false)
when(networkClient.connectionDelay(any[Node], anyLong()))
.thenReturn(0)
when(networkClient.poll(anyLong(), anyLong()))
.thenReturn(new util.ArrayList[ClientResponse]())
// rule out disconnects so the request stays for the expiry check
when(networkClient.connectionFailed(node))
.thenReturn(false)
sendThread.enqueue(handler)
sendThread.doWork()
verify(networkClient).newClientRequest(
ArgumentMatchers.eq("1"),
same(handler.request),
ArgumentMatchers.eq(handler.creationTimeMs),
ArgumentMatchers.eq(true),
ArgumentMatchers.eq(requestTimeoutMs),
same(handler.handler))
verify(networkClient).ready(any[Node], anyLong)
verify(networkClient).connectionDelay(any[Node], anyLong)
verify(networkClient).poll(anyLong, anyLong)
verify(networkClient).connectionFailed(any[Node])
assertFalse(sendThread.hasUnsentRequests)
assertTrue(completionHandler.executedWithDisconnectedResponse)
}
private class StubRequestBuilder extends AbstractRequest.Builder(ApiKeys.END_TXN) {
override def build(version: Short): Nothing = ???
}
private class StubCompletionHandler extends RequestCompletionHandler {
var executedWithDisconnectedResponse = false
var response: ClientResponse = _
override def onComplete(response: ClientResponse): Unit = {
this.executedWithDisconnectedResponse = response.wasDisconnected()
this.response = response
}
}
}