blob: 906c4b27b2e4aca9c6a92d4b0bbf6684c29dcaf0 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kafka.server
import java.io.{DataInputStream, DataOutputStream}
import java.net.Socket
import java.nio.ByteBuffer
import java.util.Properties
import kafka.integration.KafkaServerTestHarness
import kafka.network.SocketServer
import kafka.utils._
import org.apache.kafka.common.protocol.{ApiKeys, SecurityProtocol}
import org.apache.kafka.common.requests.{AbstractRequest, RequestHeader, ResponseHeader}
import org.junit.Before
abstract class BaseRequestTest extends KafkaServerTestHarness {
private var correlationId = 0
// If required, set number of brokers
protected def numBrokers: Int = 3
// If required, override properties by mutating the passed Properties object
protected def propertyOverrides(properties: Properties) {}
def generateConfigs() = {
val props = TestUtils.createBrokerConfigs(numBrokers, zkConnect, enableControlledShutdown = false,
interBrokerSecurityProtocol = Some(securityProtocol),
trustStoreFile = trustStoreFile, saslProperties = saslProperties)
props.foreach(propertyOverrides)
props.map(KafkaConfig.fromProps)
}
@Before
override def setUp() {
super.setUp()
TestUtils.waitUntilTrue(() => servers.head.metadataCache.getAliveBrokers.size == numBrokers, "Wait for cache to update")
}
def socketServer = {
servers.find { server =>
val state = server.brokerState.currentState
state != NotRunning.state && state != BrokerShuttingDown.state
}.map(_.socketServer).getOrElse(throw new IllegalStateException("No live broker is available"))
}
def connect(s: SocketServer = socketServer, protocol: SecurityProtocol = SecurityProtocol.PLAINTEXT): Socket = {
new Socket("localhost", s.boundPort(protocol))
}
private def sendRequest(socket: Socket, request: Array[Byte]) {
val outgoing = new DataOutputStream(socket.getOutputStream)
outgoing.writeInt(request.length)
outgoing.write(request)
outgoing.flush()
}
private def receiveResponse(socket: Socket): Array[Byte] = {
val incoming = new DataInputStream(socket.getInputStream)
val len = incoming.readInt()
val response = new Array[Byte](len)
incoming.readFully(response)
response
}
def requestAndReceive(socket: Socket, request: Array[Byte]): Array[Byte] = {
sendRequest(socket, request)
receiveResponse(socket)
}
def send(request: AbstractRequest, apiKey: ApiKeys, version: Short): ByteBuffer = {
val socket = connect()
try {
send(socket, request, apiKey, version)
} finally {
socket.close()
}
}
/**
* Serializes and send the request to the given api. A ByteBuffer containing the response is returned.
*/
def send(socket: Socket, request: AbstractRequest, apiKey: ApiKeys, version: Short): ByteBuffer = {
correlationId += 1
val serializedBytes = {
val header = new RequestHeader(apiKey.id, version, "", correlationId)
val byteBuffer = ByteBuffer.allocate(header.sizeOf() + request.sizeOf)
header.writeTo(byteBuffer)
request.writeTo(byteBuffer)
byteBuffer.array()
}
val response = requestAndReceive(socket, serializedBytes)
val responseBuffer = ByteBuffer.wrap(response)
ResponseHeader.parse(responseBuffer) // Parse the header to ensure its valid and move the buffer forward
responseBuffer
}
}