| #!/usr/bin/env python |
| |
| ## |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Test the protocol parsing logic.""" |
| |
| from __future__ import absolute_import, division, print_function |
| |
| import json |
| import unittest |
| |
| import avro.protocol |
| import avro.schema |
| |
| try: |
| unicode |
| except NameError: |
| unicode = str |
| |
| try: |
| basestring # type: ignore |
| except NameError: |
| basestring = (bytes, unicode) |
| |
| |
| class TestProtocol(object): |
| """A proxy for a protocol string that provides useful test metadata.""" |
| |
| def __init__(self, data, name='', comment=''): |
| if not isinstance(data, basestring): |
| data = json.dumps(data) |
| self.data = data |
| self.name = name or data |
| self.comment = comment |
| |
| def parse(self): |
| return avro.protocol.parse(str(self)) |
| |
| def __str__(self): |
| return str(self.data) |
| |
| |
| class ValidTestProtocol(TestProtocol): |
| """A proxy for a valid protocol string that provides useful test metadata.""" |
| valid = True |
| |
| |
| class InvalidTestProtocol(TestProtocol): |
| """A proxy for an invalid protocol string that provides useful test metadata.""" |
| valid = False |
| |
| |
| HELLO_WORLD = ValidTestProtocol({ |
| "namespace": "com.acme", |
| "protocol": "HelloWorld", |
| "types": [ |
| {"name": "Greeting", "type": "record", "fields": [ |
| {"name": "message", "type": "string"}]}, |
| {"name": "Curse", "type": "error", "fields": [ |
| {"name": "message", "type": "string"}]} |
| ], |
| "messages": { |
| "hello": { |
| "request": [{"name": "greeting", "type": "Greeting"}], |
| "response": "Greeting", |
| "errors": ["Curse"] |
| } |
| } |
| }) |
| EXAMPLES = [HELLO_WORLD, ValidTestProtocol({ |
| "namespace": "org.apache.avro.test", |
| "protocol": "Simple", |
| "types": [ |
| {"name": "Kind", "type": "enum", "symbols": ["FOO", "BAR", "BAZ"]}, |
| {"name": "MD5", "type": "fixed", "size": 16}, |
| {"name": "TestRecord", "type": "record", "fields": [ |
| {"name": "name", "type": "string", "order": "ignore"}, |
| {"name": "kind", "type": "Kind", "order": "descending"}, |
| {"name": "hash", "type": "MD5"} |
| ]}, |
| {"name": "TestError", "type": "error", "fields": [{"name": "message", "type": "string"}]} |
| ], |
| "messages": { |
| "hello": { |
| "request": [{"name": "greeting", "type": "string"}], |
| "response": "string" |
| }, "echo": { |
| "request": [{"name": "record", "type": "TestRecord"}], |
| "response": "TestRecord" |
| }, "add": { |
| "request": [{"name": "arg1", "type": "int"}, {"name": "arg2", "type": "int"}], |
| "response": "int" |
| }, "echoBytes": { |
| "request": [{"name": "data", "type": "bytes"}], |
| "response": "bytes" |
| }, "error": { |
| "request": [], |
| "response": "null", |
| "errors": ["TestError"] |
| } |
| } |
| }), ValidTestProtocol({ |
| "namespace": "org.apache.avro.test.namespace", |
| "protocol": "TestNamespace", |
| "types": [ |
| {"name": "org.apache.avro.test.util.MD5", "type": "fixed", "size": 16}, |
| {"name": "TestRecord", "type": "record", "fields": [ |
| {"name": "hash", "type": "org.apache.avro.test.util.MD5"} |
| ]}, |
| {"name": "TestError", "namespace": "org.apache.avro.test.errors", "type": "error", |
| "fields": [{"name": "message", "type": "string"}]} |
| ], |
| "messages": { |
| "echo": { |
| "request": [{"name": "record", "type": "TestRecord"}], |
| "response": "TestRecord" |
| }, "error": { |
| "request": [], |
| "response": "null", |
| "errors": ["org.apache.avro.test.errors.TestError"] |
| } |
| } |
| }), ValidTestProtocol({ |
| "namespace": "org.apache.avro.test.namespace", |
| "protocol": "TestImplicitNamespace", |
| "types": [ |
| {"name": "org.apache.avro.test.util.MD5", "type": "fixed", "size": 16}, |
| {"name": "ReferencedRecord", "type": "record", |
| "fields": [{"name": "foo", "type": "string"}]}, |
| {"name": "TestRecord", "type": "record", |
| "fields": [{"name": "hash", "type": "org.apache.avro.test.util.MD5"}, |
| {"name": "unqualified", "type": "ReferencedRecord"}] |
| }, |
| {"name": "TestError", "type": "error", "fields": [{"name": "message", "type": "string"}]} |
| ], |
| "messages": { |
| "echo": { |
| "request": [{"name": "qualified", "type": "org.apache.avro.test.namespace.TestRecord"}], |
| "response": "TestRecord" |
| }, "error": { |
| "request": [], |
| "response": "null", |
| "errors": ["org.apache.avro.test.namespace.TestError"] |
| } |
| } |
| }), ValidTestProtocol({ |
| "namespace": "org.apache.avro.test.namespace", |
| "protocol": "TestNamespaceTwo", |
| "types": [ |
| {"name": "org.apache.avro.test.util.MD5", "type": "fixed", "size": 16}, |
| {"name": "ReferencedRecord", "type": "record", |
| "namespace": "org.apache.avro.other.namespace", |
| "fields": [{"name": "foo", "type": "string"}]}, |
| {"name": "TestRecord", "type": "record", |
| "fields": [{"name": "hash", "type": "org.apache.avro.test.util.MD5"}, |
| {"name": "qualified", |
| "type": "org.apache.avro.other.namespace.ReferencedRecord"}] |
| }, |
| {"name": "TestError", |
| "type": "error", "fields": [{"name": "message", "type": "string"}]}], |
| "messages": { |
| "echo": { |
| "request": [{"name": "qualified", "type": "org.apache.avro.test.namespace.TestRecord"}], |
| "response": "TestRecord" |
| }, "error": { |
| "request": [], |
| "response": "null", |
| "errors": ["org.apache.avro.test.namespace.TestError"] |
| } |
| } |
| }), ValidTestProtocol({ |
| "namespace": "org.apache.avro.test.namespace", |
| "protocol": "TestValidRepeatedName", |
| "types": [ |
| {"name": "org.apache.avro.test.util.MD5", "type": "fixed", "size": 16}, |
| {"name": "ReferencedRecord", "type": "record", |
| "namespace": "org.apache.avro.other.namespace", |
| "fields": [{"name": "foo", "type": "string"}]}, |
| {"name": "ReferencedRecord", "type": "record", |
| "fields": [{"name": "bar", "type": "double"}]}, |
| {"name": "TestError", |
| "type": "error", "fields": [{"name": "message", "type": "string"}]}], |
| "messages": { |
| "echo": { |
| "request": [{"name": "qualified", "type": "ReferencedRecord"}], |
| "response": "org.apache.avro.other.namespace.ReferencedRecord"}, |
| "error": { |
| "request": [], |
| "response": "null", |
| "errors": ["org.apache.avro.test.namespace.TestError"]} |
| } |
| }), InvalidTestProtocol({ |
| "namespace": "org.apache.avro.test.namespace", |
| "protocol": "TestInvalidRepeatedName", |
| "types": [ |
| {"name": "org.apache.avro.test.util.MD5", "type": "fixed", "size": 16}, |
| {"name": "ReferencedRecord", "type": "record", |
| "fields": [{"name": "foo", "type": "string"}]}, |
| {"name": "ReferencedRecord", "type": "record", |
| "fields": [{"name": "bar", "type": "double"}]}, |
| {"name": "TestError", |
| "type": "error", "fields": [{"name": "message", "type": "string"}]}], |
| "messages": { |
| "echo": { |
| "request": [{"name": "qualified", "type": "ReferencedRecord"}], |
| "response": "org.apache.avro.other.namespace.ReferencedRecord" |
| }, "error": { |
| "request": [], |
| "response": "null", |
| "errors": ["org.apache.avro.test.namespace.TestError"] |
| } |
| } |
| }), |
| ValidTestProtocol({ |
| "namespace": "org.apache.avro.test", |
| "protocol": "BulkData", |
| "types": [], |
| "messages": { |
| "read": { |
| "request": [], |
| "response": "bytes" |
| }, "write": { |
| "request": [{"name": "data", "type": "bytes"}], |
| "response": "null" |
| } |
| } |
| }), ValidTestProtocol({ |
| "protocol": "API", |
| "namespace": "xyz.api", |
| "types": [{ |
| "type": "enum", |
| "name": "Symbology", |
| "namespace": "xyz.api.product", |
| "symbols": ["OPRA", "CUSIP", "ISIN", "SEDOL"] |
| }, { |
| "type": "record", |
| "name": "Symbol", |
| "namespace": "xyz.api.product", |
| "fields": [{"name": "symbology", "type": "xyz.api.product.Symbology"}, |
| {"name": "symbol", "type": "string"}] |
| }, { |
| "type": "record", |
| "name": "MultiSymbol", |
| "namespace": "xyz.api.product", |
| "fields": [{"name": "symbols", |
| "type": {"type": "map", "values": "xyz.api.product.Symbol"}}] |
| }], |
| "messages": {} |
| }), |
| ] |
| |
| VALID_EXAMPLES = [e for e in EXAMPLES if getattr(e, "valid", False)] |
| |
| |
| class TestMisc(unittest.TestCase): |
| def test_inner_namespace_set(self): |
| print('') |
| print('TEST INNER NAMESPACE') |
| print('===================') |
| print('') |
| proto = HELLO_WORLD.parse() |
| self.assertEqual(proto.namespace, "com.acme") |
| self.assertEqual(proto.fullname, "com.acme.HelloWorld") |
| greeting_type = proto.types_dict['Greeting'] |
| self.assertEqual(greeting_type.namespace, 'com.acme') |
| |
| def test_inner_namespace_not_rendered(self): |
| proto = HELLO_WORLD.parse() |
| self.assertEqual('com.acme.Greeting', proto.types[0].fullname) |
| self.assertEqual('Greeting', proto.types[0].name) |
| # but there shouldn't be 'namespace' rendered to json on the inner type |
| self.assertFalse('namespace' in proto.to_json()['types'][0]) |
| |
| |
| class ProtocolParseTestCase(unittest.TestCase): |
| """Enable generating parse test cases over all the valid and invalid example protocols.""" |
| |
| def __init__(self, test_proto): |
| """Ignore the normal signature for unittest.TestCase because we are generating |
| many test cases from this one class. This is safe as long as the autoloader |
| ignores this class. The autoloader will ignore this class as long as it has |
| no methods starting with `test_`. |
| """ |
| super(ProtocolParseTestCase, self).__init__( |
| 'parse_valid' if test_proto.valid else 'parse_invalid') |
| self.test_proto = test_proto |
| |
| def parse_valid(self): |
| """Parsing a valid protocol should not error.""" |
| try: |
| self.test_proto.parse() |
| except avro.protocol.ProtocolParseException: |
| self.fail("Valid protocol failed to parse: {!s}".format(self.test_proto)) |
| |
| def parse_invalid(self): |
| """Parsing an invalid schema should error.""" |
| try: |
| self.test_proto.parse() |
| except (avro.protocol.ProtocolParseException, avro.schema.SchemaParseException): |
| pass |
| else: |
| self.fail("Invalid protocol should not have parsed: {!s}".format(self.test_proto)) |
| |
| |
| class ErrorSchemaTestCase(unittest.TestCase): |
| """Enable generating error schema test cases across all the valid test protocols.""" |
| |
| def __init__(self, test_proto): |
| """Ignore the normal signature for unittest.TestCase because we are generating |
| many test cases from this one class. This is safe as long as the autoloader |
| ignores this class. The autoloader will ignore this class as long as it has |
| no methods starting with `test_`. |
| """ |
| super(ErrorSchemaTestCase, self).__init__('check_error_schema_exists') |
| self.test_proto = test_proto |
| |
| def check_error_schema_exists(self): |
| """Protocol messages should always have at least a string error schema.""" |
| p = self.test_proto.parse() |
| for k, m in p.messages.items(): |
| self.assertIsNotNone(m.errors, "Message {} did not have the expected implicit " |
| "string error schema.".format(k)) |
| |
| |
| class RoundTripParseTestCase(unittest.TestCase): |
| """Enable generating round-trip parse test cases over all the valid test protocols.""" |
| |
| def __init__(self, test_proto): |
| """Ignore the normal signature for unittest.TestCase because we are generating |
| many test cases from this one class. This is safe as long as the autoloader |
| ignores this class. The autoloader will ignore this class as long as it has |
| no methods starting with `test_`. |
| """ |
| super(RoundTripParseTestCase, self).__init__('parse_round_trip') |
| self.test_proto = test_proto |
| |
| def parse_round_trip(self): |
| """The string of a Schema should be parseable to the same Schema.""" |
| parsed = self.test_proto.parse() |
| round_trip = avro.protocol.parse(str(parsed)) |
| self.assertEqual(parsed, round_trip) |
| |
| |
| def load_tests(loader, default_tests, pattern): |
| """Generate test cases across many test schema.""" |
| suite = unittest.TestSuite() |
| suite.addTests(loader.loadTestsFromTestCase(TestMisc)) |
| suite.addTests(ProtocolParseTestCase(ex) for ex in EXAMPLES) |
| suite.addTests(RoundTripParseTestCase(ex) for ex in VALID_EXAMPLES) |
| return suite |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |