| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| package org.apache.nemo.compiler.frontend.spark.core; |
| |
| import org.apache.nemo.client.JobLauncher; |
| import org.apache.nemo.common.dag.DAG; |
| import org.apache.nemo.common.dag.DAGBuilder; |
| import org.apache.nemo.common.ir.IRDAG; |
| import org.apache.nemo.common.ir.edge.IREdge; |
| import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty; |
| import org.apache.nemo.common.ir.edge.executionproperty.DecoderProperty; |
| import org.apache.nemo.common.ir.edge.executionproperty.EncoderProperty; |
| import org.apache.nemo.common.ir.edge.executionproperty.KeyExtractorProperty; |
| import org.apache.nemo.common.ir.vertex.IRVertex; |
| import org.apache.nemo.common.ir.vertex.LoopVertex; |
| import org.apache.nemo.common.ir.vertex.OperatorVertex; |
| import org.apache.nemo.compiler.frontend.spark.SparkBroadcastVariables; |
| import org.apache.nemo.compiler.frontend.spark.SparkKeyExtractor; |
| import org.apache.nemo.compiler.frontend.spark.coder.SparkDecoderFactory; |
| import org.apache.nemo.compiler.frontend.spark.coder.SparkEncoderFactory; |
| import org.apache.nemo.compiler.frontend.spark.transform.CollectTransform; |
| import org.apache.nemo.compiler.frontend.spark.transform.GroupByKeyTransform; |
| import org.apache.nemo.compiler.frontend.spark.transform.ReduceByKeyTransform; |
| import org.apache.spark.api.java.function.FlatMapFunction; |
| import org.apache.spark.api.java.function.Function; |
| import org.apache.spark.api.java.function.Function2; |
| import org.apache.spark.api.java.function.PairFunction; |
| import org.apache.spark.serializer.JavaSerializer; |
| import org.apache.spark.serializer.KryoSerializer; |
| import org.apache.spark.serializer.Serializer; |
| import org.apache.spark.serializer.SerializerInstance; |
| import scala.Function1; |
| import scala.Tuple2; |
| import scala.collection.JavaConverters; |
| import scala.collection.TraversableOnce; |
| import scala.reflect.ClassTag; |
| import scala.reflect.ClassTag$; |
| |
| import java.nio.ByteBuffer; |
| import java.util.Iterator; |
| import java.util.List; |
| import java.util.Stack; |
| |
| /** |
| * Utility class for RDDs. |
| */ |
| public final class SparkFrontendUtils { |
| private static final KeyExtractorProperty SPARK_KEY_EXTRACTOR_PROP = KeyExtractorProperty.of(new SparkKeyExtractor()); |
| |
| /** |
| * Private constructor. |
| */ |
| private SparkFrontendUtils() { |
| } |
| |
| /** |
| * Derive Spark serializer from a spark context. |
| * |
| * @param sparkContext spark context to derive the serializer from. |
| * @return the serializer. |
| */ |
| public static Serializer deriveSerializerFrom(final org.apache.spark.SparkContext sparkContext) { |
| if (sparkContext.conf().get("spark.serializer", "") |
| .equals("org.apache.spark.serializer.KryoSerializer")) { |
| return new KryoSerializer(sparkContext.conf()); |
| } else { |
| return new JavaSerializer(sparkContext.conf()); |
| } |
| } |
| |
| /** |
| * Collect data by running the DAG. |
| * |
| * @param dag the DAG to execute. |
| * @param loopVertexStack loop vertex stack. |
| * @param lastVertex last vertex added to the dag. |
| * @param serializer serializer for the edges. |
| * @param <T> type of the return data. |
| * @return the data collected. |
| */ |
| public static <T> List<T> collect(final DAG<IRVertex, IREdge> dag, |
| final Stack<LoopVertex> loopVertexStack, |
| final IRVertex lastVertex, |
| final Serializer serializer) { |
| final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>(dag); |
| |
| final IRVertex collectVertex = new OperatorVertex(new CollectTransform<>()); |
| builder.addVertex(collectVertex, loopVertexStack); |
| |
| final IREdge newEdge = new IREdge(getEdgeCommunicationPattern(lastVertex, collectVertex), |
| lastVertex, collectVertex); |
| newEdge.setProperty(EncoderProperty.of(new SparkEncoderFactory(serializer))); |
| newEdge.setProperty(DecoderProperty.of(new SparkDecoderFactory(serializer))); |
| newEdge.setProperty(SPARK_KEY_EXTRACTOR_PROP); |
| builder.connectVertices(newEdge); |
| |
| // launch DAG |
| JobLauncher.launchDAG(new IRDAG(builder.build()), SparkBroadcastVariables.getAll(), ""); |
| |
| return JobLauncher.getCollectedData(); |
| } |
| |
| /** |
| * Retrieve communication pattern of the edge. |
| * |
| * @param src source vertex. |
| * @param dst destination vertex. |
| * @return the communication pattern. |
| */ |
| public static CommunicationPatternProperty.Value getEdgeCommunicationPattern(final IRVertex src, |
| final IRVertex dst) { |
| if (dst instanceof OperatorVertex |
| && (((OperatorVertex) dst).getTransform() instanceof ReduceByKeyTransform |
| || ((OperatorVertex) dst).getTransform() instanceof GroupByKeyTransform)) { |
| return CommunicationPatternProperty.Value.SHUFFLE; |
| } else { |
| return CommunicationPatternProperty.Value.ONE_TO_ONE; |
| } |
| } |
| |
| /** |
| * Converts a {@link Function1} to a corresponding {@link Function}. |
| * <p> |
| * Here, we use the Spark 'JavaSerializer' to facilitate debugging in the future. |
| * TODO #205: RDD Closure with Broadcast Variables Serialization Bug |
| * |
| * @param scalaFunction the scala function to convert. |
| * @param <I> the type of input. |
| * @param <O> the type of output. |
| * @return the converted Java function. |
| */ |
| public static <I, O> Function<I, O> toJavaFunction(final Function1<I, O> scalaFunction) { |
| // This 'JavaSerializer' from Spark provides a human-readable NotSerializableException stack traces, |
| // which can be useful when addressing this problem. |
| // Other toJavaFunction can also use this serializer when debugging. |
| final ClassTag<Function1<I, O>> classTag = ClassTag$.MODULE$.apply(scalaFunction.getClass()); |
| final byte[] serializedFunction = new JavaSerializer().newInstance().serialize(scalaFunction, classTag).array(); |
| |
| return new Function<I, O>() { |
| private Function1<I, O> deserializedFunction; |
| |
| @Override |
| public O call(final I v1) throws Exception { |
| if (deserializedFunction == null) { |
| // TODO #205: RDD Closure with Broadcast Variables Serialization Bug |
| final SerializerInstance js = new JavaSerializer().newInstance(); |
| deserializedFunction = js.deserialize(ByteBuffer.wrap(serializedFunction), classTag); |
| } |
| return deserializedFunction.apply(v1); |
| } |
| }; |
| } |
| |
| /** |
| * Converts a {@link scala.Function2} to a corresponding {@link org.apache.spark.api.java.function.Function2}. |
| * |
| * @param scalaFunction the scala function to convert. |
| * @param <I1> the type of first input. |
| * @param <I2> the type of second input. |
| * @param <O> the type of output. |
| * @return the converted Java function. |
| */ |
| public static <I1, I2, O> Function2<I1, I2, O> toJavaFunction(final scala.Function2<I1, I2, O> scalaFunction) { |
| return new Function2<I1, I2, O>() { |
| @Override |
| public O call(final I1 v1, final I2 v2) throws Exception { |
| return scalaFunction.apply(v1, v2); |
| } |
| }; |
| } |
| |
| /** |
| * Converts a {@link Function1} to a corresponding {@link FlatMapFunction}. |
| * |
| * @param scalaFunction the scala function to convert. |
| * @param <I> the type of input. |
| * @param <O> the type of output. |
| * @return the converted Java function. |
| */ |
| public static <I, O> FlatMapFunction<I, O> toJavaFlatMapFunction( |
| final Function1<I, TraversableOnce<O>> scalaFunction) { |
| return new FlatMapFunction<I, O>() { |
| @Override |
| public Iterator<O> call(final I i) throws Exception { |
| return JavaConverters.asJavaIteratorConverter(scalaFunction.apply(i).toIterator()).asJava(); |
| } |
| }; |
| } |
| |
| /** |
| * Converts a {@link PairFunction} to a plain map {@link Function}. |
| * |
| * @param pairFunction the pair function to convert. |
| * @param <T> the type of original element. |
| * @param <K> the type of converted key. |
| * @param <V> the type of converted value. |
| * @return the converted map function. |
| */ |
| public static <T, K, V> Function<T, Tuple2<K, V>> pairFunctionToPlainFunction( |
| final PairFunction<T, K, V> pairFunction) { |
| return new Function<T, Tuple2<K, V>>() { |
| @Override |
| public Tuple2<K, V> call(final T elem) throws Exception { |
| return pairFunction.call(elem); |
| } |
| }; |
| } |
| } |