blob: 0eee49a5d47454cdfaa6a449cfaade0e03923807 [file] [log] [blame]
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
package org.apache.avro;
import java.util.*;
import org.apache.avro.file.DataFileReader;
import org.apache.avro.file.DataFileWriter;
import org.apache.avro.file.FileReader;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericData.Record;
import org.apache.avro.generic.IndexedRecord;
import org.apache.avro.util.Utf8;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
public class TestCircularReferences {
public TemporaryFolder temp = new TemporaryFolder();
public static class Reference extends LogicalType {
private static final String REFERENCE = "reference";
private static final String REF_FIELD_NAME = "ref-field-name";
private final String refFieldName;
public Reference(String refFieldName) {
this.refFieldName = refFieldName;
public Reference(Schema schema) {
this.refFieldName = schema.getProp(REF_FIELD_NAME);
public Schema addToSchema(Schema schema) {
schema.addProp(REF_FIELD_NAME, refFieldName);
return schema;
public String getName() {
public String getRefFieldName() {
return refFieldName;
public void validate(Schema schema) {
if (schema.getField(refFieldName) == null) {
throw new IllegalArgumentException("Invalid field name for reference field: " + refFieldName);
public static class Referenceable extends LogicalType {
private static final String REFERENCEABLE = "referenceable";
private static final String ID_FIELD_NAME = "id-field-name";
private final String idFieldName;
public Referenceable(String idFieldName) {
this.idFieldName = idFieldName;
public Referenceable(Schema schema) {
this.idFieldName = schema.getProp(ID_FIELD_NAME);
public Schema addToSchema(Schema schema) {
schema.addProp(ID_FIELD_NAME, idFieldName);
return schema;
public String getName() {
public String getIdFieldName() {
return idFieldName;
public void validate(Schema schema) {
Schema.Field idField = schema.getField(idFieldName);
if (idField == null || idField.schema().getType() != Schema.Type.LONG) {
throw new IllegalArgumentException("Invalid ID field: " + idFieldName + ": " + idField);
public static void addReferenceTypes() {
LogicalTypes.register(Referenceable.REFERENCEABLE, Referenceable::new);
LogicalTypes.register(Reference.REFERENCE, Reference::new);
public static class ReferenceManager {
private interface Callback {
void set(Object referenceable);
private final Map<Long, Object> references = new HashMap<>();
private final Map<Object, Long> ids = new IdentityHashMap<>();
private final Map<Long, List<Callback>> callbacksById = new HashMap<>();
private final ReferenceableTracker tracker = new ReferenceableTracker();
private final ReferenceHandler handler = new ReferenceHandler();
public ReferenceableTracker getTracker() {
return tracker;
public ReferenceHandler getHandler() {
return handler;
public class ReferenceableTracker extends Conversion<IndexedRecord> {
public Class<IndexedRecord> getConvertedType() {
return (Class) Record.class;
public String getLogicalTypeName() {
return Referenceable.REFERENCEABLE;
public IndexedRecord fromRecord(IndexedRecord value, Schema schema, LogicalType type) {
// read side
long id = getId(value, schema);
// keep track of this for later references
references.put(id, value);
// call any callbacks waiting to resolve this id
List<Callback> callbacks = callbacksById.get(id);
for (Callback callback : callbacks) {
return value;
public IndexedRecord toRecord(IndexedRecord value, Schema schema, LogicalType type) {
// write side
long id = getId(value, schema);
// keep track of this for later references
// references.put(id, value);
ids.put(value, id);
return value;
private long getId(IndexedRecord referenceable, Schema schema) {
Referenceable info = (Referenceable) schema.getLogicalType();
int idField = schema.getField(info.getIdFieldName()).pos();
return (Long) referenceable.get(idField);
public class ReferenceHandler extends Conversion<IndexedRecord> {
public Class<IndexedRecord> getConvertedType() {
return (Class) Record.class;
public String getLogicalTypeName() {
return Reference.REFERENCE;
public IndexedRecord fromRecord(final IndexedRecord record, Schema schema, LogicalType type) {
// read side: resolve the record or save a callback
final Schema.Field refField = schema.getField(((Reference) type).getRefFieldName());
Long id = (Long) record.get(refField.pos());
if (id != null) {
if (references.containsKey(id)) {
record.put(refField.pos(), references.get(id));
} else {
List<Callback> callbacks = callbacksById.computeIfAbsent(id, k -> new ArrayList<>());
// add a callback to resolve this reference when the id is available
callbacks.add(referenceable -> record.put(refField.pos(), referenceable));
return record;
public IndexedRecord toRecord(IndexedRecord record, Schema schema, LogicalType type) {
// write side: replace a referenced field with its id
Schema.Field refField = schema.getField(((Reference) type).getRefFieldName());
IndexedRecord referenced = (IndexedRecord) record.get(refField.pos());
if (referenced == null) {
return record;
// hijack the field to return the id instead of the ref
return new HijackingIndexedRecord(record, refField.pos(), ids.get(referenced));
private static class HijackingIndexedRecord implements IndexedRecord {
private final IndexedRecord wrapped;
private final int index;
private final Object data;
public HijackingIndexedRecord(IndexedRecord wrapped, int index, Object data) {
this.wrapped = wrapped;
this.index = index; = data;
public void put(int i, Object v) {
throw new RuntimeException("[BUG] This is a read-only class.");
public Object get(int i) {
if (i == index) {
return data;
return wrapped.get(i);
public Schema getSchema() {
return wrapped.getSchema();
public void test() throws IOException {
ReferenceManager manager = new ReferenceManager();
GenericData model = new GenericData();
Schema parentSchema = Schema.createRecord("Parent", null, null, false);
Schema parentRefSchema = Schema.createUnion(Schema.create(Schema.Type.NULL), Schema.create(Schema.Type.LONG),
Reference parentRef = new Reference("parent");
List<Schema.Field> childFields = new ArrayList<>();
childFields.add(new Schema.Field("c", Schema.create(Schema.Type.STRING)));
childFields.add(new Schema.Field("parent", parentRefSchema));
Schema childSchema = parentRef.addToSchema(Schema.createRecord("Child", null, null, false, childFields));
List<Schema.Field> parentFields = new ArrayList<>();
parentFields.add(new Schema.Field("id", Schema.create(Schema.Type.LONG)));
parentFields.add(new Schema.Field("p", Schema.create(Schema.Type.STRING)));
parentFields.add(new Schema.Field("child", childSchema));
Referenceable idRef = new Referenceable("id");
Schema schema = idRef.addToSchema(parentSchema);
System.out.println("Schema: " + schema.toString(true));
Record parent = new Record(schema);
parent.put("id", 1L);
parent.put("p", "parent data!");
Record child = new Record(childSchema);
child.put("c", "child data!");
child.put("parent", parent);
parent.put("child", child);
// serialization round trip
File data = write(model, schema, parent);
List<Record> records = read(model, schema, data);
Record actual = records.get(0);
// because the record is a recursive structure, equals won't work
Assert.assertEquals("Should correctly read back the parent id", 1L, actual.get("id"));
Assert.assertEquals("Should correctly read back the parent data", new Utf8("parent data!"), actual.get("p"));
Record actualChild = (Record) actual.get("child");
Assert.assertEquals("Should correctly read back the child data", new Utf8("child data!"), actualChild.get("c"));
Object childParent = actualChild.get("parent");
Assert.assertTrue("Should have a parent Record object", childParent instanceof Record);
Record childParentRecord = (Record) actualChild.get("parent");
Assert.assertEquals("Should have the right parent id", 1L, childParentRecord.get("id"));
Assert.assertEquals("Should have the right parent data", new Utf8("parent data!"), childParentRecord.get("p"));
private <D> List<D> read(GenericData model, Schema schema, File file) throws IOException {
DatumReader<D> reader = newReader(model, schema);
List<D> data = new ArrayList<>();
try (FileReader<D> fileReader = new DataFileReader<>(file, reader)) {
for (D datum : fileReader) {
return data;
private <D> DatumReader<D> newReader(GenericData model, Schema schema) {
return model.createDatumReader(schema);
private <D> File write(GenericData model, Schema schema, D... data) throws IOException {
File file = temp.newFile();
DatumWriter<D> writer = model.createDatumWriter(schema);
try (DataFileWriter<D> fileWriter = new DataFileWriter<>(writer)) {
fileWriter.create(schema, file);
for (D datum : data) {
return file;