| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.samza.test.table; |
| |
| import com.google.common.cache.CacheBuilder; |
| |
| import java.io.IOException; |
| import java.io.ObjectInputStream; |
| import java.time.Duration; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.HashMap; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.concurrent.CompletableFuture; |
| import java.util.concurrent.Executors; |
| import java.util.concurrent.TimeUnit; |
| import java.util.concurrent.atomic.AtomicInteger; |
| import java.util.function.Function; |
| import java.util.stream.Collectors; |
| |
| import org.apache.samza.SamzaException; |
| import org.apache.samza.application.StreamApplication; |
| import org.apache.samza.application.descriptors.StreamApplicationDescriptor; |
| import org.apache.samza.config.Config; |
| import org.apache.samza.config.MapConfig; |
| import org.apache.samza.context.Context; |
| import org.apache.samza.context.MockContext; |
| import org.apache.samza.operators.TableImpl; |
| import org.apache.samza.operators.functions.MapFunction; |
| import org.apache.samza.system.descriptors.GenericInputDescriptor; |
| import org.apache.samza.metrics.Counter; |
| import org.apache.samza.metrics.MetricsRegistry; |
| import org.apache.samza.metrics.Timer; |
| import org.apache.samza.operators.KV; |
| import org.apache.samza.table.ReadWriteTable; |
| import org.apache.samza.table.descriptors.TableDescriptor; |
| import org.apache.samza.system.descriptors.DelegatingSystemDescriptor; |
| import org.apache.samza.runtime.LocalApplicationRunner; |
| import org.apache.samza.serializers.NoOpSerde; |
| import org.apache.samza.table.Table; |
| import org.apache.samza.table.descriptors.CachingTableDescriptor; |
| import org.apache.samza.table.descriptors.GuavaCacheTableDescriptor; |
| import org.apache.samza.table.remote.BaseTableFunction; |
| import org.apache.samza.table.remote.NoOpTableReadFunction; |
| import org.apache.samza.table.remote.RemoteTable; |
| import org.apache.samza.table.descriptors.RemoteTableDescriptor; |
| import org.apache.samza.table.remote.TableRateLimiter; |
| import org.apache.samza.table.remote.TableReadFunction; |
| import org.apache.samza.table.remote.TableWriteFunction; |
| import org.apache.samza.test.harness.IntegrationTestHarness; |
| import org.apache.samza.test.util.Base64Serializer; |
| import org.apache.samza.util.RateLimiter; |
| |
| import org.junit.Assert; |
| import org.junit.Test; |
| |
| import static org.apache.samza.test.table.TestTableData.EnrichedPageView; |
| import static org.apache.samza.test.table.TestTableData.PageView; |
| import static org.apache.samza.test.table.TestTableData.Profile; |
| import static org.apache.samza.test.table.TestTableData.generatePageViews; |
| import static org.apache.samza.test.table.TestTableData.generateProfiles; |
| |
| import static org.mockito.Matchers.any; |
| import static org.mockito.Matchers.anyString; |
| import static org.mockito.Mockito.*; |
| |
| |
| public class TestRemoteTableEndToEnd extends IntegrationTestHarness { |
| |
| static Map<String, AtomicInteger> counters = new HashMap<>(); |
| static Map<String, List<EnrichedPageView>> writtenRecords = new HashMap<>(); |
| |
| static class InMemoryProfileReadFunction extends BaseTableFunction |
| implements TableReadFunction<Integer, Profile> { |
| |
| private final String serializedProfiles; |
| private transient Map<Integer, Profile> profileMap; |
| |
| private InMemoryProfileReadFunction(String profiles) { |
| this.serializedProfiles = profiles; |
| } |
| |
| private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { |
| in.defaultReadObject(); |
| Profile[] profiles = Base64Serializer.deserialize(this.serializedProfiles, Profile[].class); |
| this.profileMap = Arrays.stream(profiles).collect(Collectors.toMap(p -> p.getMemberId(), Function.identity())); |
| } |
| |
| @Override |
| public CompletableFuture<Profile> getAsync(Integer key) { |
| return CompletableFuture.completedFuture(profileMap.get(key)); |
| } |
| |
| @Override |
| public CompletableFuture<Profile> getAsync(Integer key, Object ... args) { |
| Profile profile = profileMap.get(key); |
| boolean append = (boolean) args[0]; |
| if (append) { |
| profile = new Profile(profile.memberId, profile.company + "-r"); |
| } |
| return CompletableFuture.completedFuture(profile); |
| } |
| |
| @Override |
| public boolean isRetriable(Throwable exception) { |
| return false; |
| } |
| |
| static InMemoryProfileReadFunction getInMemoryReadFunction(String testName, String serializedProfiles) { |
| return new InMemoryProfileReadFunction(serializedProfiles); |
| } |
| } |
| |
| static class InMemoryEnrichedPageViewWriteFunction extends BaseTableFunction |
| implements TableWriteFunction<Integer, EnrichedPageView> { |
| |
| private String testName; |
| private transient List<EnrichedPageView> records; |
| |
| public InMemoryEnrichedPageViewWriteFunction(String testName) { |
| this.testName = testName; |
| } |
| |
| // Verify serializable functionality |
| private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { |
| in.defaultReadObject(); |
| |
| // Write to the global list for verification |
| records = writtenRecords.get(testName); |
| } |
| |
| @Override |
| public CompletableFuture<Void> putAsync(Integer key, EnrichedPageView record) { |
| System.out.println("==> " + testName + " writing " + record.getPageKey()); |
| records.add(record); |
| return CompletableFuture.completedFuture(null); |
| } |
| |
| @Override |
| public CompletableFuture<Void> putAsync(Integer key, EnrichedPageView record, Object ... args) { |
| boolean append = (boolean) args[0]; |
| if (append) { |
| record = new EnrichedPageView(record.pageKey, record.memberId, record.company + "-w"); |
| } |
| records.add(record); |
| return CompletableFuture.completedFuture(null); |
| } |
| |
| @Override |
| public CompletableFuture<Void> deleteAsync(Integer key) { |
| records.remove(key); |
| return CompletableFuture.completedFuture(null); |
| } |
| |
| @Override |
| public boolean isRetriable(Throwable exception) { |
| return false; |
| } |
| } |
| |
| static class InMemoryCounterReadFunction extends BaseTableFunction |
| implements TableReadFunction { |
| |
| private final String testName; |
| |
| private InMemoryCounterReadFunction(String testName) { |
| this.testName = testName; |
| } |
| |
| @Override |
| public CompletableFuture readAsync(int opId, Object... args) { |
| if (1 == opId) { |
| boolean shouldReturnValue = (boolean) args[0]; |
| AtomicInteger counter = counters.get(testName); |
| Integer result = shouldReturnValue ? counter.get() : null; |
| return CompletableFuture.completedFuture(result); |
| } else { |
| throw new SamzaException("Invalid opId: " + opId); |
| } |
| } |
| |
| @Override |
| public CompletableFuture getAsync(Object key) { |
| throw new SamzaException("Not supported"); |
| } |
| |
| @Override |
| public boolean isRetriable(Throwable exception) { |
| return false; |
| } |
| } |
| |
| static class InMemoryCounterWriteFunction extends BaseTableFunction |
| implements TableWriteFunction { |
| |
| private final String testName; |
| |
| private InMemoryCounterWriteFunction(String testName) { |
| this.testName = testName; |
| } |
| |
| @Override |
| public CompletableFuture<Void> putAsync(Object key, Object record) { |
| throw new SamzaException("Not supported"); |
| } |
| |
| @Override |
| public CompletableFuture<Void> deleteAsync(Object key) { |
| throw new SamzaException("Not supported"); |
| } |
| |
| @Override |
| public CompletableFuture writeAsync(int opId, Object... args) { |
| Integer result; |
| AtomicInteger counter = counters.get(testName); |
| boolean shouldModify = (boolean) args[0]; |
| switch (opId) { |
| case 1: |
| result = shouldModify ? counter.incrementAndGet() : counter.get(); |
| break; |
| case 2: |
| result = shouldModify ? counter.decrementAndGet() : counter.get(); |
| break; |
| default: |
| throw new SamzaException("Invalid opId: " + opId); |
| } |
| return CompletableFuture.completedFuture(result); |
| } |
| |
| @Override |
| public boolean isRetriable(Throwable exception) { |
| return false; |
| } |
| } |
| |
| static private class TestReadWriteMapFunction implements MapFunction<PageView, PageView> { |
| |
| private final String counterTableName; |
| private ReadWriteTable counterTable; |
| |
| private TestReadWriteMapFunction(String counterTableName) { |
| this.counterTableName = counterTableName; |
| } |
| |
| @Override |
| public void init(Context context) { |
| counterTable = context.getTaskContext().getTable(counterTableName); |
| } |
| |
| @Override |
| public PageView apply(PageView pageView) { |
| try { |
| // Counter manipulation |
| badOpId(); |
| Assert.assertNull(getCounterValue(false)); |
| Integer beforeValue = getCounterValue(true); |
| Assert.assertEquals(beforeValue, incCounterValue(false)); |
| Assert.assertEquals(beforeValue, decCounterValue(false)); |
| Assert.assertEquals(Integer.valueOf(beforeValue + 1), incCounterValue(true)); |
| Assert.assertEquals(beforeValue, decCounterValue(true)); |
| Assert.assertEquals(beforeValue, getCounterValue(true)); |
| incCounterValue(true); |
| return pageView; |
| } catch (Exception ex) { |
| throw new SamzaException(ex); |
| } |
| } |
| |
| private Integer getCounterValue(boolean shouldReturn) { |
| return (Integer) counterTable.readAsync(1, shouldReturn).join(); |
| } |
| |
| private Integer incCounterValue(boolean shouldModifdy) { |
| return (Integer) counterTable.writeAsync(1, shouldModifdy).join(); |
| } |
| |
| private Integer decCounterValue(boolean shouldModifdy) { |
| return (Integer) counterTable.writeAsync(2, shouldModifdy).join(); |
| } |
| |
| private void badOpId() { |
| try { |
| counterTable.readAsync(0).join(); |
| Assert.fail("Shouldn't reach here"); |
| } catch (SamzaException ex) { |
| // Expected exception |
| } |
| } |
| } |
| |
| private <K, V> Table<KV<K, V>> getCachingTable(TableDescriptor<K, V, ?> actualTableDesc, boolean defaultCache, |
| StreamApplicationDescriptor appDesc) { |
| String id = actualTableDesc.getTableId(); |
| CachingTableDescriptor<K, V> cachingDesc; |
| if (defaultCache) { |
| cachingDesc = new CachingTableDescriptor<>("caching-table-" + id, actualTableDesc); |
| cachingDesc.withReadTtl(Duration.ofMinutes(5)); |
| cachingDesc.withWriteTtl(Duration.ofMinutes(5)); |
| } else { |
| GuavaCacheTableDescriptor<K, V> guavaTableDesc = new GuavaCacheTableDescriptor<>("guava-table-" + id); |
| guavaTableDesc.withCache(CacheBuilder.newBuilder().expireAfterAccess(5, TimeUnit.MINUTES).build()); |
| cachingDesc = new CachingTableDescriptor<>("caching-table-" + id, actualTableDesc, guavaTableDesc); |
| } |
| |
| return appDesc.getTable(cachingDesc); |
| } |
| |
| private void doTestStreamTableJoinRemoteTable(boolean withCache, boolean defaultCache, boolean withArgs, String testName) |
| throws Exception { |
| |
| writtenRecords.put(testName, new ArrayList<>()); |
| |
| int count = 10; |
| final PageView[] pageViews = generatePageViews(count); |
| final String profiles = Base64Serializer.serialize(generateProfiles(count)); |
| |
| final int partitionCount = 4; |
| final Map<String, String> configs = TestLocalTableEndToEnd.getBaseJobConfig(bootstrapUrl(), zkConnect()); |
| configs.put("streams.PageView.samza.system", "test"); |
| configs.put("streams.PageView.source", Base64Serializer.serialize(pageViews)); |
| configs.put("streams.PageView.partitionCount", String.valueOf(partitionCount)); |
| |
| final RateLimiter readRateLimiter = mock(RateLimiter.class, withSettings().serializable()); |
| final TableRateLimiter.CreditFunction creditFunction = (k, v, args) -> 1; |
| final StreamApplication app = appDesc -> { |
| |
| final RemoteTableDescriptor joinTableDesc = |
| new RemoteTableDescriptor<Integer, TestTableData.Profile>("profile-table-1") |
| .withReadFunction(InMemoryProfileReadFunction.getInMemoryReadFunction(testName, profiles)) |
| .withRateLimiter(readRateLimiter, creditFunction, null); |
| |
| final RemoteTableDescriptor outputTableDesc = |
| new RemoteTableDescriptor<Integer, EnrichedPageView>("enriched-page-view-table-1") |
| .withReadFunction(new NoOpTableReadFunction<>()) |
| .withReadRateLimiterDisabled() |
| .withWriteFunction(new InMemoryEnrichedPageViewWriteFunction(testName)) |
| .withWriteRateLimit(1000); |
| |
| final Table<KV<Integer, EnrichedPageView>> outputTable = withCache |
| ? getCachingTable(outputTableDesc, defaultCache, appDesc) |
| : appDesc.getTable(outputTableDesc); |
| |
| final Table<KV<Integer, Profile>> joinTable = withCache |
| ? getCachingTable(joinTableDesc, defaultCache, appDesc) |
| : appDesc.getTable(joinTableDesc); |
| |
| final DelegatingSystemDescriptor ksd = new DelegatingSystemDescriptor("test"); |
| final GenericInputDescriptor<PageView> isd = ksd.getInputDescriptor("PageView", new NoOpSerde<>()); |
| |
| if (!withArgs) { |
| appDesc.getInputStream(isd) |
| .map(pv -> new KV<>(pv.getMemberId(), pv)) |
| .join(joinTable, new PageViewToProfileJoinFunction()) |
| .map(m -> new KV(m.getMemberId(), m)) |
| .sendTo(outputTable); |
| |
| } else { |
| counters.put(testName, new AtomicInteger()); |
| |
| final RemoteTableDescriptor counterTableDesc = |
| new RemoteTableDescriptor("counter-table-1") |
| .withReadFunction(new InMemoryCounterReadFunction(testName)) |
| .withWriteFunction(new InMemoryCounterWriteFunction(testName)) |
| .withRateLimiterDisabled(); |
| |
| final Table counterTable = withCache |
| ? getCachingTable(counterTableDesc, defaultCache, appDesc) |
| : appDesc.getTable(counterTableDesc); |
| |
| final String counterTableName = ((TableImpl) counterTable).getTableId(); |
| appDesc.getInputStream(isd) |
| .map(new TestReadWriteMapFunction(counterTableName)) |
| .map(pv -> new KV<>(pv.getMemberId(), pv)) |
| .join(joinTable, new PageViewToProfileJoinFunction(), true) |
| .map(m -> new KV(m.getMemberId(), m)) |
| .sendTo(outputTable, true); |
| } |
| }; |
| |
| final Config config = new MapConfig(configs); |
| final LocalApplicationRunner runner = new LocalApplicationRunner(app, config); |
| executeRun(runner, config); |
| runner.waitForFinish(); |
| |
| final int numExpected = count * partitionCount; |
| Assert.assertEquals(numExpected, writtenRecords.get(testName).size()); |
| Assert.assertTrue(writtenRecords.get(testName).get(0) instanceof EnrichedPageView); |
| if (!withArgs) { |
| writtenRecords.get(testName).forEach(epv -> Assert.assertFalse(epv.company.contains("-"))); |
| } else { |
| writtenRecords.get(testName).forEach(epv -> Assert.assertTrue(epv.company.endsWith("-r-w"))); |
| Assert.assertEquals(numExpected, counters.get(testName).get()); |
| } |
| } |
| |
| @Test |
| public void testStreamTableJoinRemoteTable() throws Exception { |
| doTestStreamTableJoinRemoteTable(false, false, false, "testStreamTableJoinRemoteTable"); |
| } |
| |
| @Test |
| public void testStreamTableJoinRemoteTableWithCache() throws Exception { |
| doTestStreamTableJoinRemoteTable(true, false, false, "testStreamTableJoinRemoteTableWithCache"); |
| } |
| |
| @Test |
| public void testStreamTableJoinRemoteTableWithDefaultCache() throws Exception { |
| doTestStreamTableJoinRemoteTable(true, true, false, "testStreamTableJoinRemoteTableWithDefaultCache"); |
| } |
| |
| @Test |
| public void testStreamTableJoinRemoteTableWithArgs() throws Exception { |
| doTestStreamTableJoinRemoteTable(false, false, true, "testStreamTableJoinRemoteTableWithArgs"); |
| } |
| |
| @Test |
| public void testStreamTableJoinRemoteTableWithCacheWithArgs() throws Exception { |
| doTestStreamTableJoinRemoteTable(true, false, true, "testStreamTableJoinRemoteTableWithCacheWithArgs"); |
| } |
| |
| @Test |
| public void testStreamTableJoinRemoteTableWithDefaultCacheWithArgs() throws Exception { |
| doTestStreamTableJoinRemoteTable(true, true, true, "testStreamTableJoinRemoteTableWithDefaultCacheWithArgs"); |
| } |
| |
| private Context createMockContext() { |
| MetricsRegistry metricsRegistry = mock(MetricsRegistry.class); |
| doReturn(new Counter("")).when(metricsRegistry).newCounter(anyString(), anyString()); |
| doReturn(new Timer("")).when(metricsRegistry).newTimer(anyString(), anyString()); |
| Context context = new MockContext(); |
| doReturn(new MapConfig()).when(context.getJobContext()).getConfig(); |
| doReturn(metricsRegistry).when(context.getContainerContext()).getContainerMetricsRegistry(); |
| return context; |
| } |
| |
| @Test(expected = SamzaException.class) |
| public void testCatchReaderException() { |
| TableReadFunction<String, ?> reader = mock(TableReadFunction.class); |
| CompletableFuture<String> future = new CompletableFuture<>(); |
| future.completeExceptionally(new RuntimeException("Expected test exception")); |
| doReturn(future).when(reader).getAsync(anyString()); |
| TableRateLimiter rateLimitHelper = mock(TableRateLimiter.class); |
| RemoteTable<String, String> table = new RemoteTable<>("table1", reader, null, |
| rateLimitHelper, null, Executors.newSingleThreadExecutor(), |
| null, null, null, |
| null, null, null); |
| table.init(createMockContext()); |
| table.get("abc"); |
| } |
| |
| @Test(expected = SamzaException.class) |
| public void testCatchWriterException() { |
| TableReadFunction<String, String> reader = mock(TableReadFunction.class); |
| TableWriteFunction<String, String> writer = mock(TableWriteFunction.class); |
| CompletableFuture<String> future = new CompletableFuture<>(); |
| future.completeExceptionally(new RuntimeException("Expected test exception")); |
| doReturn(future).when(writer).putAsync(anyString(), any()); |
| TableRateLimiter rateLimitHelper = mock(TableRateLimiter.class); |
| RemoteTable<String, String> table = new RemoteTable<String, String>("table1", reader, writer, |
| rateLimitHelper, rateLimitHelper, Executors.newSingleThreadExecutor(), |
| null, null, null, |
| null, null, null); |
| table.init(createMockContext()); |
| table.put("abc", "efg"); |
| } |
| |
| @Test |
| public void testUninitializedWriter() { |
| TableReadFunction<String, String> reader = mock(TableReadFunction.class); |
| TableRateLimiter rateLimitHelper = mock(TableRateLimiter.class); |
| RemoteTable<String, String> table = new RemoteTable<String, String>("table1", reader, null, |
| rateLimitHelper, null, Executors.newSingleThreadExecutor(), |
| null, null, null, |
| null, null, null); |
| table.init(createMockContext()); |
| try { |
| table.put("abc", "efg"); |
| Assert.fail(); |
| } catch (SamzaException ex) { |
| // Ignore |
| } |
| try { |
| table.delete("abc"); |
| Assert.fail(); |
| } catch (SamzaException ex) { |
| // Ignore |
| } |
| table.flush(); |
| table.close(); |
| } |
| } |