blob: d0a58fef38fbd6d5691ee0429027b3d7ab2c9e91 [file] [log] [blame]
# -*- coding: utf-8 -*-
import struct
from common.constants import DEFAULT_REQUEST_META, INT_DIRECT_MAX, INT_DIRECT_MIN, BC_INT_ZERO, INT_BYTE_MAX, \
INT_BYTE_MIN, BC_INT_BYTE_ZERO, INT_SHORT_MIN, INT_SHORT_MAX, BC_INT_SHORT_ZERO, BC_DOUBLE_ZERO, BC_DOUBLE_ONE, \
BC_DOUBLE_BYTE, BC_DOUBLE_MILL, STRING_DIRECT_MAX, BC_STRING_DIRECT, STRING_SHORT_MAX, BC_STRING_SHORT, \
BC_DOUBLE_SHORT
from common.exceptions import HessianTypeError
from common.util import double_to_long_bits, num_2_byte_list
def encode(request):
"""
把请求序列化为字节数组
:param request:
:return:
"""
request_body = _encode_request_body(request)
request_head = DEFAULT_REQUEST_META + _get_request_body_length(request_body)
return bytearray(request_head + request_body)
def _encode_request_body(body):
"""
对所有的已知的参数根据dubbo协议进行编码
:param body:
:return:
"""
dubbo_version = body['dubbo_version']
path = body['path']
version = body['version']
method = body['method']
arguments = body['arguments']
parameter_types = ''
# 判断并得出参数的类型
for argument in arguments:
if isinstance(argument, bool): # bool类型的判断必须放在int类型判断的前面
parameter_types += 'Z'
elif isinstance(argument, int):
if -2147483648 <= argument <= 2147483647:
parameter_types += 'I'
else:
parameter_types += 'J'
elif isinstance(argument, float):
parameter_types += 'D'
elif isinstance(argument, str):
parameter_types += 'Ljava/lang/String;'
else:
raise HessianTypeError('Unknown argument type: {0}'.format(argument))
body = []
body.extend(_encode_single_value(dubbo_version))
body.extend(_encode_single_value(path))
body.extend(_encode_single_value(version))
body.extend(_encode_single_value(method))
body.extend(_encode_single_value(parameter_types))
for argument in arguments:
body.extend(_encode_single_value(argument))
attachments = {
'path': path,
'interface': path,
'version': version
}
# attachments参数以H开头,以Z结尾
body.append(ord('H'))
for key in attachments.keys():
value = attachments[key]
body.extend(_encode_single_value(key))
body.extend(_encode_single_value(value))
body.append(ord('Z'))
# 因为在上面的逻辑中没有对byte大小进行检测,所以在这里进行统一的处理
for i in range(len(body)):
body[i] = body[i] & 0xff
return body
def _encode_single_value(value):
"""
根据hessian协议对单个变量进行编码
:param value:
:return:
"""
result = []
if isinstance(value, bool):
if value:
result.append(ord('T'))
else:
result.append(ord('F'))
return result
elif isinstance(value, int):
if value > 2147483647 or value < -2147483648:
result.append(ord('L'))
result.extend(list(bytearray(struct.pack('>q', value))))
return result
if INT_DIRECT_MIN <= value <= INT_DIRECT_MAX:
result.append(value + BC_INT_ZERO)
elif INT_BYTE_MIN <= value <= INT_BYTE_MAX:
result.append(BC_INT_BYTE_ZERO + (value >> 8))
result.append(value)
elif INT_SHORT_MIN <= value <= INT_SHORT_MAX:
result.append(BC_INT_SHORT_ZERO + (value >> 16))
result.append(value >> 8)
result.append(value)
else:
result.append(ord('I'))
result.append(value >> 24)
result.append(value >> 16)
result.append(value >> 8)
result.append(value)
return result
elif isinstance(value, float):
int_value = int(value)
if int_value == value:
if int_value == 0:
result.append(BC_DOUBLE_ZERO)
return result
elif int_value == 1:
result.append(BC_DOUBLE_ONE)
return result
elif -0x80 <= int_value < 0x80:
result.append(BC_DOUBLE_BYTE)
result.append(int_value)
return result
elif -0x8000 <= int_value < 0x8000:
result.append(BC_DOUBLE_SHORT)
result.append(int_value >> 8)
result.append(int_value)
return result
mills = int(value * 1000)
if 0.001 * mills == value and -2147483648 <= mills <= 2147483647:
result.append(BC_DOUBLE_MILL)
result.append(mills >> 24)
result.append(mills >> 16)
result.append(mills >> 8)
result.append(mills)
return result
bits = double_to_long_bits(value)
result.append(ord('D'))
result.append(bits >> 56)
result.append(bits >> 48)
result.append(bits >> 40)
result.append(bits >> 32)
result.append(bits >> 24)
result.append(bits >> 16)
result.append(bits >> 8)
result.append(bits)
return result
elif isinstance(value, str):
# 根据hessian协议这里的长度必须是字符串长度而不是字节长度,所以需要Unicode类型
length = len(value.decode('utf-8'))
if length <= STRING_DIRECT_MAX:
result.append(BC_STRING_DIRECT + length)
elif length <= STRING_SHORT_MAX:
result.append(BC_STRING_SHORT + (length >> 8))
result.append(length)
else:
result.append(ord('S'))
result.append(length >> 8)
result.append(length)
result.extend(list(bytearray(value))) # 加上变量数组
return result
else:
raise HessianTypeError('Unknown argument type: {0}'.format(value))
def _get_request_body_length(body):
"""
获取body的长度,并将其转为长度为4个字节的字节数组
:param body:
:return:
"""
request_body_length = num_2_byte_list(len(body))
# 用4个字节表示请求body的长度
while len(request_body_length) < 4:
request_body_length = [0] + request_body_length
return request_body_length