blob: ca1466c261709d2e5c841c807c56b977ba2e6c20 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import threading, struct, datetime, time
from exceptions import Timeout
class Struct:
def __init__(self, _type, *args, **kwargs):
if len(args) > len(_type.fields):
raise TypeError("%s() takes at most %s arguments (%s given)" %
(_type.name, len(_type.fields), len(args)))
self._type = _type
idx = 0
for field in _type.fields:
if idx < len(args):
arg = args[idx]
if kwargs.has_key(field.name):
raise TypeError("%s() got multiple values for keyword argument '%s'" %
(_type.name, field.name))
elif kwargs.has_key(field.name):
arg = kwargs.pop(field.name)
else:
arg = field.default()
setattr(self, field.name, arg)
idx += 1
if kwargs:
unexpected = kwargs.keys()[0]
raise TypeError("%s() got an unexpected keyword argument '%s'" %
(_type.name, unexpected))
def __getitem__(self, name):
return getattr(self, name)
def __setitem__(self, name, value):
if not hasattr(self, name):
raise AttributeError("'%s' object has no attribute '%s'" %
(self._type.name, name))
setattr(self, name, value)
def __repr__(self):
fields = []
for f in self._type.fields:
v = self[f.name]
if f.type.is_present(v):
fields.append("%s=%r" % (f.name, v))
return "%s(%s)" % (self._type.name, ", ".join(fields))
class Message:
def __init__(self, *args):
if args:
self.body = args[-1]
else:
self.body = None
if len(args) > 1:
self.headers = list(args[:-1])
else:
self.headers = None
self.id = None
def has(self, name):
return self.get(name) != None
def get(self, name):
if self.headers:
for h in self.headers:
if h.NAME == name:
return h
return None
def set(self, header):
if self.headers is None:
self.headers = []
idx = 0
while idx < len(self.headers):
if self.headers[idx].NAME == header.NAME:
self.headers[idx] = header
return
idx += 1
self.headers.append(header)
def clear(self, name):
idx = 0
while idx < len(self.headers):
if self.headers[idx].NAME == name:
del self.headers[idx]
return
idx += 1
def __repr__(self):
args = []
if self.headers:
args.extend(map(repr, self.headers))
if self.body:
args.append(repr(self.body))
if self.id is not None:
args.append("id=%s" % self.id)
return "Message(%s)" % ", ".join(args)
def serial(o):
if isinstance(o, Serial):
return o
else:
return Serial(o)
class Serial:
def __init__(self, value):
self.value = value & 0xFFFFFFFFL
def __hash__(self):
return hash(self.value)
def __cmp__(self, other):
if other.__class__ not in (int, long, Serial):
return 1
other = serial(other)
delta = (self.value - other.value) & 0xFFFFFFFFL
neg = delta & 0x80000000L
mag = delta & 0x7FFFFFFF
if neg:
return -mag
else:
return mag
def __add__(self, other):
return Serial(self.value + other)
def __sub__(self, other):
if isinstance(other, Serial):
return self.value - other.value
else:
return Serial(self.value - other)
def __repr__(self):
return "serial(%s)" % self.value
def __str__(self):
return str(self.value)
class Range:
def __init__(self, lower, upper = None):
self.lower = serial(lower)
if upper is None:
self.upper = self.lower
else:
self.upper = serial(upper)
def __contains__(self, n):
return self.lower <= n and n <= self.upper
def __iter__(self):
i = self.lower
while i <= self.upper:
yield i
i += 1
def touches(self, r):
# XXX: are we doing more checks than we need?
return (self.lower - 1 in r or
self.upper + 1 in r or
r.lower - 1 in self or
r.upper + 1 in self or
self.lower in r or
self.upper in r or
r.lower in self or
r.upper in self)
def span(self, r):
return Range(min(self.lower, r.lower), max(self.upper, r.upper))
def intersect(self, r):
lower = max(self.lower, r.lower)
upper = min(self.upper, r.upper)
if lower > upper:
return None
else:
return Range(lower, upper)
def __repr__(self):
return "%s-%s" % (self.lower, self.upper)
class RangedSet:
def __init__(self, *args):
self.ranges = []
for n in args:
self.add(n)
def __contains__(self, n):
for r in self.ranges:
if n in r:
return True
return False
def add_range(self, range):
idx = 0
while idx < len(self.ranges):
r = self.ranges[idx]
if range.touches(r):
del self.ranges[idx]
range = range.span(r)
elif range.upper < r.lower:
self.ranges.insert(idx, range)
return
else:
idx += 1
self.ranges.append(range)
def add(self, lower, upper = None):
self.add_range(Range(lower, upper))
def empty(self):
for r in self.ranges:
if r.lower <= r.upper:
return False
return True
def max(self):
if self.ranges:
return self.ranges[-1].upper
else:
return None
def min(self):
if self.ranges:
return self.ranges[0].lower
else:
return None
def __iter__(self):
return iter(self.ranges)
def __repr__(self):
return str(self.ranges)
class Future:
def __init__(self, initial=None, exception=Exception):
self.value = initial
self._error = None
self._set = threading.Event()
self.exception = exception
def error(self, error):
self._error = error
self._set.set()
def set(self, value):
self.value = value
self._set.set()
def get(self, timeout=None):
self._set.wait(timeout)
if self._set.isSet():
if self._error != None:
raise self.exception(self._error)
return self.value
else:
raise Timeout()
def is_set(self):
return self._set.isSet()
try:
from uuid import uuid4
from uuid import UUID
except ImportError:
class UUID:
def __init__(self, hex=None, bytes=None):
if [hex, bytes].count(None) != 1:
raise TypeErrror("need one of hex or bytes")
if bytes is not None:
self.bytes = bytes
elif hex is not None:
fields=hex.split("-")
fields[4:5] = [fields[4][:4], fields[4][4:]]
self.bytes = struct.pack("!LHHHHL", *[int(x,16) for x in fields])
def __cmp__(self, other):
if isinstance(other, UUID):
return cmp(self.bytes, other.bytes)
else:
return -1
def __str__(self):
return "%08x-%04x-%04x-%04x-%04x%08x" % struct.unpack("!LHHHHL", self.bytes)
def __repr__(self):
return "UUID(%r)" % str(self)
def __hash__(self):
return self.bytes.__hash__()
import os, random, socket, time
rand = random.Random()
rand.seed((os.getpid(), time.time(), socket.gethostname()))
def random_uuid():
bytes = [rand.randint(0, 255) for i in xrange(16)]
# From RFC4122, the version bits are set to 0100
bytes[7] &= 0x0F
bytes[7] |= 0x40
# From RFC4122, the top two bits of byte 8 get set to 01
bytes[8] &= 0x3F
bytes[8] |= 0x80
return "".join(map(chr, bytes))
def uuid4():
return UUID(bytes=random_uuid())
def parseUUID(str):
return UUID(hex=str)
class timestamp(float):
def __new__(cls, obj=None):
if obj is None:
obj = time.time()
elif isinstance(obj, datetime.datetime):
obj = time.mktime(obj.timetuple()) + 1e-6 * obj.microsecond
return super(timestamp, cls).__new__(cls, obj)
def datetime(self):
return datetime.datetime.fromtimestamp(self)
def __add__(self, other):
if isinstance(other, datetime.timedelta):
return timestamp(self.datetime() + other)
else:
return timestamp(float(self) + other)
def __sub__(self, other):
if isinstance(other, datetime.timedelta):
return timestamp(self.datetime() - other)
else:
return timestamp(float(self) - other)
def __radd__(self, other):
if isinstance(other, datetime.timedelta):
return timestamp(self.datetime() + other)
else:
return timestamp(other + float(self))
def __rsub__(self, other):
if isinstance(other, datetime.timedelta):
return timestamp(self.datetime() - other)
else:
return timestamp(other - float(self))
def __neg__(self):
return timestamp(-float(self))
def __pos__(self):
return self
def __abs__(self):
return timestamp(abs(float(self)))
def __repr__(self):
return "timestamp(%r)" % float(self)