| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| package org.apache.samza.util; |
| |
| import org.apache.samza.container.TaskName; |
| import org.apache.samza.context.Context; |
| import org.apache.samza.context.TaskContextImpl; |
| import org.apache.samza.job.model.ContainerModel; |
| import org.apache.samza.job.model.JobModel; |
| import org.apache.samza.job.model.TaskModel; |
| import org.junit.Assert; |
| import org.junit.Ignore; |
| import org.junit.Test; |
| |
| import java.util.HashMap; |
| import java.util.Map; |
| import java.util.function.Function; |
| import java.util.stream.Collectors; |
| import java.util.stream.IntStream; |
| |
| import static java.util.concurrent.TimeUnit.MILLISECONDS; |
| import static org.mockito.Mockito.mock; |
| import static org.mockito.Mockito.when; |
| |
| |
| public class TestEmbeddedTaggedRateLimiter { |
| |
| final static private int TEST_INTERVAL = 200; // ms |
| final static private int NUMBER_OF_TASKS = 2; |
| final static private int TARGET_RATE_RED = 1000; |
| final static private int TARGET_RATE_PER_TASK_RED = TARGET_RATE_RED / NUMBER_OF_TASKS; |
| final static private int TARGET_RATE_GREEN = 2000; |
| final static private int INCREMENT = 2; |
| |
| final static private int TARGET_RATE = 4000; |
| final static private int TARGET_RATE_PER_TASK = TARGET_RATE / NUMBER_OF_TASKS; |
| |
| @Test |
| @Ignore("Flaky Test: Test fails in travis.") |
| public void testAcquire() { |
| RateLimiter rateLimiter = new EmbeddedTaggedRateLimiter(TARGET_RATE); |
| initRateLimiter(rateLimiter); |
| |
| int count = 0; |
| long start = System.currentTimeMillis(); |
| while (System.currentTimeMillis() - start < TEST_INTERVAL) { |
| rateLimiter.acquire(INCREMENT); |
| count += INCREMENT; |
| } |
| |
| long rate = count * 1000 / TEST_INTERVAL; |
| verifyRate(rate); |
| } |
| |
| @Test |
| @Ignore("Flaky Test.") |
| public void testAcquireWithTimeout() { |
| RateLimiter rateLimiter = new EmbeddedTaggedRateLimiter(TARGET_RATE); |
| initRateLimiter(rateLimiter); |
| |
| boolean hasSeenZeros = false; |
| |
| int count = 0; |
| int callCount = 0; |
| long start = System.currentTimeMillis(); |
| while (System.currentTimeMillis() - start < TEST_INTERVAL) { |
| ++callCount; |
| int availableCredits = rateLimiter.acquire(INCREMENT, 20, MILLISECONDS); |
| if (availableCredits <= 0) { |
| hasSeenZeros = true; |
| } else { |
| count += INCREMENT; |
| } |
| } |
| |
| long rate = count * 1000 / TEST_INTERVAL; |
| verifyRate(rate); |
| junit.framework.Assert.assertTrue(Math.abs(callCount - TARGET_RATE_PER_TASK * TEST_INTERVAL / 1000 / INCREMENT) <= 2); |
| junit.framework.Assert.assertFalse(hasSeenZeros); |
| } |
| |
| @Test(expected = IllegalStateException.class) |
| public void testFailsWhenUninitialized() { |
| new EmbeddedTaggedRateLimiter(100).acquire(1); |
| } |
| |
| @Test(expected = IllegalArgumentException.class) |
| public void testFailsWhenUsingTags() { |
| RateLimiter rateLimiter = new EmbeddedTaggedRateLimiter(10); |
| initRateLimiter(rateLimiter); |
| Map<String, Integer> tagToCredits = new HashMap<>(); |
| tagToCredits.put("red", 1); |
| tagToCredits.put("green", 1); |
| rateLimiter.acquire(tagToCredits); |
| } |
| |
| @Test |
| public void testAcquireTagged() { |
| RateLimiter rateLimiter = createRateLimiter(); |
| |
| Map<String, Integer> tagToCount = new HashMap<>(); |
| tagToCount.put("red", 0); |
| tagToCount.put("green", 0); |
| |
| Map<String, Integer> tagToCredits = new HashMap<>(); |
| tagToCredits.put("red", INCREMENT); |
| tagToCredits.put("green", INCREMENT); |
| |
| long start = System.currentTimeMillis(); |
| while (System.currentTimeMillis() - start < TEST_INTERVAL) { |
| rateLimiter.acquire(tagToCredits); |
| tagToCount.put("red", tagToCount.get("red") + INCREMENT); |
| tagToCount.put("green", tagToCount.get("green") + INCREMENT); |
| } |
| |
| { |
| long rate = tagToCount.get("red") * 1000 / TEST_INTERVAL; |
| verifyRate(rate, TARGET_RATE_PER_TASK_RED); |
| } { |
| // Note: due to blocking, green is capped at red's QPS |
| long rate = tagToCount.get("green") * 1000 / TEST_INTERVAL; |
| verifyRate(rate, TARGET_RATE_PER_TASK_RED); |
| } |
| } |
| |
| @Test |
| public void testAcquireWithTimeoutTagged() { |
| |
| RateLimiter rateLimiter = createRateLimiter(); |
| |
| Map<String, Integer> tagToCount = new HashMap<>(); |
| tagToCount.put("red", 0); |
| tagToCount.put("green", 0); |
| |
| Map<String, Integer> tagToCredits = new HashMap<>(); |
| tagToCredits.put("red", INCREMENT); |
| tagToCredits.put("green", INCREMENT); |
| |
| long start = System.currentTimeMillis(); |
| while (System.currentTimeMillis() - start < TEST_INTERVAL) { |
| Map<String, Integer> resultMap = rateLimiter.acquire(tagToCredits, 20, MILLISECONDS); |
| tagToCount.put("red", tagToCount.get("red") + resultMap.get("red")); |
| tagToCount.put("green", tagToCount.get("green") + resultMap.get("green")); |
| } |
| |
| { |
| long rate = tagToCount.get("red") * 1000 / TEST_INTERVAL; |
| verifyRate(rate, TARGET_RATE_PER_TASK_RED); |
| } { |
| // Note: due to blocking, green is capped at red's QPS |
| long rate = tagToCount.get("green") * 1000 / TEST_INTERVAL; |
| verifyRate(rate, TARGET_RATE_PER_TASK_RED); |
| } |
| } |
| |
| @Test(expected = IllegalStateException.class) |
| public void testFailsWhenUninitializedTagged() { |
| Map<String, Integer> tagToTargetRateMap = new HashMap<>(); |
| tagToTargetRateMap.put("red", 1000); |
| tagToTargetRateMap.put("green", 2000); |
| new EmbeddedTaggedRateLimiter(tagToTargetRateMap).acquire(tagToTargetRateMap); |
| } |
| |
| @Test(expected = IllegalArgumentException.class) |
| public void testFailsWhenNotUsingTags() { |
| Map<String, Integer> tagToCredits = new HashMap<>(); |
| tagToCredits.put("red", 1); |
| tagToCredits.put("green", 1); |
| RateLimiter rateLimiter = new EmbeddedTaggedRateLimiter(tagToCredits); |
| initRateLimiter(rateLimiter); |
| rateLimiter.acquire(1); |
| } |
| |
| private void verifyRate(long rate, long targetRate) { |
| // As the actual rate would likely not be exactly the same as target rate, the calculation below |
| // verifies the actual rate is within 10% of the target rate per task |
| Assert.assertTrue(Math.abs(rate - targetRate) <= targetRate * 10 / 100); |
| } |
| |
| private RateLimiter createRateLimiter() { |
| Map<String, Integer> tagToTargetRateMap = new HashMap<>(); |
| tagToTargetRateMap.put("red", TARGET_RATE_RED); |
| tagToTargetRateMap.put("green", TARGET_RATE_GREEN); |
| RateLimiter rateLimiter = new EmbeddedTaggedRateLimiter(tagToTargetRateMap); |
| initRateLimiter(rateLimiter); |
| return rateLimiter; |
| } |
| |
| private void verifyRate(long rate) { |
| // As the actual rate would likely not be exactly the same as target rate, the calculation below |
| // verifies the actual rate is within 5% of the target rate per task |
| junit.framework.Assert.assertTrue(Math.abs(rate - TARGET_RATE_PER_TASK) <= TARGET_RATE_PER_TASK * 5 / 100); |
| } |
| |
| static void initRateLimiter(RateLimiter rateLimiter) { |
| Map<TaskName, TaskModel> tasks = IntStream.range(0, NUMBER_OF_TASKS) |
| .mapToObj(i -> new TaskName("task-" + i)) |
| .collect(Collectors.toMap(Function.identity(), x -> mock(TaskModel.class))); |
| ContainerModel containerModel = mock(ContainerModel.class); |
| when(containerModel.getTasks()).thenReturn(tasks); |
| JobModel jobModel = mock(JobModel.class); |
| Map<String, ContainerModel> containerModelMap = new HashMap<>(); |
| containerModelMap.put("container-1", containerModel); |
| when(jobModel.getContainers()).thenReturn(containerModelMap); |
| Context context = mock(Context.class); |
| TaskContextImpl taskContext = mock(TaskContextImpl.class); |
| when(context.getTaskContext()).thenReturn(taskContext); |
| when(taskContext.getJobModel()).thenReturn(jobModel); |
| when(context.getTaskContext().getTaskModel()).thenReturn(mock(TaskModel.class)); |
| rateLimiter.init(context); |
| } |
| } |