| #!/usr/bin/env python |
| |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| |
| from ThriftTest.ttypes import Bonk, VersioningTestV1, VersioningTestV2 |
| from thrift.protocol import TJSONProtocol |
| from thrift.transport import TTransport |
| |
| import json |
| import unittest |
| |
| |
| class SimpleJSONProtocolTest(unittest.TestCase): |
| protocol_factory = TJSONProtocol.TSimpleJSONProtocolFactory() |
| |
| def _assertDictEqual(self, a, b, msg=None): |
| if hasattr(self, 'assertDictEqual'): |
| # assertDictEqual only in Python 2.7. Depends on your machine. |
| self.assertDictEqual(a, b, msg) |
| return |
| |
| # Substitute implementation not as good as unittest library's |
| self.assertEquals(len(a), len(b), msg) |
| for k, v in a.iteritems(): |
| self.assertTrue(k in b, msg) |
| self.assertEquals(b.get(k), v, msg) |
| |
| def _serialize(self, obj): |
| trans = TTransport.TMemoryBuffer() |
| prot = self.protocol_factory.getProtocol(trans) |
| obj.write(prot) |
| return trans.getvalue() |
| |
| def _deserialize(self, objtype, data): |
| prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data)) |
| ret = objtype() |
| ret.read(prot) |
| return ret |
| |
| def testWriteOnly(self): |
| self.assertRaises(NotImplementedError, |
| self._deserialize, VersioningTestV1, b'{}') |
| |
| def testSimpleMessage(self): |
| v1obj = VersioningTestV1( |
| begin_in_both=12345, |
| old_string='aaa', |
| end_in_both=54321) |
| expected = dict(begin_in_both=v1obj.begin_in_both, |
| old_string=v1obj.old_string, |
| end_in_both=v1obj.end_in_both) |
| actual = json.loads(self._serialize(v1obj).decode('ascii')) |
| |
| self._assertDictEqual(expected, actual) |
| |
| def testComplicated(self): |
| v2obj = VersioningTestV2( |
| begin_in_both=12345, |
| newint=1, |
| newbyte=2, |
| newshort=3, |
| newlong=4, |
| newdouble=5.0, |
| newstruct=Bonk(message="Hello!", type=123), |
| newlist=[7, 8, 9], |
| newset=set([42, 1, 8]), |
| newmap={1: 2, 2: 3}, |
| newstring="Hola!", |
| end_in_both=54321) |
| expected = dict(begin_in_both=v2obj.begin_in_both, |
| newint=v2obj.newint, |
| newbyte=v2obj.newbyte, |
| newshort=v2obj.newshort, |
| newlong=v2obj.newlong, |
| newdouble=v2obj.newdouble, |
| newstruct=dict(message=v2obj.newstruct.message, |
| type=v2obj.newstruct.type), |
| newlist=v2obj.newlist, |
| newset=list(v2obj.newset), |
| newmap=v2obj.newmap, |
| newstring=v2obj.newstring, |
| end_in_both=v2obj.end_in_both) |
| |
| # Need to load/dump because map keys get escaped. |
| expected = json.loads(json.dumps(expected)) |
| actual = json.loads(self._serialize(v2obj).decode('ascii')) |
| self._assertDictEqual(expected, actual) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |