在请求方法中支持列表对象
diff --git a/README.md b/README.md
index 61f54e7..880ba55 100644
--- a/README.md
+++ b/README.md
@@ -34,6 +34,7 @@
| 整型 | int, long | int |
| 浮点类型 | float, double | float |
| 字符串类型 | java.lang.String | str |
+| 列表类型 | _可迭代的对象_ | [] |
| 自定义的对象类型 | java.lang.Object | ↓ _具体使用方法如下所示_ ↓ |
##### 使用Java的对象类型
diff --git a/dubbo/codec/encoder.py b/dubbo/codec/encoder.py
index eec2c8f..1f01b74 100644
--- a/dubbo/codec/encoder.py
+++ b/dubbo/codec/encoder.py
@@ -70,6 +70,7 @@
def __init__(self, request):
self.__body = request
self.__classes = []
+ self.types = [] # 泛型
def encode(self):
"""
@@ -80,8 +81,7 @@
request_head = DEFAULT_REQUEST_META + get_request_body_length(request_body)
return bytearray(request_head + request_body)
- @staticmethod
- def _get_parameter_types(arguments):
+ def _get_parameter_types(self, arguments):
"""
针对所有的参数计算得到参数类型字符串
:param arguments:
@@ -90,25 +90,38 @@
parameter_types = ''
# 判断并得出参数的类型
for argument in arguments:
- if isinstance(argument, bool): # bool类型的判断必须放在int类型判断的前面
- parameter_types += 'Z'
- elif isinstance(argument, int):
- if MIN_INT_32 <= argument <= MAX_INT_32:
- parameter_types += 'I'
- else:
- parameter_types += 'J'
- elif isinstance(argument, float):
- parameter_types += 'D'
- elif isinstance(argument, (str, unicode)):
- parameter_types += 'Ljava/lang/String;'
- elif isinstance(argument, Object):
- path = argument.get_path()
- path = 'L' + path.replace('.', '/') + ';'
- parameter_types += path
- else:
- raise HessianTypeError('Unknown argument type: {0}'.format(argument))
+ parameter_types += self._get_class_name(argument)
return parameter_types
+ def _get_class_name(self, _class):
+ """
+ 根据一个字段的类型得到其在Java中对应类的全限定名
+ 转换规则:https://stackoverflow.com/a/3442100/4614538
+ :param _class:
+ :return:
+ """
+ if isinstance(_class, bool): # bool类型的判断必须放在int类型判断的前面
+ return 'Z'
+ elif isinstance(_class, int):
+ if MIN_INT_32 <= _class <= MAX_INT_32:
+ return 'I'
+ else:
+ return 'J'
+ elif isinstance(_class, float):
+ return 'D'
+ elif isinstance(_class, (str, unicode)):
+ return 'L' + 'java/lang/String' + ';'
+ elif isinstance(_class, Object):
+ path = _class.get_path()
+ path = 'L' + path.replace('.', '/') + ';'
+ return path
+ elif isinstance(_class, list):
+ if len(_class) == 0:
+ raise HessianTypeError('Method parameter {} is a list but length is zero'.format(_class))
+ return '[' + self._get_class_name(_class[0])
+ else:
+ raise HessianTypeError('Unknown argument type: {0}'.format(_class))
+
def _encode_request_body(self):
"""
对所有的已知的参数根据dubbo协议进行编码
@@ -279,8 +292,50 @@
for field_name in field_names:
result.extend(self._encode_single_value(value[field_name]))
return result
+ # 列表(list)类型,不可以使用tuple替代
+ elif isinstance(value, list):
+ length = len(value)
+ if length == 0:
+ # 没有值则无法判断类型,一律返回null
+ return self._encode_single_value(None)
+ if isinstance(value[0], bool):
+ _type = '[boolean'
+ elif isinstance(value[0], int):
+ _type = '[int'
+ elif isinstance(value[0], float):
+ _type = '[double'
+ elif isinstance(value[0], str):
+ _type = '[string'
+ elif isinstance(value[0], Object):
+ _type = '[object'
+ else:
+ raise HessianTypeError('Unknown list type: {}'.format(value[0]))
+ if length < 0x7:
+ result.append(0x70 + length)
+ if _type not in self.types:
+ self.types.append(_type)
+ result.extend(self._encode_single_value(_type))
+ else:
+ result.extend(self._encode_single_value(self.types.index(_type)))
+ else:
+ result.append(0x56)
+ if _type not in self.types:
+ self.types.append(_type)
+ result.extend(self._encode_single_value(_type))
+ else:
+ result.extend(self._encode_single_value(self.types.index(_type)))
+ result.extend(self._encode_single_value(length))
+ for v in value:
+ if type(value[0]) != type(v):
+ raise HessianTypeError('All elements in list must be the same type, first type'
+ ' is {0} but current type is {1}'.format(type(value[0]), type(v)))
+ result.extend(self._encode_single_value(v))
+ return result
+ elif value is None:
+ result.append(ord('N'))
+ return result
else:
- raise HessianTypeError('Unknown argument type: {0}'.format(value))
+ raise HessianTypeError('Unknown argument type: {}'.format(value))
def get_request_body_length(body):
diff --git a/tests/dubbo_test.py b/tests/dubbo_test.py
index dd5531f..56bd483 100644
--- a/tests/dubbo_test.py
+++ b/tests/dubbo_test.py
@@ -207,6 +207,29 @@
# result = dubbo.call('echo23')
pretty_print(result)
+ def test_array(self):
+ location1 = Object('me.hourui.echo.bean.Location')
+ location1['province'] = '江苏省'
+ location1['city'] = '南京市'
+ location1['street'] = '软件大道'
+ location2 = Object('me.hourui.echo.bean.Location')
+ location2['province'] = '浙江省'
+ location2['city'] = '杭州市'
+ location2['street'] = '余杭区'
+
+ user1 = Object('me.hourui.echo.bean.User1')
+ user1['name'] = '张三'
+ user2 = Object('me.hourui.echo.bean.User1')
+ user2['name'] = '李四'
+
+ array = Object('me.hourui.echo.bean.Object4Array')
+ array['locations'] = [location1, location2]
+ array['users'] = [user1, user2]
+ array['strings'] = ['这是', '一个', '不可', '重复', '重复', '重复', '重复', '的', '列表']
+
+ dubbo_cli = DubboClient('me.hourui.echo.provider.Echo', host='127.0.0.1:20880')
+ dubbo_cli.call('test4', [['你好', '我好'], [2, 3, 3, 3], array])
+
if __name__ == '__main__':
# test = TestDubbo()
diff --git a/tests/run_test.py b/tests/run_test.py
index 3fa82f4..9fa2597 100644
--- a/tests/run_test.py
+++ b/tests/run_test.py
@@ -16,20 +16,17 @@
# self.dubbo = DubboClient('com.qianmi.pc.item.api.spu', host='172.21.36.82:20880')
def test_run(self):
- # channel = Object('com.qianmi.pc.base.api.constants.ChannelEnum')
- # channel['name'] = 'D2C'
- #
- # spu_query_request = Object('com.qianmi.pc.item.api.spu.request.SpuQueryRequest')
- # spu_query_request['chainMasterId'] = 'A000000'
- # spu_query_request['channel'] = channel
- # spu_query_request['pageSize'] = 2000
- #
- # result = self.spu_query_provider.call('query', spu_query_request)
- # pretty_print(result)
- # print len(result['dataList'])
+ channel = Object('com.qianmi.pc.base.api.constants.ChannelEnum')
+ channel['name'] = 'D2C'
- dubbo_cli = DubboClient('me.hourui.echo.provider.Echo', host='127.0.0.1:20880')
- dubbo_cli.call('echo11')
+ spu_query_request = Object('com.qianmi.pc.item.api.spu.request.SpuQueryRequest')
+ spu_query_request['chainMasterId'] = 'A000000'
+ spu_query_request['channel'] = channel
+ spu_query_request['pageSize'] = 2000
+
+ result = self.spu_query_provider.call('query', spu_query_request)
+ pretty_print(result)
+ print len(result['dataList'])
def pretty_print(value):