blob: dc002bc6726051d7e5f152333d043d7e63f08c66 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.statefun.flink.core.translation;
import com.google.protobuf.Any;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.statefun.flink.core.StatefulFunctionsConfig;
import org.apache.flink.statefun.flink.core.StatefulFunctionsUniverse;
import org.apache.flink.statefun.flink.core.StatefulFunctionsUniverses;
import org.apache.flink.statefun.flink.core.message.Message;
import org.apache.flink.statefun.flink.core.message.MessageFactory;
import org.apache.flink.statefun.flink.core.message.MessageFactoryKey;
import org.apache.flink.statefun.flink.core.message.MessageFactoryType;
import org.apache.flink.statefun.sdk.Address;
import org.apache.flink.statefun.sdk.io.IngressIdentifier;
import org.apache.flink.statefun.sdk.io.Router;
import org.apache.flink.statefun.sdk.io.Router.Downstream;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.ChainingStrategy;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.Output;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public final class IngressRouterOperator<T> extends AbstractStreamOperator<Message>
implements OneInputStreamOperator<T, Message> {
private static final Logger LOG = LoggerFactory.getLogger(IngressRouterOperator.class);
private static final long serialVersionUID = 1;
private final StatefulFunctionsConfig configuration;
private final IngressIdentifier<T> id;
private transient List<Router<T>> routers;
private transient DownstreamCollector<T> downstream;
IngressRouterOperator(StatefulFunctionsConfig configuration, IngressIdentifier<T> id) {
this.configuration = configuration;
this.id = Objects.requireNonNull(id);
this.chainingStrategy = ChainingStrategy.ALWAYS;
}
@Override
public void open() throws Exception {
super.open();
StatefulFunctionsUniverse universe =
StatefulFunctionsUniverses.get(
Thread.currentThread().getContextClassLoader(), configuration);
LOG.info("Using message factory key " + universe.messageFactoryKey());
this.downstream = new DownstreamCollector<>(universe.messageFactoryKey(), output);
this.routers = loadRoutersAttachedToIngress(id, universe.routers());
}
@Override
public void processElement(StreamRecord<T> element) {
final T value = element.getValue();
for (Router<T> router : routers) {
router.route(value, downstream);
}
}
@SuppressWarnings("unchecked")
private static <T> List<Router<T>> loadRoutersAttachedToIngress(
IngressIdentifier<T> id, Map<IngressIdentifier<?>, List<Router<?>>> definedRouters) {
List<Router<?>> routerList = definedRouters.get(id);
Preconditions.checkState(routerList != null, "unable to find a router for ingress " + id);
return (List<Router<T>>) (List<?>) routerList;
}
@VisibleForTesting
static final class DownstreamCollector<T> implements Downstream<T> {
private final MessageFactory factory;
private final boolean multiLanguagePayloads;
private final StreamRecord<Message> reuse = new StreamRecord<>(null);
private final Output<StreamRecord<Message>> output;
DownstreamCollector(MessageFactoryKey messageFactoryKey, Output<StreamRecord<Message>> output) {
this.factory = MessageFactory.forKey(messageFactoryKey);
this.output = Objects.requireNonNull(output);
this.multiLanguagePayloads =
messageFactoryKey.getType() == MessageFactoryType.WITH_PROTOBUF_PAYLOADS_MULTILANG;
}
@Override
public void forward(Address to, Object message) {
if (to == null) {
throw new NullPointerException("Unable to send a message downstream without an address.");
}
if (message == null) {
throw new NullPointerException("message is mandatory parameter and can not be NULL.");
}
// create an envelope out of the source, destination addresses, and the payload.
// This is the first instance where a user supplied payload that comes off a source
// is wrapped into a (statefun) Message envelope.
if (multiLanguagePayloads) {
// in the multi language case, the payloads are always of type com.google.protobuf.Any
// therefore we first pack the whatever protobuf Message supplied from the user
// into a Protobuf Any.
message = wrapAsProtobufAny(message);
}
Message envelope = factory.from(null, to, message);
output.collect(reuse.replace(envelope));
}
private Object wrapAsProtobufAny(Object message) {
if (message instanceof Any) {
return message;
}
return Any.pack((com.google.protobuf.Message) message);
}
}
}