| /** |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.tez.runtime.common.resources; |
| |
| import java.util.Collections; |
| import java.util.Iterator; |
| import java.util.LinkedList; |
| import java.util.List; |
| import java.util.Set; |
| import java.util.concurrent.ConcurrentHashMap; |
| import java.util.concurrent.atomic.AtomicInteger; |
| |
| import org.apache.commons.logging.Log; |
| import org.apache.commons.logging.LogFactory; |
| import org.apache.hadoop.classification.InterfaceAudience.Private; |
| import org.apache.hadoop.conf.Configuration; |
| import org.apache.tez.common.ReflectionUtils; |
| import org.apache.tez.common.TezJobConfig; |
| import org.apache.tez.dag.api.TezEntityDescriptor; |
| import org.apache.tez.dag.api.TezUncheckedException; |
| import org.apache.tez.runtime.api.MemoryUpdateCallback; |
| import org.apache.tez.runtime.api.TezInputContext; |
| import org.apache.tez.runtime.api.TezOutputContext; |
| import org.apache.tez.runtime.api.TezProcessorContext; |
| import org.apache.tez.runtime.api.TezTaskContext; |
| |
| import com.google.common.annotations.VisibleForTesting; |
| import com.google.common.base.Function; |
| import com.google.common.base.Preconditions; |
| import com.google.common.collect.Iterables; |
| |
| // Not calling this a MemoryManager explicitly. Not yet anyway. |
| @Private |
| public class MemoryDistributor { |
| |
| private static final Log LOG = LogFactory.getLog(MemoryDistributor.class); |
| |
| private final int numTotalInputs; |
| private final int numTotalOutputs; |
| private final Configuration conf; |
| |
| private AtomicInteger numInputsSeen = new AtomicInteger(0); |
| private AtomicInteger numOutputsSeen = new AtomicInteger(0); |
| |
| private long totalJvmMemory; |
| private final boolean isEnabled; |
| private final Set<TezTaskContext> dupSet = Collections |
| .newSetFromMap(new ConcurrentHashMap<TezTaskContext, Boolean>()); |
| private final List<RequestorInfo> requestList; |
| |
| /** |
| * @param numTotalInputs |
| * total number of Inputs for the task |
| * @param numTotalOutputs |
| * total number of Outputs for the task |
| * @param conf |
| * Tez specific task configuration |
| */ |
| public MemoryDistributor(int numTotalInputs, int numTotalOutputs, Configuration conf) { |
| this.conf = conf; |
| isEnabled = conf.getBoolean(TezJobConfig.TEZ_RUNTIME_SCALE_TASK_MEMORY_ENABLED, |
| TezJobConfig.TEZ_RUNTIME_SCALE_TASK_MEMORY_ENABLED_DEFAULT); |
| |
| |
| this.numTotalInputs = numTotalInputs; |
| this.numTotalOutputs = numTotalOutputs; |
| this.totalJvmMemory = Runtime.getRuntime().maxMemory(); |
| this.requestList = Collections.synchronizedList(new LinkedList<RequestorInfo>()); |
| LOG.info("InitialMemoryDistributor (isEnabled=" + isEnabled + ") invoked with: numInputs=" |
| + numTotalInputs + ", numOutputs=" + numTotalOutputs |
| + ", JVM.maxFree=" + totalJvmMemory); |
| } |
| |
| |
| |
| /** |
| * Used by the Tez framework to request memory on behalf of user requests. |
| */ |
| public void requestMemory(long requestSize, MemoryUpdateCallback callback, |
| TezTaskContext taskContext, TezEntityDescriptor descriptor) { |
| registerRequest(requestSize, callback, taskContext, descriptor); |
| } |
| |
| /** |
| * Used by the Tez framework to distribute initial memory after components |
| * have made their initial requests. |
| */ |
| public void makeInitialAllocations() { |
| Preconditions.checkState(numInputsSeen.get() == numTotalInputs, "All inputs are expected to ask for memory"); |
| Preconditions.checkState(numOutputsSeen.get() == numTotalOutputs, "All outputs are expected to ask for memory"); |
| Iterable<InitialMemoryRequestContext> requestContexts = Iterables.transform(requestList, |
| new Function<RequestorInfo, InitialMemoryRequestContext>() { |
| public InitialMemoryRequestContext apply(RequestorInfo requestInfo) { |
| return requestInfo.getRequestContext(); |
| } |
| }); |
| |
| Iterable<Long> allocations = null; |
| if (!isEnabled) { |
| allocations = Iterables.transform(requestList, new Function<RequestorInfo, Long>() { |
| public Long apply(RequestorInfo requestInfo) { |
| return requestInfo.getRequestContext().getRequestedSize(); |
| } |
| }); |
| } else { |
| String allocatorClassName = conf.get(TezJobConfig.TEZ_RUNTIME_SCALE_TASK_MEMORY_ALLOCATOR_CLASS, |
| TezJobConfig.TEZ_RUNTIME_SCALE_TASK_MEMORY_ALLOCATOR_CLASS_DEFAULT); |
| LOG.info("Using Allocator class: " + allocatorClassName); |
| InitialMemoryAllocator allocator = ReflectionUtils.createClazzInstance(allocatorClassName); |
| allocator.setConf(conf); |
| allocations = allocator.assignMemory(totalJvmMemory, numTotalInputs, numTotalOutputs, |
| Iterables.unmodifiableIterable(requestContexts)); |
| validateAllocations(allocations, requestList.size()); |
| } |
| |
| // Making the callbacks directly for now, instead of spawning threads. The |
| // callback implementors - all controlled by Tez at the moment are |
| // lightweight. |
| Iterator<Long> allocatedIter = allocations.iterator(); |
| for (RequestorInfo rInfo : requestList) { |
| long allocated = allocatedIter.next(); |
| LOG.info("Informing: " + rInfo.getRequestContext().getComponentType() + ", " |
| + rInfo.getRequestContext().getComponentVertexName() + ", " |
| + rInfo.getRequestContext().getComponentClassName() + ": requested=" |
| + rInfo.getRequestContext().getRequestedSize() + ", allocated=" + allocated); |
| rInfo.getCallback().memoryAssigned(allocated); |
| } |
| } |
| |
| /** |
| * Allow tests to set memory. |
| * @param size |
| */ |
| @Private |
| @VisibleForTesting |
| void setJvmMemory(long size) { |
| this.totalJvmMemory = size; |
| } |
| |
| private long registerRequest(long requestSize, MemoryUpdateCallback callback, |
| TezTaskContext entityContext, TezEntityDescriptor descriptor) { |
| Preconditions.checkArgument(requestSize >= 0); |
| Preconditions.checkNotNull(callback); |
| Preconditions.checkNotNull(entityContext); |
| Preconditions.checkNotNull(descriptor); |
| if (!dupSet.add(entityContext)) { |
| throw new TezUncheckedException( |
| "A single entity can only make one call to request resources for now"); |
| } |
| |
| RequestorInfo requestInfo = new RequestorInfo(entityContext,requestSize, callback, descriptor); |
| switch (requestInfo.getRequestContext().getComponentType()) { |
| case INPUT: |
| numInputsSeen.incrementAndGet(); |
| Preconditions.checkState(numInputsSeen.get() <= numTotalInputs, |
| "Num Requesting Inputs higher than total # of inputs: " + numInputsSeen + ", " |
| + numTotalInputs); |
| break; |
| case OUTPUT: |
| numOutputsSeen.incrementAndGet(); |
| Preconditions.checkState(numOutputsSeen.get() <= numTotalOutputs, |
| "Num Requesting Inputs higher than total # of outputs: " + numOutputsSeen + ", " |
| + numTotalOutputs); |
| case PROCESSOR: |
| break; |
| default: |
| break; |
| } |
| requestList.add(requestInfo); |
| return -1; |
| } |
| |
| private void validateAllocations(Iterable<Long> allocations, int numRequestors) { |
| Preconditions.checkNotNull(allocations); |
| long totalAllocated = 0l; |
| int numAllocations = 0; |
| for (Long l : allocations) { |
| totalAllocated += l; |
| numAllocations++; |
| } |
| Preconditions.checkState(numAllocations == numRequestors, |
| "Number of allocations must match number of requestors. Allocated=" + numAllocations |
| + ", Requests: " + numRequestors); |
| Preconditions.checkState(totalAllocated <= totalJvmMemory, |
| "Total allocation should be <= availableMem. TotalAllocated: " + totalAllocated |
| + ", totalJvmMemory: " + totalJvmMemory); |
| } |
| |
| |
| private static class RequestorInfo { |
| |
| private static final Log LOG = LogFactory.getLog(RequestorInfo.class); |
| |
| private final MemoryUpdateCallback callback; |
| private final InitialMemoryRequestContext requestContext; |
| |
| public RequestorInfo(TezTaskContext taskContext, long requestSize, |
| final MemoryUpdateCallback callback, TezEntityDescriptor descriptor) { |
| InitialMemoryRequestContext.ComponentType type; |
| String componentVertexName; |
| if (taskContext instanceof TezInputContext) { |
| type = InitialMemoryRequestContext.ComponentType.INPUT; |
| componentVertexName = ((TezInputContext) taskContext).getSourceVertexName(); |
| } else if (taskContext instanceof TezOutputContext) { |
| type = InitialMemoryRequestContext.ComponentType.OUTPUT; |
| componentVertexName = ((TezOutputContext) taskContext).getDestinationVertexName(); |
| } else if (taskContext instanceof TezProcessorContext) { |
| type = InitialMemoryRequestContext.ComponentType.PROCESSOR; |
| componentVertexName = ((TezProcessorContext) taskContext).getTaskVertexName(); |
| } else { |
| throw new IllegalArgumentException("Unknown type of entityContext: " |
| + taskContext.getClass().getName()); |
| } |
| this.requestContext = new InitialMemoryRequestContext(requestSize, descriptor.getClassName(), |
| type, componentVertexName); |
| this.callback = callback; |
| LOG.info("Received request: " + requestSize + ", type: " + type + ", componentVertexName: " |
| + componentVertexName); |
| } |
| |
| public MemoryUpdateCallback getCallback() { |
| return callback; |
| } |
| |
| public InitialMemoryRequestContext getRequestContext() { |
| return requestContext; |
| } |
| } |
| |
| } |