blob: f0bb24f4ceb2a01b3a5f5df3f67df5e0d620f402 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Coder for AvroRecord serialization/deserialization."""
from __future__ import absolute_import
import json
from io import BytesIO
from fastavro import parse_schema
from fastavro import schemaless_reader
from fastavro import schemaless_writer
from apache_beam.coders.coder_impl import SimpleCoderImpl
from apache_beam.coders.coders import Coder
from apache_beam.coders.coders import FastCoder
AVRO_CODER_URN = "beam:coder:avro:v1"
__all__ = ['AvroCoder', 'AvroRecord']
class AvroCoder(FastCoder):
"""A coder used for AvroRecord values."""
def __init__(self, schema):
self.schema = schema
def _create_impl(self):
return AvroCoderImpl(self.schema)
def is_deterministic(self):
# TODO: need to confirm if it's deterministic
return False
def __eq__(self, other):
return (type(self) == type(other)
and self.schema == other.schema)
def __hash__(self):
return hash(self.schema)
def to_type_hint(self):
return AvroRecord
def to_runner_api_parameter(self, context):
return AVRO_CODER_URN, self.schema, ()
@Coder.register_urn(AVRO_CODER_URN, bytes)
def from_runner_api_parameter(payload, unused_components, unused_context):
return AvroCoder(payload)
class AvroCoderImpl(SimpleCoderImpl):
"""For internal use only; no backwards-compatibility guarantees."""
def __init__(self, schema):
self.parsed_schema = parse_schema(json.loads(schema))
def encode(self, value):
assert issubclass(type(value), AvroRecord)
with BytesIO() as buf:
schemaless_writer(buf, self.parsed_schema, value.record)
return buf.getvalue()
def decode(self, encoded):
with BytesIO(encoded) as buf:
return AvroRecord(schemaless_reader(buf, self.parsed_schema))
class AvroRecord(object):
"""Simple wrapper class for dictionary records."""
def __init__(self, value):
self.record = value
def __eq__(self, other):
return (
issubclass(type(other), AvroRecord) and
self.record == other.record
)
def __hash__(self):
return hash(self.record)