blob: 23d64b732d28658307a613d0161a12a8a115a7f5 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kafka.utils
import java.io._
import java.nio._
import java.nio.channels._
import java.util.concurrent.atomic._
import java.lang.management._
import java.util.zip.CRC32
import javax.management._
import java.util.Properties
import scala.collection._
import scala.collection.mutable
import kafka.message.{NoCompressionCodec, CompressionCodec}
import org.I0Itec.zkclient.ZkClient
import joptsimple.{OptionSpec, OptionSet, OptionParser}
import util.parsing.json.JSON
/**
* Helper functions!
*/
object Utils extends Logging {
/**
* Wrap the given function in a java.lang.Runnable
* @param fun A function
* @return A Runnable that just executes the function
*/
def runnable(fun: () => Unit): Runnable =
new Runnable() {
def run() = fun()
}
/**
* Wrap the given function in a java.lang.Runnable that logs any errors encountered
* @param fun A function
* @return A Runnable that just executes the function
*/
def loggedRunnable(fun: () => Unit): Runnable =
new Runnable() {
def run() = {
try {
fun()
}
catch {
case t =>
// log any error and the stack trace
error("error in loggedRunnable", t)
}
}
}
/**
* Create a daemon thread
* @param name The name of the thread
* @param runnable The runnable to execute in the background
* @return The unstarted thread
*/
def daemonThread(name: String, runnable: Runnable): Thread =
newThread(name, runnable, true)
/**
* Create a daemon thread
* @param name The name of the thread
* @param fun The runction to execute in the thread
* @return The unstarted thread
*/
def daemonThread(name: String, fun: () => Unit): Thread =
daemonThread(name, runnable(fun))
/**
* Create a new thread
* @param name The name of the thread
* @param runnable The work for the thread to do
* @param daemon Should the thread block JVM shutdown?
* @return The unstarted thread
*/
def newThread(name: String, runnable: Runnable, daemon: Boolean): Thread = {
val thread = new Thread(runnable, name)
thread.setDaemon(daemon)
thread
}
/**
* Read a byte array from the given offset and size in the buffer
* TODO: Should use System.arraycopy
*/
def readBytes(buffer: ByteBuffer, offset: Int, size: Int): Array[Byte] = {
val bytes = new Array[Byte](size)
var i = 0
while(i < size) {
bytes(i) = buffer.get(offset + i)
i += 1
}
bytes
}
/**
* Read size prefixed string where the size is stored as a 2 byte short.
* @param buffer The buffer to read from
* @param encoding The encoding in which to read the string
*/
def readShortString(buffer: ByteBuffer, encoding: String): String = {
val size: Int = buffer.getShort()
if(size < 0)
return null
val bytes = new Array[Byte](size)
buffer.get(bytes)
new String(bytes, encoding)
}
/**
* Write a size prefixed string where the size is stored as a 2 byte short
* @param buffer The buffer to write to
* @param string The string to write
* @param encoding The encoding in which to write the string
*/
def writeShortString(buffer: ByteBuffer, string: String, encoding: String): Unit = {
if(string == null) {
buffer.putShort(-1)
} else if(string.length > Short.MaxValue) {
throw new IllegalArgumentException("String exceeds the maximum size of " + Short.MaxValue + ".")
} else {
buffer.putShort(string.length.asInstanceOf[Short])
buffer.put(string.getBytes(encoding))
}
}
/**
* Read a properties file from the given path
* @param filename The path of the file to read
*/
def loadProps(filename: String): Properties = {
val propStream = new FileInputStream(filename)
val props = new Properties()
props.load(propStream)
props
}
/**
* Read a required integer property value or throw an exception if no such property is found
*/
def getInt(props: Properties, name: String): Int = {
if(props.containsKey(name))
return getInt(props, name, -1)
else
throw new IllegalArgumentException("Missing required property '" + name + "'")
}
/**
* Read an integer from the properties instance
* @param props The properties to read from
* @param name The property name
* @param default The default value to use if the property is not found
* @return the integer value
*/
def getInt(props: Properties, name: String, default: Int): Int =
getIntInRange(props, name, default, (Int.MinValue, Int.MaxValue))
/**
* Read an integer from the properties instance. Throw an exception
* if the value is not in the given range (inclusive)
* @param props The properties to read from
* @param name The property name
* @param default The default value to use if the property is not found
* @param range The range in which the value must fall (inclusive)
* @throws IllegalArgumentException If the value is not in the given range
* @return the integer value
*/
def getIntInRange(props: Properties, name: String, default: Int, range: (Int, Int)): Int = {
val v =
if(props.containsKey(name))
props.getProperty(name).toInt
else
default
if(v < range._1 || v > range._2)
throw new IllegalArgumentException(name + " has value " + v + " which is not in the range " + range + ".")
else
v
}
/**
* Read a required long property value or throw an exception if no such property is found
*/
def getLong(props: Properties, name: String): Long = {
if(props.containsKey(name))
return getLong(props, name, -1)
else
throw new IllegalArgumentException("Missing required property '" + name + "'")
}
/**
* Read an long from the properties instance
* @param props The properties to read from
* @param name The property name
* @param default The default value to use if the property is not found
* @return the long value
*/
def getLong(props: Properties, name: String, default: Long): Long =
getLongInRange(props, name, default, (Long.MinValue, Long.MaxValue))
/**
* Read an long from the properties instance. Throw an exception
* if the value is not in the given range (inclusive)
* @param props The properties to read from
* @param name The property name
* @param default The default value to use if the property is not found
* @param range The range in which the value must fall (inclusive)
* @throws IllegalArgumentException If the value is not in the given range
* @return the long value
*/
def getLongInRange(props: Properties, name: String, default: Long, range: (Long, Long)): Long = {
val v =
if(props.containsKey(name))
props.getProperty(name).toLong
else
default
if(v < range._1 || v > range._2)
throw new IllegalArgumentException(name + " has value " + v + " which is not in the range " + range + ".")
else
v
}
/**
* Read a boolean value from the properties instance
* @param props The properties to read from
* @param name The property name
* @param default The default value to use if the property is not found
* @return the boolean value
*/
def getBoolean(props: Properties, name: String, default: Boolean): Boolean = {
if(!props.containsKey(name))
default
else if("true" == props.getProperty(name))
true
else if("false" == props.getProperty(name))
false
else
throw new IllegalArgumentException("Unacceptable value for property '" + name + "', boolean values must be either 'true' or 'false" )
}
/**
* Get a string property, or, if no such property is defined, return the given default value
*/
def getString(props: Properties, name: String, default: String): String = {
if(props.containsKey(name))
props.getProperty(name)
else
default
}
/**
* Get a string property or throw and exception if no such property is defined.
*/
def getString(props: Properties, name: String): String = {
if(props.containsKey(name))
props.getProperty(name)
else
throw new IllegalArgumentException("Missing required property '" + name + "'")
}
/**
* Get a property of type java.util.Properties or throw and exception if no such property is defined.
*/
def getProps(props: Properties, name: String): Properties = {
if(props.containsKey(name)) {
val propString = props.getProperty(name)
val propValues = propString.split(",")
val properties = new Properties
for(i <- 0 until propValues.length) {
val prop = propValues(i).split("=")
if(prop.length != 2)
throw new IllegalArgumentException("Illegal format of specifying properties '" + propValues(i) + "'")
properties.put(prop(0), prop(1))
}
properties
}
else
throw new IllegalArgumentException("Missing required property '" + name + "'")
}
/**
* Get a property of type java.util.Properties or return the default if no such property is defined
*/
def getProps(props: Properties, name: String, default: Properties): Properties = {
if(props.containsKey(name)) {
val propString = props.getProperty(name)
val propValues = propString.split(",")
if(propValues.length < 1)
throw new IllegalArgumentException("Illegal format of specifying properties '" + propString + "'")
val properties = new Properties
for(i <- 0 until propValues.length) {
val prop = propValues(i).split("=")
if(prop.length != 2)
throw new IllegalArgumentException("Illegal format of specifying properties '" + propValues(i) + "'")
properties.put(prop(0), prop(1))
}
properties
}
else
default
}
/**
* Open a channel for the given file
*/
def openChannel(file: File, mutable: Boolean): FileChannel = {
if(mutable)
new RandomAccessFile(file, "rw").getChannel()
else
new FileInputStream(file).getChannel()
}
/**
* Do the given action and log any exceptions thrown without rethrowing them
* @param log The log method to use for logging. E.g. logger.warn
* @param action The action to execute
*/
def swallow(log: (Object, Throwable) => Unit, action: => Unit) = {
try {
action
} catch {
case e: Throwable => log(e.getMessage(), e)
}
}
/**
* Test if two byte buffers are equal. In this case equality means having
* the same bytes from the current position to the limit
*/
def equal(b1: ByteBuffer, b2: ByteBuffer): Boolean = {
// two byte buffers are equal if their position is the same,
// their remaining bytes are the same, and their contents are the same
if(b1.position != b2.position)
return false
if(b1.remaining != b2.remaining)
return false
for(i <- 0 until b1.remaining)
if(b1.get(i) != b2.get(i))
return false
return true
}
/**
* Translate the given buffer into a string
* @param buffer The buffer to translate
* @param encoding The encoding to use in translating bytes to characters
*/
def toString(buffer: ByteBuffer, encoding: String): String = {
val bytes = new Array[Byte](buffer.remaining)
buffer.get(bytes)
new String(bytes, encoding)
}
/**
* Print an error message and shutdown the JVM
* @param message The error message
*/
def croak(message: String) {
System.err.println(message)
System.exit(1)
}
/**
* Recursively delete the given file/directory and any subfiles (if any exist)
* @param file The root file at which to begin deleting
*/
def rm(file: String): Unit = rm(new File(file))
/**
* Recursively delete the given file/directory and any subfiles (if any exist)
* @param file The root file at which to begin deleting
*/
def rm(file: File): Unit = {
if(file == null) {
return
} else if(file.isDirectory) {
val files = file.listFiles()
if(files != null) {
for(f <- files)
rm(f)
}
file.delete()
} else {
file.delete()
}
}
/**
* Register the given mbean with the platform mbean server,
* unregistering any mbean that was there before. Note,
* this method will not throw an exception if the registration
* fails (since there is nothing you can do and it isn't fatal),
* instead it just returns false indicating the registration failed.
* @param mbean The object to register as an mbean
* @param name The name to register this mbean with
* @return true if the registration succeeded
*/
def registerMBean(mbean: Object, name: String): Boolean = {
try {
val mbs = ManagementFactory.getPlatformMBeanServer()
mbs synchronized {
val objName = new ObjectName(name)
if(mbs.isRegistered(objName))
mbs.unregisterMBean(objName)
mbs.registerMBean(mbean, objName)
true
}
} catch {
case e: Exception => {
error("Failed to register Mbean " + name, e)
false
}
}
}
/**
* Unregister the mbean with the given name, if there is one registered
* @param name The mbean name to unregister
*/
def unregisterMBean(name: String) {
val mbs = ManagementFactory.getPlatformMBeanServer()
mbs synchronized {
val objName = new ObjectName(name)
if(mbs.isRegistered(objName))
mbs.unregisterMBean(objName)
}
}
/**
* Read an unsigned integer from the current position in the buffer,
* incrementing the position by 4 bytes
* @param buffer The buffer to read from
* @return The integer read, as a long to avoid signedness
*/
def getUnsignedInt(buffer: ByteBuffer): Long =
buffer.getInt() & 0xffffffffL
/**
* Read an unsigned integer from the given position without modifying the buffers
* position
* @param The buffer to read from
* @param index the index from which to read the integer
* @return The integer read, as a long to avoid signedness
*/
def getUnsignedInt(buffer: ByteBuffer, index: Int): Long =
buffer.getInt(index) & 0xffffffffL
/**
* Write the given long value as a 4 byte unsigned integer. Overflow is ignored.
* @param buffer The buffer to write to
* @param value The value to write
*/
def putUnsignedInt(buffer: ByteBuffer, value: Long): Unit =
buffer.putInt((value & 0xffffffffL).asInstanceOf[Int])
/**
* Write the given long value as a 4 byte unsigned integer. Overflow is ignored.
* @param buffer The buffer to write to
* @param index The position in the buffer at which to begin writing
* @param value The value to write
*/
def putUnsignedInt(buffer: ByteBuffer, index: Int, value: Long): Unit =
buffer.putInt(index, (value & 0xffffffffL).asInstanceOf[Int])
/**
* Compute the CRC32 of the byte array
* @param bytes The array to compute the checksum for
* @return The CRC32
*/
def crc32(bytes: Array[Byte]): Long = crc32(bytes, 0, bytes.length)
/**
* Compute the CRC32 of the segment of the byte array given by the specificed size and offset
* @param bytes The bytes to checksum
* @param the offset at which to begin checksumming
* @param the number of bytes to checksum
* @return The CRC32
*/
def crc32(bytes: Array[Byte], offset: Int, size: Int): Long = {
val crc = new CRC32()
crc.update(bytes, offset, size)
crc.getValue()
}
/**
* Compute the hash code for the given items
*/
def hashcode(as: Any*): Int = {
if(as == null)
return 0
var h = 1
var i = 0
while(i < as.length) {
if(as(i) != null) {
h = 31 * h + as(i).hashCode
i += 1
}
}
return h
}
/**
* Group the given values by keys extracted with the given function
*/
def groupby[K,V](vals: Iterable[V], f: V => K): Map[K,List[V]] = {
val m = new mutable.HashMap[K, List[V]]
for(v <- vals) {
val k = f(v)
m.get(k) match {
case Some(l: List[V]) => m.put(k, v :: l)
case None => m.put(k, List(v))
}
}
m
}
/**
* Read some bytes into the provided buffer, and return the number of bytes read. If the
* channel has been closed or we get -1 on the read for any reason, throw an EOFException
*/
def read(channel: ReadableByteChannel, buffer: ByteBuffer): Int = {
channel.read(buffer) match {
case -1 => throw new EOFException("Received -1 when reading from channel, socket has likely been closed.")
case n: Int => n
}
}
def notNull[V](v: V) = {
if(v == null)
throw new IllegalArgumentException("Value cannot be null.")
else
v
}
def getHostPort(hostport: String) : Tuple2[String, Int] = {
val splits = hostport.split(":")
(splits(0), splits(1).toInt)
}
def getTopicPartition(topicPartition: String) : Tuple2[String, Int] = {
val index = topicPartition.lastIndexOf('-')
(topicPartition.substring(0,index), topicPartition.substring(index+1).toInt)
}
def stackTrace(e: Throwable): String = {
val sw = new StringWriter;
val pw = new PrintWriter(sw);
e.printStackTrace(pw);
sw.toString();
}
/**
* This method gets comma separated values which contains key,value pairs and returns a map of
* key value pairs. the format of allCSVal is key1:val1, key2:val2 ....
*/
private def getCSVMap[K, V](allCSVals: String, exceptionMsg:String, successMsg:String) :Map[K, V] = {
val map = new mutable.HashMap[K, V]
if("".equals(allCSVals))
return map
val csVals = allCSVals.split(",")
for(i <- 0 until csVals.length)
{
try{
val tempSplit = csVals(i).split(":")
info(successMsg + tempSplit(0) + " : " + Integer.parseInt(tempSplit(1).trim))
map += tempSplit(0).asInstanceOf[K] -> Integer.parseInt(tempSplit(1).trim).asInstanceOf[V]
} catch {
case _ => error(exceptionMsg + ": " + csVals(i))
}
}
map
}
def getCSVList(csvList: String): Seq[String] = {
if(csvList == null)
Seq.empty[String]
else {
csvList.split(",").filter(v => !v.equals(""))
}
}
def getTopicRetentionHours(retentionHours: String) : Map[String, Int] = {
val exceptionMsg = "Malformed token for topic.log.retention.hours in server.properties: "
val successMsg = "The retention hours for "
val map: Map[String, Int] = getCSVMap(retentionHours, exceptionMsg, successMsg)
map.foreach{case(topic, hrs) =>
require(hrs > 0, "Log retention hours value for topic " + topic + " is " + hrs +
" which is not greater than 0.")}
map
}
def getTopicRollHours(rollHours: String) : Map[String, Int] = {
val exceptionMsg = "Malformed token for topic.log.roll.hours in server.properties: "
val successMsg = "The roll hours for "
val map: Map[String, Int] = getCSVMap(rollHours, exceptionMsg, successMsg)
map.foreach{case(topic, hrs) =>
require(hrs > 0, "Log roll hours value for topic " + topic + " is " + hrs +
" which is not greater than 0.")}
map
}
def getTopicFileSize(fileSizes: String): Map[String, Int] = {
val exceptionMsg = "Malformed token for topic.log.file.size in server.properties: "
val successMsg = "The log file size for "
val map: Map[String, Int] = getCSVMap(fileSizes, exceptionMsg, successMsg)
map.foreach{case(topic, size) =>
require(size > 0, "Log file size value for topic " + topic + " is " + size +
" which is not greater than 0.")}
map
}
def getTopicRetentionSize(retentionSizes: String): Map[String, Long] = {
val exceptionMsg = "Malformed token for topic.log.retention.size in server.properties: "
val successMsg = "The log retention size for "
val map: Map[String, Long] = getCSVMap(retentionSizes, exceptionMsg, successMsg)
map.foreach{case(topic, size) =>
require(size > 0, "Log retention size value for topic " + topic + " is " + size +
" which is not greater than 0.")}
map
}
def getTopicFlushIntervals(allIntervals: String) : Map[String, Int] = {
val exceptionMsg = "Malformed token for topic.flush.Intervals.ms in server.properties: "
val successMsg = "The flush interval for "
val map: Map[String, Int] = getCSVMap(allIntervals, exceptionMsg, successMsg)
map.foreach{case(topic, interval) =>
require(interval > 0, "Flush interval value for topic " + topic + " is " + interval +
" ms which is not greater than 0.")}
map
}
def getTopicPartitions(allPartitions: String) : Map[String, Int] = {
val exceptionMsg = "Malformed token for topic.partition.counts in server.properties: "
val successMsg = "The number of partitions for topic "
val map: Map[String, Int] = getCSVMap(allPartitions, exceptionMsg, successMsg)
map.foreach{case(topic, count) =>
require(count > 0, "The number of partitions for topic " + topic + " is " + count +
" which is not greater than 0.")}
map
}
def getConsumerTopicMap(consumerTopicString: String) : Map[String, Int] = {
val exceptionMsg = "Malformed token for embeddedconsumer.topics in consumer.properties: "
val successMsg = "The number of consumer threads for topic "
val map: Map[String, Int] = getCSVMap(consumerTopicString, exceptionMsg, successMsg)
map.foreach{case(topic, count) =>
require(count > 0, "The number of consumer threads for topic " + topic + " is " + count +
" which is not greater than 0.")}
map
}
def getObject[T<:AnyRef](className: String): T = {
className match {
case null => null.asInstanceOf[T]
case _ =>
val clazz = Class.forName(className)
val clazzT = clazz.asInstanceOf[Class[T]]
val constructors = clazzT.getConstructors
require(constructors.length == 1)
constructors.head.newInstance().asInstanceOf[T]
}
}
def propertyExists(prop: String): Boolean = {
if(prop == null)
false
else if(prop.compareTo("") == 0)
false
else true
}
def getCompressionCodec(props: Properties, codec: String): CompressionCodec = {
val codecValueString = props.getProperty(codec)
if(codecValueString == null)
NoCompressionCodec
else
CompressionCodec.getCompressionCodec(codecValueString.toInt)
}
def tryCleanupZookeeper(zkUrl: String, groupId: String) {
try {
val dir = "/consumers/" + groupId
logger.info("Cleaning up temporary zookeeper data under " + dir + ".")
val zk = new ZkClient(zkUrl, 30*1000, 30*1000, ZKStringSerializer)
zk.deleteRecursive(dir)
zk.close()
} catch {
case _ => // swallow
}
}
def checkRequiredArgs(parser: OptionParser, options: OptionSet, required: OptionSpec[_]*) {
for(arg <- required) {
if(!options.has(arg)) {
error("Missing required argument \"" + arg + "\"")
parser.printHelpOn(System.err)
System.exit(1)
}
}
}
/**
* Create a circular (looping) iterator over a collection.
* @param coll An iterable over the underlying collection.
* @return A circular iterator over the collection.
*/
def circularIterator[T](coll: Iterable[T]) = {
val stream: Stream[T] =
for (forever <- Stream.continually(1); t <- coll) yield t
stream.iterator
}
}
class SnapshotStats(private val monitorDurationNs: Long = 600L * 1000L * 1000L * 1000L) {
private val time: Time = SystemTime
private val complete = new AtomicReference(new Stats())
private val current = new AtomicReference(new Stats())
private val total = new AtomicLong(0)
private val numCumulatedRequests = new AtomicLong(0)
def recordRequestMetric(requestNs: Long) {
val stats = current.get
stats.add(requestNs)
total.getAndAdd(requestNs)
numCumulatedRequests.getAndAdd(1)
val ageNs = time.nanoseconds - stats.start
// if the current stats are too old it is time to swap
if(ageNs >= monitorDurationNs) {
val swapped = current.compareAndSet(stats, new Stats())
if(swapped) {
complete.set(stats)
stats.end.set(time.nanoseconds)
}
}
}
def recordThroughputMetric(data: Long) {
val stats = current.get
stats.addData(data)
val ageNs = time.nanoseconds - stats.start
// if the current stats are too old it is time to swap
if(ageNs >= monitorDurationNs) {
val swapped = current.compareAndSet(stats, new Stats())
if(swapped) {
complete.set(stats)
stats.end.set(time.nanoseconds)
}
}
}
def getNumRequests(): Long = numCumulatedRequests.get
def getRequestsPerSecond: Double = {
val stats = complete.get
stats.numRequests / stats.durationSeconds
}
def getThroughput: Double = {
val stats = complete.get
stats.totalData / stats.durationSeconds
}
def getAvgMetric: Double = {
val stats = complete.get
if (stats.numRequests == 0) {
0
}
else {
stats.totalRequestMetric / stats.numRequests
}
}
def getTotalMetric: Long = total.get
def getMaxMetric: Double = complete.get.maxRequestMetric
class Stats {
val start = time.nanoseconds
var end = new AtomicLong(-1)
var numRequests = 0
var totalRequestMetric: Long = 0L
var maxRequestMetric: Long = 0L
var totalData: Long = 0L
private val lock = new Object()
def addData(data: Long) {
lock synchronized {
totalData += data
}
}
def add(requestNs: Long) {
lock synchronized {
numRequests +=1
totalRequestMetric += requestNs
maxRequestMetric = scala.math.max(maxRequestMetric, requestNs)
}
}
def durationSeconds: Double = (end.get - start) / (1000.0 * 1000.0 * 1000.0)
def durationMs: Double = (end.get - start) / (1000.0 * 1000.0)
}
}
/**
* A wrapper that synchronizes JSON in scala, which is not threadsafe.
*/
object SyncJSON extends Logging {
val myConversionFunc = {input : String => input.toInt}
JSON.globalNumberParser = myConversionFunc
val lock = new Object
def parseFull(input: String): Option[Any] = {
lock synchronized {
try {
JSON.parseFull(input)
} catch {
case t =>
throw new RuntimeException("Can't parse json string: %s".format(input), t)
}
}
}
}