blob: efe1acf22963df2c4044d57d289d49b4ab8e943f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.samza.table.remote;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ThreadPoolExecutor;
import org.apache.samza.container.SamzaContainerContext;
import org.apache.samza.container.TaskName;
import org.apache.samza.metrics.Counter;
import org.apache.samza.metrics.MetricsRegistry;
import org.apache.samza.metrics.Timer;
import org.apache.samza.table.Table;
import org.apache.samza.table.TableSpec;
import org.apache.samza.table.retry.RetriableReadFunction;
import org.apache.samza.table.retry.RetriableWriteFunction;
import org.apache.samza.table.retry.TableRetryPolicy;
import org.apache.samza.task.TaskContext;
import org.apache.samza.util.EmbeddedTaggedRateLimiter;
import org.apache.samza.util.RateLimiter;
import org.junit.Assert;
import org.junit.Test;
import static org.apache.samza.table.remote.RemoteTableDescriptor.RL_READ_TAG;
import static org.apache.samza.table.remote.RemoteTableDescriptor.RL_WRITE_TAG;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
public class TestRemoteTableDescriptor {
private void doTestSerialize(RateLimiter rateLimiter,
TableRateLimiter.CreditFunction readCredFn,
TableRateLimiter.CreditFunction writeCredFn) {
RemoteTableDescriptor desc = new RemoteTableDescriptor("1");
desc.withReadFunction(mock(TableReadFunction.class));
desc.withWriteFunction(mock(TableWriteFunction.class));
if (rateLimiter != null) {
desc.withRateLimiter(rateLimiter, readCredFn, writeCredFn);
} else {
desc.withReadRateLimit(100);
desc.withWriteRateLimit(200);
}
TableSpec spec = desc.getTableSpec();
Assert.assertTrue(spec.getConfig().containsKey(RemoteTableProvider.RATE_LIMITER));
Assert.assertEquals(readCredFn != null, spec.getConfig().containsKey(RemoteTableProvider.READ_CREDIT_FN));
Assert.assertEquals(writeCredFn != null, spec.getConfig().containsKey(RemoteTableProvider.WRITE_CREDIT_FN));
}
@Test
public void testSerializeSimple() {
doTestSerialize(null, null, null);
}
@Test
public void testSerializeWithLimiter() {
doTestSerialize(mock(RateLimiter.class), null, null);
}
@Test
public void testSerializeWithLimiterAndReadCredFn() {
doTestSerialize(mock(RateLimiter.class), (k, v) -> 1, null);
}
@Test
public void testSerializeWithLimiterAndWriteCredFn() {
doTestSerialize(mock(RateLimiter.class), null, (k, v) -> 1);
}
@Test
public void testSerializeWithLimiterAndReadWriteCredFns() {
doTestSerialize(mock(RateLimiter.class), (key, value) -> 1, (key, value) -> 1);
}
@Test
public void testSerializeNullWriteFunction() {
RemoteTableDescriptor desc = new RemoteTableDescriptor("1");
desc.withReadFunction(mock(TableReadFunction.class));
TableSpec spec = desc.getTableSpec();
Assert.assertTrue(spec.getConfig().containsKey(RemoteTableProvider.READ_FN));
Assert.assertFalse(spec.getConfig().containsKey(RemoteTableProvider.WRITE_FN));
}
@Test(expected = NullPointerException.class)
public void testSerializeNullReadFunction() {
RemoteTableDescriptor desc = new RemoteTableDescriptor("1");
TableSpec spec = desc.getTableSpec();
Assert.assertTrue(spec.getConfig().containsKey(RemoteTableProvider.READ_FN));
}
@Test(expected = IllegalArgumentException.class)
public void testSpecifyBothRateAndRateLimiter() {
RemoteTableDescriptor desc = new RemoteTableDescriptor("1");
desc.withReadFunction(mock(TableReadFunction.class));
desc.withReadRateLimit(100);
desc.withRateLimiter(mock(RateLimiter.class), null, null);
desc.getTableSpec();
}
private TaskContext createMockTaskContext() {
MetricsRegistry metricsRegistry = mock(MetricsRegistry.class);
doReturn(mock(Timer.class)).when(metricsRegistry).newTimer(anyString(), anyString());
doReturn(mock(Counter.class)).when(metricsRegistry).newCounter(anyString(), anyString());
TaskContext taskContext = mock(TaskContext.class);
doReturn(metricsRegistry).when(taskContext).getMetricsRegistry();
SamzaContainerContext containerCtx = new SamzaContainerContext(
"1", null, Collections.singleton(new TaskName("MyTask")), null);
doReturn(containerCtx).when(taskContext).getSamzaContainerContext();
return taskContext;
}
static class CountingCreditFunction<K, V> implements TableRateLimiter.CreditFunction<K, V> {
int numCalls = 0;
@Override
public int getCredits(K key, V value) {
numCalls++;
return 1;
}
}
private void doTestDeserializeReadFunctionAndLimiter(boolean rateOnly, boolean rlGets, boolean rlPuts) {
int numRateLimitOps = (rlGets ? 1 : 0) + (rlPuts ? 1 : 0);
RemoteTableDescriptor<String, String> desc = new RemoteTableDescriptor("1");
TableRetryPolicy retryPolicy = new TableRetryPolicy();
retryPolicy.withRetryPredicate((ex) -> false);
desc.withReadFunction(mock(TableReadFunction.class), retryPolicy);
desc.withWriteFunction(mock(TableWriteFunction.class));
desc.withAsyncCallbackExecutorPoolSize(10);
if (rateOnly) {
if (rlGets) {
desc.withReadRateLimit(1000);
}
if (rlPuts) {
desc.withWriteRateLimit(2000);
}
} else {
if (numRateLimitOps > 0) {
Map<String, Integer> tagCredits = new HashMap<>();
if (rlGets) {
tagCredits.put(RL_READ_TAG, 1000);
}
if (rlPuts) {
tagCredits.put(RL_WRITE_TAG, 2000);
}
// Spy the rate limiter to verify call count
RateLimiter rateLimiter = spy(new EmbeddedTaggedRateLimiter(tagCredits));
desc.withRateLimiter(rateLimiter, new CountingCreditFunction(), new CountingCreditFunction());
}
}
TableSpec spec = desc.getTableSpec();
RemoteTableProvider provider = new RemoteTableProvider(spec);
provider.init(mock(SamzaContainerContext.class), createMockTaskContext());
Table table = provider.getTable();
Assert.assertTrue(table instanceof RemoteReadWriteTable);
RemoteReadWriteTable rwTable = (RemoteReadWriteTable) table;
if (numRateLimitOps > 0) {
Assert.assertTrue(!rlGets || rwTable.readRateLimiter != null);
Assert.assertTrue(!rlPuts || rwTable.writeRateLimiter != null);
}
ThreadPoolExecutor callbackExecutor = (ThreadPoolExecutor) rwTable.callbackExecutor;
Assert.assertEquals(10, callbackExecutor.getCorePoolSize());
Assert.assertNotNull(rwTable.readFn instanceof RetriableReadFunction);
Assert.assertNotNull(!(rwTable.writeFn instanceof RetriableWriteFunction));
}
@Test
public void testDeserializeReadFunctionNoRateLimit() {
doTestDeserializeReadFunctionAndLimiter(false, false, false);
}
@Test
public void testDeserializeReadFunctionAndLimiterWrite() {
doTestDeserializeReadFunctionAndLimiter(false, false, true);
}
@Test
public void testDeserializeReadFunctionAndLimiterRead() {
doTestDeserializeReadFunctionAndLimiter(false, true, false);
}
@Test
public void testDeserializeReadFunctionAndLimiterReadWrite() {
doTestDeserializeReadFunctionAndLimiter(false, true, true);
}
@Test
public void testDeserializeReadFunctionAndLimiterRateOnlyWrite() {
doTestDeserializeReadFunctionAndLimiter(true, false, true);
}
@Test
public void testDeserializeReadFunctionAndLimiterRateOnlyRead() {
doTestDeserializeReadFunctionAndLimiter(true, true, false);
}
@Test
public void testDeserializeReadFunctionAndLimiterRateOnlyReadWrite() {
doTestDeserializeReadFunctionAndLimiter(true, true, true);
}
}