blob: 60c86c1e1eafcc1a3c0136253fd5ffa27114e535 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.samza.util;
import com.google.common.base.Preconditions;
import com.google.common.base.Stopwatch;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.samza.container.TaskName;
import org.apache.samza.context.Context;
import org.apache.samza.context.TaskContextImpl;
import org.apache.samza.job.model.JobModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import static java.util.concurrent.TimeUnit.NANOSECONDS;
/**
* An embedded rate limiter that supports tags. A default tag will be used if users specifies a simple rate only
* for simple use cases.
*/
public class EmbeddedTaggedRateLimiter implements RateLimiter {
static final private Logger LOGGER = LoggerFactory.getLogger(EmbeddedTaggedRateLimiter.class);
private static final String DEFAULT_TAG = "default-tag";
private static final Map<String, Integer> DEFAULT_TAG_MAP = Collections.singletonMap(DEFAULT_TAG, 0);
private final Map<String, Integer> tagToTargetRateMap;
private Map<String, com.google.common.util.concurrent.RateLimiter> tagToRateLimiterMap;
private boolean initialized;
public EmbeddedTaggedRateLimiter(int creditsPerSecond) {
this(Collections.singletonMap(DEFAULT_TAG, creditsPerSecond));
}
public EmbeddedTaggedRateLimiter(Map<String, Integer> tagToCreditsPerSecondMap) {
Preconditions.checkArgument(tagToCreditsPerSecondMap.size() > 0, "Map of tags can't be empty");
tagToCreditsPerSecondMap.values().forEach(c -> Preconditions.checkArgument(c >= 0, "Credits must be non-negative"));
this.tagToTargetRateMap = tagToCreditsPerSecondMap;
}
@Override
public void acquire(Map<String, Integer> tagToCreditsMap) {
ensureTagsAreValid(tagToCreditsMap);
tagToCreditsMap.forEach((tag, numberOfCredits) -> tagToRateLimiterMap.get(tag).acquire(numberOfCredits));
}
@Override
public Map<String, Integer> acquire(Map<String, Integer> tagToCreditsMap, long timeout, TimeUnit unit) {
ensureTagsAreValid(tagToCreditsMap);
long timeoutInNanos = NANOSECONDS.convert(timeout, unit);
Stopwatch stopwatch = Stopwatch.createStarted();
return tagToCreditsMap.entrySet().stream()
.map(e -> {
String tag = e.getKey();
int requiredCredits = e.getValue();
long remainingTimeoutInNanos = Math.max(0L, timeoutInNanos - stopwatch.elapsed(NANOSECONDS));
com.google.common.util.concurrent.RateLimiter rateLimiter = tagToRateLimiterMap.get(tag);
int availableCredits = rateLimiter.tryAcquire(requiredCredits, remainingTimeoutInNanos, NANOSECONDS)
? requiredCredits
: 0;
return new ImmutablePair<>(tag, availableCredits);
})
.collect(Collectors.toMap(ImmutablePair::getKey, ImmutablePair::getValue));
}
@Override
public Set<String> getSupportedTags() {
return Collections.unmodifiableSet(tagToRateLimiterMap.keySet());
}
@Override
public void acquire(int numberOfCredits) {
ensureTagsAreValid(DEFAULT_TAG_MAP);
tagToRateLimiterMap.get(DEFAULT_TAG).acquire(numberOfCredits);
}
@Override
public int acquire(int numberOfCredit, long timeout, TimeUnit unit) {
ensureTagsAreValid(DEFAULT_TAG_MAP);
return tagToRateLimiterMap.get(DEFAULT_TAG).tryAcquire(numberOfCredit, timeout, unit)
? numberOfCredit
: 0;
}
@Override
public void init(Context context) {
this.tagToRateLimiterMap = Collections.unmodifiableMap(tagToTargetRateMap.entrySet().stream()
.map(e -> {
String tag = e.getKey();
JobModel jobModel = ((TaskContextImpl) context.getTaskContext()).getJobModel();
int numTasks = jobModel.getContainers().values().stream()
.mapToInt(cm -> cm.getTasks().size())
.sum();
double effectiveRate = (double) e.getValue() / numTasks;
TaskName taskName = context.getTaskContext().getTaskModel().getTaskName();
LOGGER.info(String.format("Effective rate limit for task %s and tag %s is %f", taskName, tag,
effectiveRate));
if (effectiveRate < 1.0) {
LOGGER.warn(String.format("Effective limit rate (%f) is very low. "
+ "Total rate limit is %d while number of tasks is %d. Consider increasing the rate limit.",
effectiveRate, e.getValue(), numTasks));
}
return new ImmutablePair<>(tag, com.google.common.util.concurrent.RateLimiter.create(effectiveRate));
})
.collect(Collectors.toMap(ImmutablePair::getKey, ImmutablePair::getValue))
);
initialized = true;
}
private void ensureInitialized() {
Preconditions.checkState(initialized, "Not initialized");
}
private void ensureTagsAreValid(Map<String, ?> tagMap) {
ensureInitialized();
tagMap.keySet().forEach(tag ->
Preconditions.checkArgument(tagToRateLimiterMap.containsKey(tag), "Invalid tag: " + tag));
}
}