blob: 1684a35a17c9112a89352431542db03864601d12 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import re
import sys
from collections import defaultdict
from common import ValExpr, ValExprList
module_contents = dict()
class DataTypeMetaclass(type):
'''Provides sorting of classes used to determine upcasting.'''
def __init__(cls, name, bases, dict):
super(DataTypeMetaclass, cls).__init__(name, bases, dict)
if name in ('Char', 'DataType', 'Decimal', 'Float', 'Int', 'Number', 'Timestamp'):
cls.type = cls
else:
cls.type = cls.get_generic_type()
def __cmp__(cls, other):
if not isinstance(other, DataTypeMetaclass):
return -1
return cmp(
getattr(cls, 'CMP_VALUE', cls.__name__),
getattr(other, 'CMP_VALUE', other.__name__))
class DataType(ValExpr):
'''Base class for data types.
Data types are represented as classes so inheritance can be used.
'''
__metaclass__ = DataTypeMetaclass
@staticmethod
def group_by_type(vals):
'''Group cols by their data type and return a dict of the results.'''
vals_by_type = defaultdict(ValExprList)
for val in vals:
vals_by_type[val.type].append(val)
return vals_by_type
@classmethod
def get_base_type(cls):
'''This should only be called from a subclass to find the type that is just below
DataType in the class hierarchy. For example Int and Decimal would both return
Number as their base type.
'''
if DataType in cls.__bases__:
return cls
for base in cls.__bases__:
if issubclass(base, DataType):
return base.get_base_type()
raise Exception('Unable to determine base type of %s' % cls)
@classmethod
def get_generic_type(cls):
return cls.get_base_type()
@classmethod
def name(cls):
return cls.__name__
@classmethod
def is_approximate(cls):
return False
def __init__(self, val):
self.val = val
@property
def exact_type(self):
return type(self)
class Boolean(DataType):
pass
class Number(DataType):
pass
class Int(Number):
@classmethod
def get_generic_type(cls):
return Int
# Used to compare with other numbers for determining upcasting
CMP_VALUE = 2
# Used during data generation to keep vals in range
MIN = -2 ** 31
MAX = -MIN - 1
class TinyInt(Int):
CMP_VALUE = 0
MIN = -2 ** 7
MAX = -MIN - 1
class SmallInt(Int):
CMP_VALUE = 1
MIN = -2 ** 15
MAX = -MIN - 1
class BigInt(Int):
CMP_VALUE = 3
MIN = -2 ** 63
MAX = -MIN - 1
class Decimal(Number):
@classmethod
def get_generic_type(cls):
return Decimal
CMP_VALUE = 4
MAX_DIGITS = 38 # Arbitrary default values
MAX_FRACTIONAL_DIGITS = 10 # Arbitrary default values
class Float(Number):
@classmethod
def get_generic_type(cls):
return Float
@classmethod
def is_approximate(cls):
return True
CMP_VALUE = 5
class Double(Float):
CMP_VALUE = 6
class Char(DataType):
CMP_VALUE = 100
MIN = 0
MAX = 255 # This is not the true max
class VarChar(Char):
CMP_VALUE = 101
MAX = 255 # Not a true max. This is used to differentiate between VarChar and String.
class String(VarChar):
CMP_VALUE = 102
MIN = VarChar.MAX + 1 # This is used to differentiate between VarChar and String.
MAX = 1000 # This is not the true max.
class Timestamp(DataType):
pass
EXACT_TYPES = [
BigInt,
Boolean,
Char,
Decimal,
Double,
Float,
Int,
SmallInt,
String,
Timestamp,
TinyInt,
VarChar]
JOINABLE_TYPES = (Char, Decimal, Int, Timestamp)
TYPES = tuple(set(type_.type for type_ in EXACT_TYPES))
__DECIMAL_TYPE_CACHE = dict()
def get_decimal_class(total_digits, fractional_digits):
cache_key = (total_digits, fractional_digits)
if cache_key not in __DECIMAL_TYPE_CACHE:
__DECIMAL_TYPE_CACHE[cache_key] = type(
'Decimal%02d%02d' % (total_digits, fractional_digits),
(Decimal, ),
{'MAX_DIGITS': total_digits, 'MAX_FRACTIONAL_DIGITS': fractional_digits})
return __DECIMAL_TYPE_CACHE[cache_key]
__CHAR_TYPE_CACHE = dict()
def get_char_class(length):
if length not in __CHAR_TYPE_CACHE:
__CHAR_TYPE_CACHE[length] = type(
'Char%04d' % length,
(Char, ),
{'MAX': length})
return __CHAR_TYPE_CACHE[length]
__VARCHAR_TYPE_CACHE = dict()
def get_varchar_class(length):
if length not in __VARCHAR_TYPE_CACHE:
__VARCHAR_TYPE_CACHE[length] = type(
'VarChar%04d' % length,
(VarChar, ),
{'MAX': length})
return __VARCHAR_TYPE_CACHE[length]
class ModuleWrapper(object):
def __init__(self, module):
self.module = module
self.decimal_class_pattern = re.compile(r"Decimal(\d{2})(\d{2})")
self.char_class_pattern = re.compile(r"Char(\d+)")
self.varchar_class_pattern = re.compile(r"VarChar(\d+)")
def __getattr__(self, name):
match = self.decimal_class_pattern.match(name)
if match:
return self.get_decimal_class(int(match.group(1)), int(match.group(2)))
match = self.char_class_pattern.match(name)
if match:
return self.get_char_class(int(match.group(1)))
match = self.varchar_class_pattern.match(name)
if match:
return self.get_varchar_class(int(match.group(1)))
return getattr(self.module, name)
sys.modules[__name__] = ModuleWrapper(sys.modules[__name__])