blob: 8428e57831f68829fbcc0bf228f01a44fe03a355 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.protobuf;
import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.FileDescriptor.Syntax;
import com.google.protobuf.Descriptors.GenericDescriptor;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.ExtensionRegistry.ExtensionInfo;
import com.google.protobuf.Message;
import java.lang.reflect.InvocationTargetException;
import java.util.HashSet;
import java.util.Set;
import org.apache.beam.sdk.coders.Coder.NonDeterministicException;
/**
* Utility functions for reflecting and analyzing Protocol Buffers classes.
*
* <p>Used by {@link ProtoCoder}, but in a separate file for testing and isolation.
*/
class ProtobufUtil {
/**
* Returns the {@link Descriptor} for the given Protocol Buffers {@link Message}.
*
* @throws IllegalArgumentException if there is an error in Java reflection.
*/
static Descriptor getDescriptorForClass(Class<? extends Message> clazz) {
try {
return (Descriptor) clazz.getMethod("getDescriptor").invoke(null);
} catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
throw new IllegalArgumentException(e);
}
}
/**
* Returns the {@link Descriptor} for the given Protocol Buffers {@link Message} as well as every
* class it can include transitively.
*
* @throws IllegalArgumentException if there is an error in Java reflection.
*/
static Set<Descriptor> getRecursiveDescriptorsForClass(
Class<? extends Message> clazz, ExtensionRegistry registry) {
Descriptor root = getDescriptorForClass(clazz);
Set<Descriptor> descriptors = new HashSet<>();
recursivelyAddDescriptors(root, descriptors, registry);
return descriptors;
}
/**
* Recursively walks the given {@link Message} class and verifies that every field or message
* linked in uses the Protocol Buffers proto2 syntax.
*/
static void checkProto2Syntax(Class<? extends Message> clazz, ExtensionRegistry registry) {
for (GenericDescriptor d : getRecursiveDescriptorsForClass(clazz, registry)) {
Syntax s = d.getFile().getSyntax();
checkArgument(
s == Syntax.PROTO2,
"Message %s or one of its dependencies does not use proto2 syntax: %s in file %s",
clazz.getName(),
d.getFullName(),
d.getFile().getName());
}
}
/**
* Recursively checks whether the specified class uses any Protocol Buffers fields that cannot be
* deterministically encoded.
*
* @throws NonDeterministicException if the object cannot be encoded deterministically.
*/
static void verifyDeterministic(ProtoCoder<?> coder) throws NonDeterministicException {
Class<? extends Message> message = coder.getMessageType();
ExtensionRegistry registry = coder.getExtensionRegistry();
Set<Descriptor> descriptors = getRecursiveDescriptorsForClass(message, registry);
for (Descriptor d : descriptors) {
for (FieldDescriptor fd : d.getFields()) {
// If there is a transitively reachable Protocol Buffers map field, then this object cannot
// be encoded deterministically.
if (fd.isMapField()) {
String reason =
String.format(
"Protocol Buffers message %s transitively includes Map field %s (from file %s)."
+ " Maps cannot be deterministically encoded.",
message.getName(), fd.getFullName(), fd.getFile().getFullName());
throw new NonDeterministicException(coder, reason);
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////
// Disable construction of utility class
private ProtobufUtil() {}
private static void recursivelyAddDescriptors(
Descriptor message, Set<Descriptor> descriptors, ExtensionRegistry registry) {
if (descriptors.contains(message)) {
return;
}
descriptors.add(message);
for (FieldDescriptor f : message.getFields()) {
recursivelyAddDescriptors(f, descriptors, registry);
}
for (FieldDescriptor f : message.getExtensions()) {
recursivelyAddDescriptors(f, descriptors, registry);
}
for (ExtensionInfo info :
registry.getAllImmutableExtensionsByExtendedType(message.getFullName())) {
recursivelyAddDescriptors(info.descriptor, descriptors, registry);
}
for (ExtensionInfo info :
registry.getAllMutableExtensionsByExtendedType(message.getFullName())) {
recursivelyAddDescriptors(info.descriptor, descriptors, registry);
}
}
private static void recursivelyAddDescriptors(
FieldDescriptor field, Set<Descriptor> descriptors, ExtensionRegistry registry) {
switch (field.getType()) {
case BOOL:
case BYTES:
case DOUBLE:
case ENUM:
case FIXED32:
case FIXED64:
case FLOAT:
case INT32:
case INT64:
case SFIXED32:
case SFIXED64:
case SINT32:
case SINT64:
case STRING:
case UINT32:
case UINT64:
// Primitive types do not transitively access anything else.
break;
case GROUP:
case MESSAGE:
// Recursively adds all the fields from this nested Message.
recursivelyAddDescriptors(field.getMessageType(), descriptors, registry);
break;
default:
throw new UnsupportedOperationException(
"Unexpected Protocol Buffers field type: " + field.getType());
}
}
}