| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.runtime.controlprogram.federated; |
| |
| import java.net.InetSocketAddress; |
| import java.util.ArrayList; |
| import java.util.HashSet; |
| import java.util.List; |
| import java.util.Set; |
| import java.util.concurrent.Future; |
| |
| import org.apache.commons.logging.Log; |
| import org.apache.commons.logging.LogFactory; |
| import org.apache.sysds.common.Types; |
| import org.apache.sysds.conf.DMLConfig; |
| import org.apache.sysds.runtime.DMLRuntimeException; |
| import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType; |
| |
| import io.netty.bootstrap.Bootstrap; |
| import io.netty.channel.ChannelFuture; |
| import io.netty.channel.ChannelHandlerContext; |
| import io.netty.channel.ChannelInboundHandlerAdapter; |
| import io.netty.channel.ChannelInitializer; |
| import io.netty.channel.EventLoopGroup; |
| import io.netty.channel.nio.NioEventLoopGroup; |
| import io.netty.channel.socket.SocketChannel; |
| import io.netty.channel.socket.nio.NioSocketChannel; |
| import io.netty.handler.codec.serialization.ClassResolvers; |
| import io.netty.handler.codec.serialization.ObjectDecoder; |
| import io.netty.handler.codec.serialization.ObjectEncoder; |
| import io.netty.util.concurrent.Promise; |
| |
| |
| public class FederatedData { |
| private static final Log LOG = LogFactory.getLog(FederatedData.class.getName()); |
| private static final Set<InetSocketAddress> _allFedSites = new HashSet<>(); |
| |
| private final Types.DataType _dataType; |
| private final InetSocketAddress _address; |
| private final String _filepath; |
| /** |
| * The ID of default matrix/tensor on which operations get executed if no other ID is given. |
| */ |
| private long _varID = -1; // -1 is never valid since varIDs start at 0 |
| |
| public FederatedData(Types.DataType dataType, InetSocketAddress address, String filepath) { |
| _dataType = dataType; |
| _address = address; |
| _filepath = filepath; |
| if(_address != null) |
| _allFedSites.add(_address); |
| } |
| |
| public InetSocketAddress getAddress() { |
| return _address; |
| } |
| |
| public void setVarID(long varID) { |
| _varID = varID; |
| } |
| |
| public long getVarID() { |
| return _varID; |
| } |
| |
| public String getFilepath() { |
| return _filepath; |
| } |
| |
| public Types.DataType getDataType(){ |
| return _dataType; |
| } |
| |
| public boolean isInitialized() { |
| return _varID != -1; |
| } |
| |
| boolean equalAddress(FederatedData that) { |
| return _address != null && that != null && that._address != null |
| && _address.equals(that._address); |
| } |
| |
| /** |
| * Make a copy of the <code>FederatedData</code> metadata, but use another varID (refer to another object on worker) |
| * @param varID the varID of the variable we refer to |
| * @return new <code>FederatedData</code> with different varID set |
| */ |
| public FederatedData copyWithNewID(long varID) { |
| FederatedData copy = new FederatedData(_dataType, _address, _filepath); |
| copy.setVarID(varID); |
| return copy; |
| } |
| |
| public synchronized Future<FederatedResponse> initFederatedData(long id) { |
| if(isInitialized()) |
| throw new DMLRuntimeException("Tried to init already initialized data"); |
| if(!_dataType.isMatrix() && !_dataType.isFrame()) |
| throw new DMLRuntimeException("Federated datatype \"" + _dataType.toString() + "\" is not supported."); |
| _varID = id; |
| FederatedRequest request = new FederatedRequest(RequestType.READ_VAR, id); |
| request.appendParam(_filepath); |
| request.appendParam(_dataType.name()); |
| return executeFederatedOperation(request); |
| } |
| |
| public synchronized Future<FederatedResponse> executeFederatedOperation(FederatedRequest... request) { |
| return executeFederatedOperation(_address, request); |
| } |
| |
| /** |
| * Executes an federated operation on a federated worker. |
| * |
| * @param address socket address (incl host and port) |
| * @param request the requested operation |
| * @return the response |
| */ |
| public static Future<FederatedResponse> executeFederatedOperation(InetSocketAddress address, FederatedRequest... request) { |
| // Careful with the number of threads. Each thread opens connections to multiple files making resulting in |
| // java.io.IOException: Too many open files |
| EventLoopGroup workerGroup = new NioEventLoopGroup(DMLConfig.DEFAULT_NUMBER_OF_FEDERATED_WORKER_THREADS); |
| try { |
| Bootstrap b = new Bootstrap(); |
| final DataRequestHandler handler = new DataRequestHandler(workerGroup); |
| b.group(workerGroup).channel(NioSocketChannel.class).handler(new ChannelInitializer<SocketChannel>() { |
| @Override |
| public void initChannel(SocketChannel ch) { |
| ch.pipeline().addLast("ObjectDecoder", |
| new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.weakCachingResolver(ClassLoader.getSystemClassLoader()))) |
| .addLast("FederatedOperationHandler", handler) |
| .addLast("ObjectEncoder", new ObjectEncoder()); |
| } |
| }); |
| |
| ChannelFuture f = b.connect(address).sync(); |
| Promise<FederatedResponse> promise = f.channel().eventLoop().newPromise(); |
| handler.setPromise(promise); |
| f.channel().writeAndFlush(request); |
| return promise; |
| } |
| catch (InterruptedException e) { |
| throw new DMLRuntimeException("Could not send federated operation."); |
| } |
| catch (Exception e) { |
| throw new DMLRuntimeException(e); |
| } |
| } |
| |
| public static void clearFederatedWorkers() { |
| if( _allFedSites.isEmpty() ) |
| return; |
| |
| try { |
| //create and execute clear request on all workers |
| FederatedRequest fr = new FederatedRequest(RequestType.CLEAR); |
| List<Future<FederatedResponse>> ret = new ArrayList<>(); |
| for( InetSocketAddress address : _allFedSites ) |
| ret.add(executeFederatedOperation(address, fr)); |
| |
| //wait for successful completion |
| FederationUtils.waitFor(ret); |
| } |
| catch(Exception ex) { |
| LOG.warn("Failed to execute CLEAR request on existing federated sites.", ex); |
| } |
| finally { |
| resetFederatedSites(); |
| } |
| } |
| |
| public static void resetFederatedSites() { |
| _allFedSites.clear(); |
| } |
| |
| private static class DataRequestHandler extends ChannelInboundHandlerAdapter { |
| private Promise<FederatedResponse> _prom; |
| private EventLoopGroup _workerGroup; |
| |
| public DataRequestHandler(EventLoopGroup workerGroup) { |
| _workerGroup = workerGroup; |
| } |
| |
| public void setPromise(Promise<FederatedResponse> prom) { |
| _prom = prom; |
| } |
| |
| @Override |
| public void channelRead(ChannelHandlerContext ctx, Object msg) { |
| if (_prom == null) |
| throw new DMLRuntimeException("Read while no message was sent"); |
| _prom.setSuccess((FederatedResponse) msg); |
| ctx.close(); |
| _workerGroup.shutdownGracefully(); |
| } |
| } |
| |
| @Override |
| public String toString(){ |
| StringBuilder sb = new StringBuilder(); |
| sb.append(this.getClass().getSimpleName().toString()); |
| sb.append(" "+ _dataType); |
| sb.append(" "+_address.toString()); |
| sb.append(":" + _filepath); |
| return sb.toString(); |
| } |
| } |