| /** |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * <p> |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * <p> |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package com.twitter.distributedlog.service.placement; |
| |
| |
| import static org.junit.Assert.assertEquals; |
| import static org.junit.Assert.assertTrue; |
| import static org.junit.Assert.fail; |
| import static org.mockito.Matchers.anyString; |
| import static org.mockito.Mockito.mock; |
| import static org.mockito.Mockito.verify; |
| import static org.mockito.Mockito.when; |
| |
| import com.twitter.distributedlog.client.routing.RoutingService; |
| import com.twitter.distributedlog.namespace.DistributedLogNamespace; |
| import com.twitter.util.Await; |
| import com.twitter.util.Duration; |
| import com.twitter.util.Future; |
| import java.io.IOException; |
| import java.net.InetSocketAddress; |
| import java.net.SocketAddress; |
| import java.util.LinkedHashSet; |
| import java.util.Random; |
| import java.util.Set; |
| import java.util.TreeSet; |
| import java.util.concurrent.atomic.AtomicInteger; |
| import org.apache.bookkeeper.stats.NullStatsLogger; |
| import org.junit.Test; |
| import org.mockito.ArgumentCaptor; |
| import org.mockito.invocation.InvocationOnMock; |
| import org.mockito.stubbing.Answer; |
| |
| /** |
| * Test Case for {@link LeastLoadPlacementPolicy}. |
| */ |
| public class TestLeastLoadPlacementPolicy { |
| |
| @Test(timeout = 10000) |
| public void testCalculateBalances() throws Exception { |
| int numSevers = new Random().nextInt(20) + 1; |
| int numStreams = new Random().nextInt(200) + 1; |
| RoutingService mockRoutingService = mock(RoutingService.class); |
| DistributedLogNamespace mockNamespace = mock(DistributedLogNamespace.class); |
| LeastLoadPlacementPolicy leastLoadPlacementPolicy = new LeastLoadPlacementPolicy( |
| new EqualLoadAppraiser(), |
| mockRoutingService, |
| mockNamespace, |
| null, |
| Duration.fromSeconds(600), |
| new NullStatsLogger()); |
| TreeSet<ServerLoad> serverLoads = |
| Await.result(leastLoadPlacementPolicy.calculate(generateServers(numSevers), generateStreams(numStreams))); |
| long lowLoadPerServer = numStreams / numSevers; |
| long highLoadPerServer = lowLoadPerServer + 1; |
| for (ServerLoad serverLoad : serverLoads) { |
| long load = serverLoad.getLoad(); |
| assertEquals(load, serverLoad.getStreamLoads().size()); |
| assertTrue(String.format("Load %d is not between %d and %d", |
| load, lowLoadPerServer, highLoadPerServer), load == lowLoadPerServer || load == highLoadPerServer); |
| } |
| } |
| |
| @Test(timeout = 10000) |
| public void testRefreshAndPlaceStream() throws Exception { |
| int numSevers = new Random().nextInt(20) + 1; |
| int numStreams = new Random().nextInt(200) + 1; |
| RoutingService mockRoutingService = mock(RoutingService.class); |
| when(mockRoutingService.getHosts()).thenReturn(generateSocketAddresses(numSevers)); |
| DistributedLogNamespace mockNamespace = mock(DistributedLogNamespace.class); |
| try { |
| when(mockNamespace.getLogs()).thenReturn(generateStreams(numStreams).iterator()); |
| } catch (IOException e) { |
| fail(); |
| } |
| PlacementStateManager mockPlacementStateManager = mock(PlacementStateManager.class); |
| LeastLoadPlacementPolicy leastLoadPlacementPolicy = new LeastLoadPlacementPolicy( |
| new EqualLoadAppraiser(), |
| mockRoutingService, |
| mockNamespace, |
| mockPlacementStateManager, |
| Duration.fromSeconds(600), |
| new NullStatsLogger()); |
| leastLoadPlacementPolicy.refresh(); |
| |
| final ArgumentCaptor<TreeSet> captor = ArgumentCaptor.forClass(TreeSet.class); |
| verify(mockPlacementStateManager).saveOwnership(captor.capture()); |
| TreeSet<ServerLoad> serverLoads = (TreeSet<ServerLoad>) captor.getValue(); |
| ServerLoad next = serverLoads.first(); |
| String serverPlacement = Await.result(leastLoadPlacementPolicy.placeStream("newstream1")); |
| assertEquals(next.getServer(), serverPlacement); |
| } |
| |
| @Test(timeout = 10000) |
| public void testCalculateUnequalWeight() throws Exception { |
| int numSevers = new Random().nextInt(20) + 1; |
| int numStreams = new Random().nextInt(200) + 1; |
| /* use AtomicInteger to have a final object in answer method */ |
| final AtomicInteger maxLoad = new AtomicInteger(Integer.MIN_VALUE); |
| RoutingService mockRoutingService = mock(RoutingService.class); |
| DistributedLogNamespace mockNamespace = mock(DistributedLogNamespace.class); |
| LoadAppraiser mockLoadAppraiser = mock(LoadAppraiser.class); |
| when(mockLoadAppraiser.getStreamLoad(anyString())).then(new Answer<Future<StreamLoad>>() { |
| @Override |
| public Future<StreamLoad> answer(InvocationOnMock invocationOnMock) throws Throwable { |
| int load = new Random().nextInt(100000); |
| if (load > maxLoad.get()) { |
| maxLoad.set(load); |
| } |
| return Future.value(new StreamLoad(invocationOnMock.getArguments()[0].toString(), load)); |
| } |
| }); |
| LeastLoadPlacementPolicy leastLoadPlacementPolicy = new LeastLoadPlacementPolicy( |
| mockLoadAppraiser, |
| mockRoutingService, |
| mockNamespace, |
| null, |
| Duration.fromSeconds(600), |
| new NullStatsLogger()); |
| TreeSet<ServerLoad> serverLoads = |
| Await.result(leastLoadPlacementPolicy.calculate(generateServers(numSevers), generateStreams(numStreams))); |
| long highestLoadSeen = Long.MIN_VALUE; |
| long lowestLoadSeen = Long.MAX_VALUE; |
| for (ServerLoad serverLoad : serverLoads) { |
| long load = serverLoad.getLoad(); |
| if (load < lowestLoadSeen) { |
| lowestLoadSeen = load; |
| } |
| if (load > highestLoadSeen) { |
| highestLoadSeen = load; |
| } |
| } |
| assertTrue("Unexpected placement for " + numStreams + " streams to " |
| + numSevers + " servers : highest load = " + highestLoadSeen |
| + ", lowest load = " + lowestLoadSeen + ", max stream load = " + maxLoad.get(), |
| highestLoadSeen - lowestLoadSeen < maxLoad.get()); |
| } |
| |
| private Set<SocketAddress> generateSocketAddresses(int num) { |
| LinkedHashSet<SocketAddress> socketAddresses = new LinkedHashSet<SocketAddress>(); |
| for (int i = 0; i < num; i++) { |
| socketAddresses.add(new InetSocketAddress(i)); |
| } |
| return socketAddresses; |
| } |
| |
| private Set<String> generateStreams(int num) { |
| LinkedHashSet<String> streams = new LinkedHashSet<String>(); |
| for (int i = 0; i < num; i++) { |
| streams.add("stream_" + i); |
| } |
| return streams; |
| } |
| |
| private Set<String> generateServers(int num) { |
| LinkedHashSet<String> servers = new LinkedHashSet<String>(); |
| for (int i = 0; i < num; i++) { |
| servers.add("server_" + i); |
| } |
| return servers; |
| } |
| } |