blob: 87aafaf04c791c5e923a1d840799bf65672efd19 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.tensorflow.cluster.util;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.UUID;
import org.apache.ignite.Ignite;
import org.apache.ignite.cache.affinity.Affinity;
import org.apache.ignite.cluster.ClusterNode;
import org.apache.ignite.tensorflow.cluster.spec.TensorFlowClusterSpec;
import org.apache.ignite.tensorflow.cluster.spec.TensorFlowServerAddressSpec;
/**
* TensorFlow cluster resolver based on Ignite Cache affinity.
*/
public class TensorFlowClusterResolver {
/** TensorFlow worker job name. */
public static final String WORKER_JOB_NAME = "worker";
/** TensorFlow chief job name. */
public static final String CHIEF_JOB_NAME = "chief";
/** Ignite instance. */
private final Ignite ignite;
/** Cluster port manager. */
private final ClusterPortManager portMgr;
/**
* Constructs a new instance of TensorFlow cluster resolver.
*
* @param ignite Ignite instance.
*/
public TensorFlowClusterResolver(Ignite ignite, String portPoolName, int portFrom, int portCnt) {
assert ignite != null : "Ignite instance should not be null";
assert portPoolName != null : "Port pool name should not be null";
assert portFrom >= 0 : "Port count should not be negative";
assert portCnt >= 0 && portCnt + portFrom <= 0xFFFF : "Port range should be between 0 and 65535";
this.ignite = ignite;
this.portMgr = new ClusterPortManager(ignite, portPoolName, portFrom, portCnt);
}
/**
* Resolves TensorFlow cluster and acquires required ports.
*
* @param upstreamCacheName Upstream cache name.
* @return TensorFlow cluster specification.
*/
public TensorFlowClusterSpec resolveAndAcquirePorts(String upstreamCacheName) {
TensorFlowClusterSpec spec = new TensorFlowClusterSpec();
resolveAndAcquirePortsForWorkers(spec, upstreamCacheName);
resolveAndAcquirePortsForChief(spec);
return spec;
}
/**
* Releases ports acquired for the given cluster specification.
*
* @param spec TensorFlow cluster specification.
*/
public void releasePorts(TensorFlowClusterSpec spec) {
for (String jobName : spec.getJobs().keySet())
for (TensorFlowServerAddressSpec address : spec.getJobs().get(jobName))
portMgr.releasePort(address.getNodeId(), address.getPort());
}
/** Destroys TensorFlow cluster resolver. */
public void destroy() {
portMgr.destroy();
}
/**
* Resolves TensorFlow cluster worker jobs and acquires ports.
*
* @param spec TensorFlow cluster specification.
* @param upstreamCacheName Upstream cache name.
*/
private void resolveAndAcquirePortsForWorkers(TensorFlowClusterSpec spec, String upstreamCacheName) {
Affinity<?> affinity = ignite.affinity(upstreamCacheName);
int parts = affinity.partitions();
Set<UUID> distinctNodeIds = new HashSet<>();
for (int part = 0; part < parts; part++) {
ClusterNode node = affinity.mapPartitionToNode(part);
UUID nodeId = node.id();
distinctNodeIds.add(nodeId);
}
List<UUID> nodeIds = new ArrayList<>(distinctNodeIds);
Collections.sort(nodeIds);
for (UUID nodeId : nodeIds) {
int port = portMgr.acquirePort(nodeId);
spec.addTask(WORKER_JOB_NAME, nodeId, port);
}
}
/**
* Resolves TensorFlow cluster chief job and acquires ports.
*
* @param spec TensorFlow cluster specification.
*/
private void resolveAndAcquirePortsForChief(TensorFlowClusterSpec spec) {
ClusterNode chiefNode = ignite.cluster().localNode();
UUID chiefNodeId = chiefNode.id();
int chiefPort = portMgr.acquirePort(chiefNodeId);
spec.addTask(CHIEF_JOB_NAME, chiefNodeId, chiefPort);
}
}