| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.sdk.extensions.protobuf; |
| |
| import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument; |
| |
| import com.google.protobuf.Descriptors.Descriptor; |
| import com.google.protobuf.Descriptors.FieldDescriptor; |
| import com.google.protobuf.Descriptors.FileDescriptor.Syntax; |
| import com.google.protobuf.Descriptors.GenericDescriptor; |
| import com.google.protobuf.ExtensionRegistry; |
| import com.google.protobuf.ExtensionRegistry.ExtensionInfo; |
| import com.google.protobuf.Message; |
| import java.lang.reflect.InvocationTargetException; |
| import java.util.HashSet; |
| import java.util.Set; |
| import org.apache.beam.sdk.coders.Coder.NonDeterministicException; |
| |
| /** |
| * Utility functions for reflecting and analyzing Protocol Buffers classes. |
| * |
| * <p>Used by {@link ProtoCoder}, but in a separate file for testing and isolation. |
| */ |
| class ProtobufUtil { |
| /** |
| * Returns the {@link Descriptor} for the given Protocol Buffers {@link Message}. |
| * |
| * @throws IllegalArgumentException if there is an error in Java reflection. |
| */ |
| static Descriptor getDescriptorForClass(Class<? extends Message> clazz) { |
| try { |
| return (Descriptor) clazz.getMethod("getDescriptor").invoke(null); |
| } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) { |
| throw new IllegalArgumentException(e); |
| } |
| } |
| |
| /** |
| * Returns the {@link Descriptor} for the given Protocol Buffers {@link Message} as well as every |
| * class it can include transitively. |
| * |
| * @throws IllegalArgumentException if there is an error in Java reflection. |
| */ |
| static Set<Descriptor> getRecursiveDescriptorsForClass( |
| Class<? extends Message> clazz, ExtensionRegistry registry) { |
| Descriptor root = getDescriptorForClass(clazz); |
| Set<Descriptor> descriptors = new HashSet<>(); |
| recursivelyAddDescriptors(root, descriptors, registry); |
| return descriptors; |
| } |
| |
| /** |
| * Recursively walks the given {@link Message} class and verifies that every field or message |
| * linked in uses the Protocol Buffers proto2 syntax. |
| */ |
| static void checkProto2Syntax(Class<? extends Message> clazz, ExtensionRegistry registry) { |
| for (GenericDescriptor d : getRecursiveDescriptorsForClass(clazz, registry)) { |
| Syntax s = d.getFile().getSyntax(); |
| checkArgument( |
| s == Syntax.PROTO2, |
| "Message %s or one of its dependencies does not use proto2 syntax: %s in file %s", |
| clazz.getName(), |
| d.getFullName(), |
| d.getFile().getName()); |
| } |
| } |
| |
| /** |
| * Recursively checks whether the specified class uses any Protocol Buffers fields that cannot be |
| * deterministically encoded. |
| * |
| * @throws NonDeterministicException if the object cannot be encoded deterministically. |
| */ |
| static void verifyDeterministic(ProtoCoder<?> coder) throws NonDeterministicException { |
| Class<? extends Message> message = coder.getMessageType(); |
| ExtensionRegistry registry = coder.getExtensionRegistry(); |
| Set<Descriptor> descriptors = getRecursiveDescriptorsForClass(message, registry); |
| for (Descriptor d : descriptors) { |
| for (FieldDescriptor fd : d.getFields()) { |
| // If there is a transitively reachable Protocol Buffers map field, then this object cannot |
| // be encoded deterministically. |
| if (fd.isMapField()) { |
| String reason = |
| String.format( |
| "Protocol Buffers message %s transitively includes Map field %s (from file %s)." |
| + " Maps cannot be deterministically encoded.", |
| message.getName(), fd.getFullName(), fd.getFile().getFullName()); |
| throw new NonDeterministicException(coder, reason); |
| } |
| } |
| } |
| } |
| |
| //////////////////////////////////////////////////////////////////////////////////////////////// |
| // Disable construction of utility class |
| private ProtobufUtil() {} |
| |
| private static void recursivelyAddDescriptors( |
| Descriptor message, Set<Descriptor> descriptors, ExtensionRegistry registry) { |
| if (descriptors.contains(message)) { |
| return; |
| } |
| descriptors.add(message); |
| |
| for (FieldDescriptor f : message.getFields()) { |
| recursivelyAddDescriptors(f, descriptors, registry); |
| } |
| for (FieldDescriptor f : message.getExtensions()) { |
| recursivelyAddDescriptors(f, descriptors, registry); |
| } |
| for (ExtensionInfo info : |
| registry.getAllImmutableExtensionsByExtendedType(message.getFullName())) { |
| recursivelyAddDescriptors(info.descriptor, descriptors, registry); |
| } |
| for (ExtensionInfo info : |
| registry.getAllMutableExtensionsByExtendedType(message.getFullName())) { |
| recursivelyAddDescriptors(info.descriptor, descriptors, registry); |
| } |
| } |
| |
| private static void recursivelyAddDescriptors( |
| FieldDescriptor field, Set<Descriptor> descriptors, ExtensionRegistry registry) { |
| switch (field.getType()) { |
| case BOOL: |
| case BYTES: |
| case DOUBLE: |
| case ENUM: |
| case FIXED32: |
| case FIXED64: |
| case FLOAT: |
| case INT32: |
| case INT64: |
| case SFIXED32: |
| case SFIXED64: |
| case SINT32: |
| case SINT64: |
| case STRING: |
| case UINT32: |
| case UINT64: |
| // Primitive types do not transitively access anything else. |
| break; |
| |
| case GROUP: |
| case MESSAGE: |
| // Recursively adds all the fields from this nested Message. |
| recursivelyAddDescriptors(field.getMessageType(), descriptors, registry); |
| break; |
| |
| default: |
| throw new UnsupportedOperationException( |
| "Unexpected Protocol Buffers field type: " + field.getType()); |
| } |
| } |
| } |