| /** |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.tez.runtime.common.resources; |
| |
| import java.util.Collections; |
| import java.util.Iterator; |
| import java.util.LinkedList; |
| import java.util.List; |
| import java.util.Set; |
| import java.util.concurrent.ConcurrentHashMap; |
| import java.util.concurrent.atomic.AtomicInteger; |
| import java.util.Objects; |
| |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| import org.apache.hadoop.classification.InterfaceAudience.Private; |
| import org.apache.hadoop.conf.Configuration; |
| import org.apache.tez.common.ReflectionUtils; |
| import org.apache.tez.dag.api.TezConfiguration; |
| import org.apache.tez.dag.api.EntityDescriptor; |
| import org.apache.tez.dag.api.TezException; |
| import org.apache.tez.dag.api.TezUncheckedException; |
| import org.apache.tez.runtime.api.MemoryUpdateCallback; |
| import org.apache.tez.runtime.api.InputContext; |
| import org.apache.tez.runtime.api.OutputContext; |
| import org.apache.tez.runtime.api.ProcessorContext; |
| import org.apache.tez.runtime.api.TaskContext; |
| |
| import com.google.common.annotations.VisibleForTesting; |
| import com.google.common.base.Function; |
| import org.apache.tez.common.Preconditions; |
| import com.google.common.collect.Iterables; |
| |
| // Not calling this a MemoryManager explicitly. Not yet anyway. |
| @Private |
| public class MemoryDistributor { |
| |
| private static final Logger LOG = LoggerFactory.getLogger(MemoryDistributor.class); |
| |
| private final int numTotalInputs; |
| private final int numTotalOutputs; |
| private final Configuration conf; |
| |
| private AtomicInteger numInputsSeen = new AtomicInteger(0); |
| private AtomicInteger numOutputsSeen = new AtomicInteger(0); |
| |
| private long totalJvmMemory; |
| private final boolean isEnabled; |
| private final boolean isInputOutputConcurrent; |
| private final String allocatorClassName; |
| private final Set<TaskContext> dupSet = Collections |
| .newSetFromMap(new ConcurrentHashMap<TaskContext, Boolean>()); |
| private final List<RequestorInfo> requestList; |
| |
| /** |
| * @param numTotalInputs |
| * total number of Inputs for the task |
| * @param numTotalOutputs |
| * total number of Outputs for the task |
| * @param conf |
| * Tez specific task configuration |
| */ |
| public MemoryDistributor(int numTotalInputs, int numTotalOutputs, Configuration conf) { |
| this.conf = conf; |
| isEnabled = conf.getBoolean(TezConfiguration.TEZ_TASK_SCALE_MEMORY_ENABLED, |
| TezConfiguration.TEZ_TASK_SCALE_MEMORY_ENABLED_DEFAULT); |
| isInputOutputConcurrent = conf.getBoolean( |
| TezConfiguration.TEZ_TASK_SCALE_MEMORY_INPUT_OUTPUT_CONCURRENT, |
| TezConfiguration.TEZ_TASK_SCALE_MEMORY_INPUT_OUTPUT_CONCURRENT_DEFAULT); |
| |
| if (isEnabled) { |
| allocatorClassName = conf.get(TezConfiguration.TEZ_TASK_SCALE_MEMORY_ALLOCATOR_CLASS, |
| TezConfiguration.TEZ_TASK_SCALE_MEMORY_ALLOCATOR_CLASS_DEFAULT); |
| } else { |
| allocatorClassName = null; |
| } |
| |
| this.numTotalInputs = numTotalInputs; |
| this.numTotalOutputs = numTotalOutputs; |
| this.totalJvmMemory = Runtime.getRuntime().maxMemory(); |
| this.requestList = Collections.synchronizedList(new LinkedList<RequestorInfo>()); |
| LOG.info("InitialMemoryDistributor (isEnabled=" + isEnabled + ") invoked with: numInputs=" |
| + numTotalInputs + ", numOutputs=" + numTotalOutputs |
| + ", JVM.maxFree=" + totalJvmMemory |
| + ", allocatorClassName=" + allocatorClassName); |
| } |
| |
| |
| |
| /** |
| * Used by the Tez framework to request memory on behalf of user requests. |
| */ |
| public void requestMemory(long requestSize, MemoryUpdateCallback callback, |
| TaskContext taskContext, EntityDescriptor<?> descriptor) { |
| registerRequest(requestSize, callback, taskContext, descriptor); |
| } |
| |
| /** |
| * Used by the Tez framework to distribute initial memory after components |
| * have made their initial requests. |
| * @throws TezException |
| */ |
| public void makeInitialAllocations() throws TezException { |
| Preconditions.checkState(numInputsSeen.get() == numTotalInputs, "All inputs are expected to ask for memory"); |
| Preconditions.checkState(numOutputsSeen.get() == numTotalOutputs, "All outputs are expected to ask for memory"); |
| |
| logInitialRequests(requestList); |
| |
| Iterable<InitialMemoryRequestContext> requestContexts = Iterables.transform(requestList, |
| new Function<RequestorInfo, InitialMemoryRequestContext>() { |
| public InitialMemoryRequestContext apply(RequestorInfo requestInfo) { |
| return requestInfo.getRequestContext(); |
| } |
| }); |
| |
| Iterable<Long> allocations = null; |
| if (!isEnabled) { |
| allocations = Iterables.transform(requestList, new Function<RequestorInfo, Long>() { |
| public Long apply(RequestorInfo requestInfo) { |
| return requestInfo.getRequestContext().getRequestedSize(); |
| } |
| }); |
| } else { |
| InitialMemoryAllocator allocator = ReflectionUtils.createClazzInstance(allocatorClassName); |
| allocator.setConf(conf); |
| allocations = allocator.assignMemory(totalJvmMemory, numTotalInputs, numTotalOutputs, |
| Iterables.unmodifiableIterable(requestContexts)); |
| validateAllocations(allocations, requestList.size()); |
| logFinalAllocations(allocations, requestList); |
| } |
| |
| // Making the callbacks directly for now, instead of spawning threads. The |
| // callback implementors - all controlled by Tez at the moment are |
| // lightweight. |
| Iterator<Long> allocatedIter = allocations.iterator(); |
| for (RequestorInfo rInfo : requestList) { |
| long allocated = allocatedIter.next(); |
| if (LOG.isDebugEnabled()) { |
| LOG.info("Informing: " + rInfo.getRequestContext().getComponentType() + ", " |
| + rInfo.getRequestContext().getComponentVertexName() + ", " |
| + rInfo.getRequestContext().getComponentClassName() + ": requested=" |
| + rInfo.getRequestContext().getRequestedSize() + ", allocated=" + allocated); |
| } |
| rInfo.getCallback().memoryAssigned(allocated); |
| } |
| } |
| |
| |
| |
| /** |
| * Allow tests to set memory. |
| * @param size |
| */ |
| @Private |
| @VisibleForTesting |
| void setJvmMemory(long size) { |
| this.totalJvmMemory = size; |
| } |
| |
| private long registerRequest(long requestSize, MemoryUpdateCallback callback, |
| TaskContext entityContext, EntityDescriptor<?> descriptor) { |
| Preconditions.checkArgument(requestSize >= 0); |
| Objects.requireNonNull(callback); |
| Objects.requireNonNull(entityContext); |
| Objects.requireNonNull(descriptor); |
| if (!dupSet.add(entityContext)) { |
| throw new TezUncheckedException( |
| "A single entity can only make one call to request resources for now"); |
| } |
| |
| RequestorInfo requestInfo = new RequestorInfo(entityContext,requestSize, callback, descriptor); |
| switch (requestInfo.getRequestContext().getComponentType()) { |
| case INPUT: |
| numInputsSeen.incrementAndGet(); |
| Preconditions.checkState(numInputsSeen.get() <= numTotalInputs, |
| "Num Requesting Inputs higher than total # of inputs: " + numInputsSeen + ", " |
| + numTotalInputs); |
| break; |
| case OUTPUT: |
| numOutputsSeen.incrementAndGet(); |
| Preconditions.checkState(numOutputsSeen.get() <= numTotalOutputs, |
| "Num Requesting Inputs higher than total # of outputs: " + numOutputsSeen + ", " |
| + numTotalOutputs); |
| break; |
| case PROCESSOR: |
| break; |
| default: |
| break; |
| } |
| requestList.add(requestInfo); |
| return -1; |
| } |
| |
| private void validateAllocations(Iterable<Long> allocations, int numRequestors) { |
| Objects.requireNonNull(allocations); |
| long totalAllocated = 0l; |
| int numAllocations = 0; |
| for (Long l : allocations) { |
| totalAllocated += l; |
| numAllocations++; |
| } |
| Preconditions.checkState(numAllocations == numRequestors, |
| "Number of allocations must match number of requestors. Allocated=" + numAllocations |
| + ", Requests: " + numRequestors); |
| if (isInputOutputConcurrent) { |
| Preconditions.checkState(totalAllocated <= totalJvmMemory, |
| "Total allocation should be <= availableMem. TotalAllocated: " + totalAllocated |
| + ", totalJvmMemory: " + totalJvmMemory); |
| } |
| } |
| |
| |
| private static class RequestorInfo { |
| |
| private static final Logger LOG = LoggerFactory.getLogger(RequestorInfo.class); |
| |
| private final MemoryUpdateCallback callback; |
| private final InitialMemoryRequestContext requestContext; |
| |
| public RequestorInfo(TaskContext taskContext, long requestSize, |
| final MemoryUpdateCallback callback, EntityDescriptor<?> descriptor) { |
| InitialMemoryRequestContext.ComponentType type; |
| String componentVertexName; |
| if (taskContext instanceof InputContext) { |
| type = InitialMemoryRequestContext.ComponentType.INPUT; |
| componentVertexName = ((InputContext) taskContext).getSourceVertexName(); |
| } else if (taskContext instanceof OutputContext) { |
| type = InitialMemoryRequestContext.ComponentType.OUTPUT; |
| componentVertexName = ((OutputContext) taskContext).getDestinationVertexName(); |
| } else if (taskContext instanceof ProcessorContext) { |
| type = InitialMemoryRequestContext.ComponentType.PROCESSOR; |
| componentVertexName = ((ProcessorContext) taskContext).getTaskVertexName(); |
| } else { |
| throw new IllegalArgumentException("Unknown type of entityContext: " |
| + taskContext.getClass().getName()); |
| } |
| this.requestContext = new InitialMemoryRequestContext(requestSize, descriptor.getClassName(), |
| type, componentVertexName); |
| this.callback = callback; |
| } |
| |
| public MemoryUpdateCallback getCallback() { |
| return callback; |
| } |
| |
| public InitialMemoryRequestContext getRequestContext() { |
| return requestContext; |
| } |
| } |
| |
| |
| private void logInitialRequests(List<RequestorInfo> initialRequests) { |
| if (initialRequests != null && !initialRequests.isEmpty()) { |
| StringBuilder sb = new StringBuilder(); |
| for (int i = 0; i < initialRequests.size(); i++) { |
| InitialMemoryRequestContext context = initialRequests.get(i).getRequestContext(); |
| sb.append("["); |
| sb.append(context.getComponentVertexName()).append(":"); |
| sb.append(context.getComponentType()).append(":"); |
| sb.append(context.getRequestedSize()).append(":").append(context.getComponentClassName()); |
| sb.append("]"); |
| if (i < initialRequests.size() - 1) { |
| sb.append(", "); |
| } |
| } |
| LOG.info("InitialRequests=" + sb.toString()); |
| } |
| } |
| |
| private void logFinalAllocations(Iterable<Long> allocations, List<RequestorInfo> requestList) { |
| if (requestList != null && !requestList.isEmpty()) { |
| Iterator<Long> allocatedIter = allocations.iterator(); |
| StringBuilder sb = new StringBuilder(); |
| |
| for (int i = 0 ; i < requestList.size() ; i++) { |
| long allocated = allocatedIter.next(); |
| InitialMemoryRequestContext context = requestList.get(i).getRequestContext(); |
| sb.append("["); |
| sb.append(context.getComponentVertexName()).append(":"); |
| sb.append(context.getComponentClassName()).append(":"); |
| sb.append(context.getComponentType()).append(":"); |
| sb.append(context.getRequestedSize()).append(":").append(allocated); |
| sb.append("]"); |
| if (i < requestList.size() - 1) { |
| sb.append(", "); |
| } |
| } |
| LOG.info("Allocations=" + sb.toString()); |
| } |
| } |
| |
| } |