| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| package org.apache.nemo.compiler.frontend.spark.source; |
| |
| import org.apache.nemo.common.ir.BoundedIteratorReadable; |
| import org.apache.nemo.common.ir.Readable; |
| import org.apache.nemo.common.ir.vertex.SourceVertex; |
| import org.apache.nemo.compiler.frontend.spark.sql.Dataset; |
| import org.apache.nemo.compiler.frontend.spark.sql.SparkSession; |
| import org.apache.spark.Partition; |
| import org.apache.spark.TaskContext$; |
| import org.apache.spark.rdd.RDD; |
| import scala.collection.JavaConverters; |
| |
| import javax.naming.OperationNotSupportedException; |
| import java.io.IOException; |
| import java.util.*; |
| |
| /** |
| * Bounded source vertex for Spark Dataset. |
| * |
| * @param <T> type of data to read. |
| */ |
| public final class SparkDatasetBoundedSourceVertex<T> extends SourceVertex<T> { |
| private List<Readable<T>> readables; |
| private long estimatedByteSize; |
| |
| /** |
| * Constructor. |
| * |
| * @param sparkSession sparkSession to recreate on each executor. |
| * @param dataset Dataset to read data from. |
| */ |
| public SparkDatasetBoundedSourceVertex(final SparkSession sparkSession, final Dataset<T> dataset) { |
| this.readables = new ArrayList<>(); |
| final RDD rdd = dataset.sparkRDD(); |
| final Partition[] partitions = rdd.getPartitions(); |
| for (int i = 0; i < partitions.length; i++) { |
| readables.add(new SparkDatasetBoundedSourceReadable( |
| partitions[i], |
| sparkSession.getDatasetCommandsList(), |
| sparkSession.getInitialConf(), |
| i)); |
| } |
| this.estimatedByteSize = dataset.javaRDD() |
| .map(o -> (long) o.toString().getBytes("UTF-8").length) |
| .reduce((a, b) -> a + b); |
| } |
| |
| /** |
| * Copy Constructor for SparkDatasetBoundedSourceVertex. |
| * |
| * @param that the source object for copying |
| */ |
| private SparkDatasetBoundedSourceVertex(final SparkDatasetBoundedSourceVertex<T> that) { |
| super(that); |
| this.readables = new ArrayList<>(); |
| that.readables.forEach(this.readables::add); |
| } |
| |
| @Override |
| public SparkDatasetBoundedSourceVertex getClone() { |
| return new SparkDatasetBoundedSourceVertex(this); |
| } |
| |
| @Override |
| public boolean isBounded() { |
| return true; |
| } |
| |
| @Override |
| public List<Readable<T>> getReadables(final int desiredNumOfSplits) { |
| return readables; |
| } |
| |
| @Override |
| public long getEstimatedSizeBytes() { |
| return this.estimatedByteSize; |
| } |
| |
| @Override |
| public void clearInternalStates() { |
| readables = null; |
| } |
| |
| /** |
| * A Readable wrapper for Spark Dataset. |
| */ |
| private final class SparkDatasetBoundedSourceReadable extends BoundedIteratorReadable<T> { |
| private final LinkedHashMap<String, Object[]> commands; |
| private final Map<String, String> sessionInitialConf; |
| private final int partitionIndex; |
| private final List<String> locations; |
| |
| /** |
| * Constructor. |
| * |
| * @param partition the partition to wrap. |
| * @param commands list of commands needed to build the dataset. |
| * @param sessionInitialConf spark session's initial configuration. |
| * @param partitionIndex partition for this readable. |
| */ |
| private SparkDatasetBoundedSourceReadable(final Partition partition, |
| final LinkedHashMap<String, Object[]> commands, |
| final Map<String, String> sessionInitialConf, |
| final int partitionIndex) { |
| this.commands = commands; |
| this.sessionInitialConf = sessionInitialConf; |
| this.partitionIndex = partitionIndex; |
| this.locations = SparkSourceUtil.getPartitionLocation(partition); |
| } |
| |
| @Override |
| protected Iterator<T> initializeIterator() { |
| // for setting up the same environment in the executors. |
| final SparkSession spark = SparkSession.builder() |
| .config(sessionInitialConf) |
| .getOrCreate(); |
| final Dataset<T> dataset; |
| |
| try { |
| dataset = SparkSession.initializeDataset(spark, commands); |
| } catch (final OperationNotSupportedException e) { |
| throw new IllegalStateException(e); |
| } |
| |
| // Spark does lazy evaluation: it doesn't load the full dataset, but only the partition it is asked for. |
| final RDD<T> rdd = dataset.sparkRDD(); |
| final Iterable<T> iterable = () -> JavaConverters.asJavaIteratorConverter( |
| rdd.iterator(rdd.getPartitions()[partitionIndex], TaskContext$.MODULE$.empty())).asJava(); |
| return iterable.iterator(); |
| } |
| |
| @Override |
| public long readWatermark() { |
| throw new UnsupportedOperationException("No watermark"); |
| } |
| |
| @Override |
| public List<String> getLocations() { |
| if (locations.isEmpty()) { |
| throw new UnsupportedOperationException(); |
| } else { |
| return locations; |
| } |
| } |
| |
| @Override |
| public void close() throws IOException { |
| |
| } |
| } |
| } |