blob: ad743d67df07ff5336fbad5acfd6e11c541b022c [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.blur.spark;
import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.UUID;
import org.apache.blur.command.stream.StreamClient;
import org.apache.blur.command.stream.StreamFunction;
import org.apache.blur.thirdparty.thrift_0_9_0.TException;
import org.apache.blur.thrift.BlurClient;
import org.apache.blur.thrift.generated.Blur.Iface;
import org.apache.blur.thrift.generated.BlurException;
import org.apache.blur.user.User;
import org.apache.blur.user.UserContext;
import org.apache.blur.utils.BlurConstants;
import org.apache.commons.io.IOUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import com.google.common.base.Splitter;
import com.google.common.io.Closer;
@SuppressWarnings("serial")
public class BlurRDD implements Serializable {
private final static String CLASS_LOADER_ID = UUID.randomUUID().toString();
private final transient Iface _client;
private final transient SparkConf _sparkConf;
private final List<String> _jars = new ArrayList<String>();
private final int _timeout = 60000;
public BlurRDD(String zooKeeperConnectionStr, SparkConf sparkConf) throws IOException {
this(BlurClient.getClientFromZooKeeperConnectionStr(zooKeeperConnectionStr), sparkConf);
}
public BlurRDD(Iface client, SparkConf sparkConf) throws IOException {
_client = client;
_sparkConf = sparkConf;
if (_sparkConf.contains("spark.jars")) {
String jars = _sparkConf.get("spark.jars");
for (String jar : Splitter.on(',').split(jars)) {
_jars.add(jar);
}
}
}
public <T> JavaRDD<T> executeStream(JavaSparkContext context, String table, StreamFunction<T> streamFunction) {
User user = UserContext.getUser();
List<BlurSparkSplit> splits = getSplits(table, user, CLASS_LOADER_ID);
return context.parallelize(splits).flatMap(new FlatMapFunction<BlurSparkSplit, T>() {
@Override
public Iterable<T> call(BlurSparkSplit t) throws Exception {
return new Iterable<T>() {
@Override
public Iterator<T> iterator() {
Closer closer = Closer.create();
try {
String host = t.getHost();
int port = t.getPort();
int timeout = t.getTimeout();
StreamClient streamClient = closer.register(new StreamClient(host, port, timeout));
String classLoaderId = t.getClassLoaderId();
if (!streamClient.isClassLoaderAvailable(classLoaderId)) {
streamClient.loadJars(classLoaderId, _jars);
}
return wrapClose(closer, streamClient.executeStream(t, streamFunction).iterator());
} catch (IOException e) {
IOUtils.closeQuietly(closer);
throw new RuntimeException(e);
}
}
};
}
});
}
private static <T> Iterator<T> wrapClose(Closeable c, Iterator<T> t) {
return new Iterator<T>() {
@Override
public boolean hasNext() {
try {
boolean hasNext = t.hasNext();
if (!hasNext) {
IOUtils.closeQuietly(c);
}
return hasNext;
} catch (Throwable t) {
IOUtils.closeQuietly(c);
if (t instanceof RuntimeException) {
throw (RuntimeException) t;
} else {
throw new RuntimeException(t);
}
}
}
@Override
public T next() {
return t.next();
}
};
}
private List<BlurSparkSplit> getSplits(String table, User user, String classLoaderId) {
try {
Map<String, String> shardServerLayout = _client.shardServerLayout(table);
List<BlurSparkSplit> splits = new ArrayList<BlurSparkSplit>();
for (Entry<String, String> e : shardServerLayout.entrySet()) {
String shard = e.getKey();
String shardServerWithThriftPort = e.getValue();
String host = getHost(shardServerWithThriftPort);
int port = getStreamPort(shardServerWithThriftPort);
String username = user.getUsername();
Map<String, String> attributes = user.getAttributes();
splits.add(new BlurSparkSplit(host, port, _timeout, table, shard, classLoaderId, username, attributes));
}
return splits;
} catch (BlurException e) {
throw new RuntimeException(e);
} catch (TException e) {
throw new RuntimeException(e);
}
}
private String getHost(String shardServerWithThriftPort) {
return shardServerWithThriftPort.substring(0, shardServerWithThriftPort.indexOf(':'));
}
private int getStreamPort(String shardServerWithThriftPort) throws BlurException, TException {
String port = _client.configurationPerServer(shardServerWithThriftPort,
BlurConstants.BLUR_STREAM_SERVER_RUNNING_PORT);
return Integer.parseInt(port);
}
public static byte[] toBytes(Serializable s) {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
try {
ObjectOutputStream objectOutputStream = new ObjectOutputStream(outputStream);
objectOutputStream.writeObject(s);
objectOutputStream.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
return outputStream.toByteArray();
}
}