blob: 25f6b49e8f887d02ffd2d3b2ab99adb9a35a66f3 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kafka.utils
import java.io._
import java.net._
import java.nio._
import java.nio.channels._
import java.util.Random
import java.util.Properties
import junit.framework.Assert._
import kafka.server._
import kafka.producer._
import kafka.message._
import org.I0Itec.zkclient.ZkClient
import kafka.consumer.ConsumerConfig
/**
* Utility functions to help with testing
*/
object TestUtils {
val Letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
val Digits = "0123456789"
val LettersAndDigits = Letters + Digits
/* A consistent random number generator to make tests repeatable */
val seededRandom = new Random(192348092834L)
val random = new Random()
/**
* Choose a number of random available ports
*/
def choosePorts(count: Int): List[Int] = {
val sockets =
for(i <- 0 until count)
yield new ServerSocket(0)
val socketList = sockets.toList
val ports = socketList.map(_.getLocalPort)
socketList.map(_.close)
ports
}
/**
* Choose an available port
*/
def choosePort(): Int = choosePorts(1).head
/**
* Create a temporary directory
*/
def tempDir(): File = {
val ioDir = System.getProperty("java.io.tmpdir")
val f = new File(ioDir, "kafka-" + random.nextInt(1000000))
f.mkdirs()
f.deleteOnExit()
f
}
/**
* Create a temporary file
*/
def tempFile(): File = {
val f = File.createTempFile("kafka", ".tmp")
f.deleteOnExit()
f
}
/**
* Create a temporary file and return an open file channel for this file
*/
def tempChannel(): FileChannel = new RandomAccessFile(tempFile(), "rw").getChannel()
/**
* Create a kafka server instance with appropriate test settings
* USING THIS IS A SIGN YOU ARE NOT WRITING A REAL UNIT TEST
* @param config The configuration of the server
*/
def createServer(config: KafkaConfig): KafkaServer = {
val server = new KafkaServer(config)
server.startup()
server
}
/**
* Create a test config for the given node id
*/
def createBrokerConfigs(numConfigs: Int): List[Properties] = {
for((port, node) <- choosePorts(numConfigs).zipWithIndex)
yield createBrokerConfig(node, port)
}
/**
* Create a test config for the given node id
*/
def createBrokerConfig(nodeId: Int, port: Int): Properties = {
val props = new Properties
props.put("brokerid", nodeId.toString)
props.put("port", port.toString)
props.put("log.dir", TestUtils.tempDir().getAbsolutePath)
props.put("log.flush.interval", "1")
props.put("zk.connect", TestZKUtils.zookeeperConnect)
props
}
/**
* Create a test config for a consumer
*/
def createConsumerProperties(zkConnect: String, groupId: String, consumerId: String,
consumerTimeout: Long = -1): Properties = {
val props = new Properties
props.put("zk.connect", zkConnect)
props.put("groupid", groupId)
props.put("consumerid", consumerId)
props.put("consumer.timeout.ms", consumerTimeout.toString)
props.put("zk.sessiontimeout.ms", "400")
props.put("zk.synctime.ms", "200")
props.put("autocommit.interval.ms", "1000")
props
}
/**
* Wrap the message in a message set
* @param payload The bytes of the message
*/
def singleMessageSet(payload: Array[Byte]) =
new ByteBufferMessageSet(compressionCodec = NoCompressionCodec, messages = new Message(payload))
/**
* Generate an array of random bytes
* @param numBytes The size of the array
*/
def randomBytes(numBytes: Int): Array[Byte] = {
val bytes = new Array[Byte](numBytes)
seededRandom.nextBytes(bytes)
bytes
}
/**
* Generate a random string of letters and digits of the given length
* @param len The length of the string
* @return The random string
*/
def randomString(len: Int): String = {
val b = new StringBuilder()
for(i <- 0 until len)
b.append(LettersAndDigits.charAt(seededRandom.nextInt(LettersAndDigits.length)))
b.toString
}
/**
* Check that the buffer content from buffer.position() to buffer.limit() is equal
*/
def checkEquals(b1: ByteBuffer, b2: ByteBuffer) {
assertEquals("Buffers should have equal length", b1.limit - b1.position, b2.limit - b2.position)
for(i <- 0 until b1.limit - b1.position)
assertEquals("byte " + i + " byte not equal.", b1.get(b1.position + i), b2.get(b1.position + i))
}
/**
* Throw an exception if the two iterators are of differing lengths or contain
* different messages on their Nth element
*/
def checkEquals[T](expected: Iterator[T], actual: Iterator[T]) {
var length = 0
while(expected.hasNext && actual.hasNext) {
length += 1
assertEquals(expected.next, actual.next)
}
if (expected.hasNext)
{
var length1 = length;
while (expected.hasNext)
{
expected.next
length1 += 1
}
assertFalse("Iterators have uneven length-- first has more: "+length1 + " > " + length, true);
}
if (actual.hasNext)
{
var length2 = length;
while (actual.hasNext)
{
actual.next
length2 += 1
}
assertFalse("Iterators have uneven length-- second has more: "+length2 + " > " + length, true);
}
}
/**
* Throw an exception if an iterable has different length than expected
*
*/
def checkLength[T](s1: Iterator[T], expectedLength:Int) {
var n = 0
while (s1.hasNext) {
n+=1
s1.next
}
assertEquals(expectedLength, n)
}
/**
* Throw an exception if the two iterators are of differing lengths or contain
* different messages on their Nth element
*/
def checkEquals[T](s1: java.util.Iterator[T], s2: java.util.Iterator[T]) {
while(s1.hasNext && s2.hasNext)
assertEquals(s1.next, s2.next)
assertFalse("Iterators have uneven length--first has more", s1.hasNext)
assertFalse("Iterators have uneven length--second has more", s2.hasNext)
}
def stackedIterator[T](s: Iterator[T]*): Iterator[T] = {
new Iterator[T] {
var cur: Iterator[T] = null
val topIterator = s.iterator
def hasNext() : Boolean = {
while (true) {
if (cur == null) {
if (topIterator.hasNext)
cur = topIterator.next
else
return false
}
if (cur.hasNext)
return true
cur = null
}
// should never reach her
throw new RuntimeException("should not reach here")
}
def next() : T = cur.next
}
}
/**
* Create a hexidecimal string for the given bytes
*/
def hexString(bytes: Array[Byte]): String = hexString(ByteBuffer.wrap(bytes))
/**
* Create a hexidecimal string for the given bytes
*/
def hexString(buffer: ByteBuffer): String = {
val builder = new StringBuilder("0x")
for(i <- 0 until buffer.limit)
builder.append(String.format("%x", Integer.valueOf(buffer.get(buffer.position + i))))
builder.toString
}
/**
* Create a producer for the given host and port
*/
def createProducer(host: String, port: Int): SyncProducer = {
val props = new Properties()
props.put("host", host)
props.put("port", port.toString)
props.put("buffer.size", "65536")
props.put("connect.timeout.ms", "100000")
props.put("reconnect.interval", "10000")
return new SyncProducer(new SyncProducerConfig(props))
}
def updateConsumerOffset(config : ConsumerConfig, path : String, offset : Long) = {
val zkClient = new ZkClient(config.zkConnect, config.zkSessionTimeoutMs, config.zkConnectionTimeoutMs, ZKStringSerializer)
ZkUtils.updatePersistentPath(zkClient, path, offset.toString)
}
def getMessageIterator(iter: Iterator[MessageAndOffset]): Iterator[Message] = {
new IteratorTemplate[Message] {
override def makeNext(): Message = {
if (iter.hasNext)
return iter.next.message
else
return allDone()
}
}
}
}
object TestZKUtils {
val zookeeperConnect = "127.0.0.1:2182"
}