blob: 4c1494380ea608f168943c091aab7993e333b93a [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kafka.server
import kafka.api.IntegrationTestHarness
import kafka.network.SocketServer
import org.apache.kafka.common.network.ListenerName
import org.apache.kafka.common.protocol.ApiKeys
import org.apache.kafka.common.requests.{AbstractRequest, AbstractResponse, RequestHeader, ResponseHeader}
import org.apache.kafka.common.utils.Utils
import org.apache.kafka.metadata.BrokerState
import org.apache.kafka.server.config.ServerConfigs
import java.io.{DataInputStream, DataOutputStream}
import java.net.Socket
import java.nio.ByteBuffer
import java.util.Properties
import scala.collection.Seq
import scala.reflect.ClassTag
abstract class BaseRequestTest extends IntegrationTestHarness {
private var correlationId = 0
// If required, set number of brokers
override def brokerCount: Int = 3
// If required, override properties by mutating the passed Properties object
protected def brokerPropertyOverrides(properties: Properties): Unit = {}
override def modifyConfigs(props: Seq[Properties]): Unit = {
props.foreach { p =>
p.put(ServerConfigs.CONTROLLED_SHUTDOWN_ENABLE_CONFIG, "false")
brokerPropertyOverrides(p)
}
}
def anySocketServer: SocketServer = {
brokers.find { broker =>
val state = broker.brokerState
state != BrokerState.NOT_RUNNING && state != BrokerState.SHUTTING_DOWN
}.map(_.socketServer).getOrElse(throw new IllegalStateException("No live broker is available"))
}
def controllerSocketServer: SocketServer = {
if (isKRaftTest()) {
controllerServer.socketServer
} else {
servers.find { server =>
server.kafkaController.isActive
}.map(_.socketServer).getOrElse(throw new IllegalStateException("No controller broker is available"))
}
}
def notControllerSocketServer: SocketServer = {
if (isKRaftTest()) {
anySocketServer
} else {
servers.find { server =>
!server.kafkaController.isActive
}.map(_.socketServer).getOrElse(throw new IllegalStateException("No non-controller broker is available"))
}
}
def brokerSocketServer(brokerId: Int): SocketServer = {
brokers.find { broker =>
broker.config.brokerId == brokerId
}.map(_.socketServer).getOrElse(throw new IllegalStateException(s"Could not find broker with id $brokerId"))
}
/**
* Return the socket server where admin request to be sent.
*
* For KRaft clusters that is any broker as the broker will forward the request to the active
* controller. For Legacy clusters that is the controller broker.
*/
def adminSocketServer: SocketServer = {
if (isKRaftTest()) {
anySocketServer
} else {
controllerSocketServer
}
}
def connect(socketServer: SocketServer = anySocketServer,
listenerName: ListenerName = listenerName): Socket = {
new Socket("localhost", socketServer.boundPort(listenerName))
}
private def sendRequest(socket: Socket, request: Array[Byte]): Unit = {
val outgoing = new DataOutputStream(socket.getOutputStream)
outgoing.writeInt(request.length)
outgoing.write(request)
outgoing.flush()
}
def receive[T <: AbstractResponse](socket: Socket, apiKey: ApiKeys, version: Short)
(implicit classTag: ClassTag[T]): T = {
val incoming = new DataInputStream(socket.getInputStream)
val len = incoming.readInt()
val responseBytes = new Array[Byte](len)
incoming.readFully(responseBytes)
val responseBuffer = ByteBuffer.wrap(responseBytes)
ResponseHeader.parse(responseBuffer, apiKey.responseHeaderVersion(version))
AbstractResponse.parseResponse(apiKey, responseBuffer, version) match {
case response: T => response
case response =>
throw new ClassCastException(s"Expected response with type ${classTag.runtimeClass}, but found ${response.getClass}")
}
}
def sendAndReceive[T <: AbstractResponse](request: AbstractRequest,
socket: Socket,
clientId: String = "client-id",
correlationId: Option[Int] = None)
(implicit classTag: ClassTag[T]): T = {
send(request, socket, clientId, correlationId)
receive[T](socket, request.apiKey, request.version)
}
def connectAndReceive[T <: AbstractResponse](request: AbstractRequest,
destination: SocketServer = anySocketServer,
listenerName: ListenerName = listenerName)
(implicit classTag: ClassTag[T]): T = {
val socket = connect(destination, listenerName)
try sendAndReceive[T](request, socket)
finally socket.close()
}
/**
* Serializes and sends the request to the given api.
*/
def send(request: AbstractRequest,
socket: Socket,
clientId: String = "client-id",
correlationId: Option[Int] = None): Unit = {
val header = nextRequestHeader(request.apiKey, request.version, clientId, correlationId)
sendWithHeader(request, header, socket)
}
def sendWithHeader(request: AbstractRequest, header: RequestHeader, socket: Socket): Unit = {
val serializedBytes = Utils.toArray(request.serializeWithHeader(header))
sendRequest(socket, serializedBytes)
}
def nextRequestHeader[T <: AbstractResponse](apiKey: ApiKeys,
apiVersion: Short,
clientId: String = "client-id",
correlationIdOpt: Option[Int] = None): RequestHeader = {
val correlationId = correlationIdOpt.getOrElse {
this.correlationId += 1
this.correlationId
}
new RequestHeader(apiKey, apiVersion, clientId, correlationId)
}
}