blob: 0b2d717da2885acbf1cbc86df73d32b4a79f1a9d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.protobuf;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
import com.google.protobuf.DynamicMessage;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import com.google.protobuf.Parser;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.CoderProvider;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.CustomCoder;
import org.apache.beam.sdk.coders.DefaultCoder;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
/**
* A {@link Coder} using Google Protocol Buffers binary format. {@link ProtoCoder} supports both
* Protocol Buffers syntax versions 2 and 3.
*
* <p>To learn more about Protocol Buffers, visit: <a
* href="https://developers.google.com/protocol-buffers">https://developers.google.com/protocol-buffers</a>
*
* <p>{@link ProtoCoder} is registered in the global {@link CoderRegistry} as the default {@link
* Coder} for any {@link Message} object. Custom message extensions are also supported, but these
* extensions must be registered for a particular {@link ProtoCoder} instance and that instance must
* be registered on the {@link PCollection} that needs the extensions:
*
* <pre>{@code
* import MyProtoFile;
* import MyProtoFile.MyMessage;
*
* Coder<MyMessage> coder = ProtoCoder.of(MyMessage.class).withExtensionsFrom(MyProtoFile.class);
* PCollection<MyMessage> records = input.apply(...).setCoder(coder);
* }</pre>
*
* <h3>Versioning</h3>
*
* <p>{@link ProtoCoder} supports both versions 2 and 3 of the Protocol Buffers syntax. However, the
* Java runtime version of the <code>google.com.protobuf</code> library must match exactly the
* version of <code>protoc</code> that was used to produce the JAR files containing the compiled
* <code>.proto</code> messages.
*
* <p>For more information, see the <a
* href="https://developers.google.com/protocol-buffers/docs/proto3#using-proto2-message-types">Protocol
* Buffers documentation</a>.
*
* <h3>{@link ProtoCoder} and Determinism</h3>
*
* <p>In general, Protocol Buffers messages can be encoded deterministically within a single
* pipeline as long as:
*
* <ul>
* <li>The encoded messages (and any transitively linked messages) do not use <code>map</code>
* fields.
* <li>Every Java VM that encodes or decodes the messages use the same runtime version of the
* Protocol Buffers library and the same compiled <code>.proto</code> file JAR.
* </ul>
*
* <h3>{@link ProtoCoder} and Encoding Stability</h3>
*
* <p>When changing Protocol Buffers messages, follow the rules in the Protocol Buffers language
* guides for <a href="https://developers.google.com/protocol-buffers/docs/proto#updating">{@code
* proto2}</a> and <a
* href="https://developers.google.com/protocol-buffers/docs/proto3#updating">{@code proto3}</a>
* syntaxes, depending on your message type. Following these guidelines will ensure that the old
* encoded data can be read by new versions of the code.
*
* <p>Generally, any change to the message type, registered extensions, runtime library, or compiled
* proto JARs may change the encoding. Thus even if both the original and updated messages can be
* encoded deterministically within a single job, these deterministic encodings may not be the same
* across jobs.
*
* @param <T> the Protocol Buffers {@link Message} handled by this {@link Coder}.
*/
public class ProtoCoder<T extends Message> extends CustomCoder<T> {
public static final long serialVersionUID = -5043999806040629525L;
/** Returns a {@link ProtoCoder} for the given Protocol Buffers {@link Message}. */
public static <T extends Message> ProtoCoder<T> of(Class<T> protoMessageClass) {
return new ProtoCoder<>(protoMessageClass, ImmutableSet.of());
}
/**
* Returns a {@link ProtoCoder} for the Protocol Buffers {@link Message} indicated by the given
* {@link TypeDescriptor}.
*/
public static <T extends Message> ProtoCoder<T> of(TypeDescriptor<T> protoMessageType) {
@SuppressWarnings("unchecked")
Class<T> protoMessageClass = (Class<T>) protoMessageType.getRawType();
return of(protoMessageClass);
}
/**
* Validate that all extensionHosts are able to be registered.
*
* @param moreExtensionHosts
*/
void validateExtensions(Iterable<Class<?>> moreExtensionHosts) {
for (Class<?> extensionHost : moreExtensionHosts) {
// Attempt to access the required method, to make sure it's present.
try {
Method registerAllExtensions =
extensionHost.getDeclaredMethod("registerAllExtensions", ExtensionRegistry.class);
checkArgument(
Modifier.isStatic(registerAllExtensions.getModifiers()),
"Method registerAllExtensions() must be static");
} catch (NoSuchMethodException | SecurityException e) {
throw new IllegalArgumentException(
String.format("Unable to register extensions for %s", extensionHost.getCanonicalName()),
e);
}
}
}
/**
* Returns a {@link ProtoCoder} like this one, but with the extensions from the given classes
* registered.
*
* <p>Each of the extension host classes must be an class automatically generated by the Protocol
* Buffers compiler, {@code protoc}, that contains messages.
*
* <p>Does not modify this object.
*/
public ProtoCoder<T> withExtensionsFrom(Iterable<Class<?>> moreExtensionHosts) {
validateExtensions(moreExtensionHosts);
return new ProtoCoder<>(
protoMessageClass,
new ImmutableSet.Builder<Class<?>>()
.addAll(extensionHostClasses)
.addAll(moreExtensionHosts)
.build());
}
/**
* See {@link #withExtensionsFrom(Iterable)}.
*
* <p>Does not modify this object.
*/
public ProtoCoder<T> withExtensionsFrom(Class<?>... moreExtensionHosts) {
return withExtensionsFrom(Arrays.asList(moreExtensionHosts));
}
@Override
public void encode(T value, OutputStream outStream) throws IOException {
encode(value, outStream, Context.NESTED);
}
@Override
public void encode(T value, OutputStream outStream, Context context) throws IOException {
if (value == null) {
throw new CoderException("cannot encode a null " + protoMessageClass.getSimpleName());
}
if (context.isWholeStream) {
value.writeTo(outStream);
} else {
value.writeDelimitedTo(outStream);
}
}
@Override
public T decode(InputStream inStream) throws IOException {
return decode(inStream, Context.NESTED);
}
@Override
public T decode(InputStream inStream, Context context) throws IOException {
if (context.isWholeStream) {
return getParser().parseFrom(inStream, getExtensionRegistry());
} else {
return getParser().parseDelimitedFrom(inStream, getExtensionRegistry());
}
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other == null || getClass() != other.getClass()) {
return false;
}
ProtoCoder<?> otherCoder = (ProtoCoder<?>) other;
return protoMessageClass.equals(otherCoder.protoMessageClass)
&& Sets.newHashSet(extensionHostClasses)
.equals(Sets.newHashSet(otherCoder.extensionHostClasses));
}
@Override
public int hashCode() {
return Objects.hash(protoMessageClass, extensionHostClasses);
}
@Override
public void verifyDeterministic() throws NonDeterministicException {
ProtobufUtil.verifyDeterministic(this);
}
/** Returns the Protocol Buffers {@link Message} type this {@link ProtoCoder} supports. */
public Class<T> getMessageType() {
return protoMessageClass;
}
public Set<Class<?>> getExtensionHosts() {
return extensionHostClasses;
}
/**
* Returns the {@link ExtensionRegistry} listing all known Protocol Buffers extension messages to
* {@code T} registered with this {@link ProtoCoder}.
*/
public ExtensionRegistry getExtensionRegistry() {
if (memoizedExtensionRegistry == null) {
ExtensionRegistry registry = ExtensionRegistry.newInstance();
for (Class<?> extensionHost : extensionHostClasses) {
try {
extensionHost
.getDeclaredMethod("registerAllExtensions", ExtensionRegistry.class)
.invoke(null, registry);
} catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
throw new IllegalStateException(e);
}
}
memoizedExtensionRegistry = registry.getUnmodifiable();
}
return memoizedExtensionRegistry;
}
////////////////////////////////////////////////////////////////////////////////////
// Private implementation details below.
/** The {@link Message} type to be coded. */
final Class<T> protoMessageClass;
/**
* All extension host classes included in this {@link ProtoCoder}. The extensions from these
* classes will be included in the {@link ExtensionRegistry} used during encoding and decoding.
*/
final Set<Class<?>> extensionHostClasses;
// Constants used to serialize and deserialize
private static final String PROTO_MESSAGE_CLASS = "proto_message_class";
private static final String PROTO_EXTENSION_HOSTS = "proto_extension_hosts";
// Transient fields that are lazy initialized and then memoized.
private transient ExtensionRegistry memoizedExtensionRegistry;
transient Parser<T> memoizedParser;
/** Private constructor. */
protected ProtoCoder(Class<T> protoMessageClass, Set<Class<?>> extensionHostClasses) {
this.protoMessageClass = protoMessageClass;
this.extensionHostClasses = extensionHostClasses;
}
/** Get the memoized {@link Parser}, possibly initializing it lazily. */
protected Parser<T> getParser() {
if (memoizedParser == null) {
try {
if (DynamicMessage.class.equals(protoMessageClass)) {
throw new IllegalArgumentException(
"DynamicMessage is not supported by the ProtoCoder, use the DynamicProtoCoder.");
} else {
@SuppressWarnings("unchecked")
T protoMessageInstance =
(T) protoMessageClass.getMethod("getDefaultInstance").invoke(null);
@SuppressWarnings("unchecked")
Parser<T> tParser = (Parser<T>) protoMessageInstance.getParserForType();
memoizedParser = tParser;
}
} catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
throw new IllegalArgumentException(e);
}
}
return memoizedParser;
}
/**
* Returns a {@link CoderProvider} which uses the {@link ProtoCoder} for {@link Message proto
* messages}.
*
* <p>This method is invoked reflectively from {@link DefaultCoder}.
*/
public static CoderProvider getCoderProvider() {
return new ProtoCoderProvider();
}
static final TypeDescriptor<Message> MESSAGE_TYPE = new TypeDescriptor<Message>() {};
/** A {@link CoderProvider} for {@link Message proto messages}. */
private static class ProtoCoderProvider extends CoderProvider {
@Override
public <T> Coder<T> coderFor(
TypeDescriptor<T> typeDescriptor, List<? extends Coder<?>> componentCoders)
throws CannotProvideCoderException {
if (!typeDescriptor.isSubtypeOf(MESSAGE_TYPE)) {
throw new CannotProvideCoderException(
String.format(
"Cannot provide %s because %s is not a subclass of %s",
ProtoCoder.class.getSimpleName(), typeDescriptor, Message.class.getName()));
}
@SuppressWarnings("unchecked")
TypeDescriptor<? extends Message> messageType =
(TypeDescriptor<? extends Message>) typeDescriptor;
try {
@SuppressWarnings("unchecked")
Coder<T> coder = (Coder<T>) ProtoCoder.of(messageType);
return coder;
} catch (IllegalArgumentException e) {
throw new CannotProvideCoderException(e);
}
}
}
}