| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.seatunnel.translation.source; |
| |
| import org.apache.seatunnel.api.serialization.Serializer; |
| import org.apache.seatunnel.api.source.Collector; |
| import org.apache.seatunnel.api.source.SeaTunnelSource; |
| import org.apache.seatunnel.api.source.SourceEvent; |
| import org.apache.seatunnel.api.source.SourceReader; |
| import org.apache.seatunnel.api.source.SourceSplit; |
| import org.apache.seatunnel.api.source.SourceSplitEnumerator; |
| import org.apache.seatunnel.translation.util.ThreadPoolExecutorFactory; |
| |
| import java.io.IOException; |
| import java.io.Serializable; |
| import java.util.ArrayList; |
| import java.util.Collections; |
| import java.util.HashMap; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Set; |
| import java.util.concurrent.ConcurrentHashMap; |
| import java.util.concurrent.ScheduledThreadPoolExecutor; |
| import java.util.concurrent.atomic.AtomicBoolean; |
| import java.util.concurrent.atomic.AtomicInteger; |
| import java.util.stream.Collectors; |
| |
| public class CoordinatedSource<T, SplitT extends SourceSplit, StateT extends Serializable> implements BaseSourceFunction<T> { |
| protected static final long SLEEP_TIME_INTERVAL = 5L; |
| protected final SeaTunnelSource<T, SplitT, StateT> source; |
| protected final Map<Integer, List<byte[]>> restoredState; |
| protected final Integer parallelism; |
| |
| protected final Serializer<SplitT> splitSerializer; |
| protected final Serializer<StateT> enumeratorStateSerializer; |
| |
| protected final CoordinatedEnumeratorContext<SplitT> coordinatedEnumeratorContext; |
| protected final Map<Integer, CoordinatedReaderContext> readerContextMap; |
| protected final Map<Integer, List<SplitT>> restoredSplitStateMap = new HashMap<>(); |
| |
| protected transient volatile SourceSplitEnumerator<SplitT, StateT> splitEnumerator; |
| protected transient Map<Integer, SourceReader<T, SplitT>> readerMap = new ConcurrentHashMap<>(); |
| protected final Map<Integer, AtomicBoolean> readerRunningMap; |
| protected final AtomicInteger completedReader = new AtomicInteger(0); |
| protected transient volatile ScheduledThreadPoolExecutor executorService; |
| |
| /** |
| * Flag indicating whether the consumer is still running. |
| */ |
| protected volatile boolean running = true; |
| |
| public CoordinatedSource(SeaTunnelSource<T, SplitT, StateT> source, |
| Map<Integer, List<byte[]>> restoredState, |
| int parallelism) { |
| this.source = source; |
| this.restoredState = restoredState; |
| this.parallelism = parallelism; |
| this.splitSerializer = source.getSplitSerializer(); |
| this.enumeratorStateSerializer = source.getEnumeratorStateSerializer(); |
| |
| this.coordinatedEnumeratorContext = new CoordinatedEnumeratorContext<>(this); |
| this.readerContextMap = new ConcurrentHashMap<>(parallelism); |
| this.readerRunningMap = new ConcurrentHashMap<>(parallelism); |
| try { |
| createSplitEnumerator(); |
| createReaders(); |
| } catch (Exception e) { |
| e.printStackTrace(); |
| } |
| } |
| |
| private void createSplitEnumerator() throws Exception { |
| if (restoredState != null && restoredState.size() > 0) { |
| StateT restoredEnumeratorState = enumeratorStateSerializer.deserialize(restoredState.get(-1).get(0)); |
| splitEnumerator = source.restoreEnumerator(coordinatedEnumeratorContext, restoredEnumeratorState); |
| restoredState.forEach((subtaskId, splitBytes) -> { |
| if (subtaskId == -1) { |
| return; |
| } |
| List<SplitT> restoredSplitState = new ArrayList<>(splitBytes.size()); |
| for (byte[] splitByte : splitBytes) { |
| try { |
| restoredSplitState.add(splitSerializer.deserialize(splitByte)); |
| } catch (IOException e) { |
| throw new RuntimeException(e); |
| } |
| } |
| restoredSplitStateMap.put(subtaskId, restoredSplitState); |
| }); |
| } else { |
| splitEnumerator = source.createEnumerator(coordinatedEnumeratorContext); |
| } |
| } |
| |
| private void createReaders() throws Exception { |
| for (int subtaskId = 0; subtaskId < this.parallelism; subtaskId++) { |
| CoordinatedReaderContext readerContext = new CoordinatedReaderContext(this, source.getBoundedness(), subtaskId); |
| readerContextMap.put(subtaskId, readerContext); |
| readerRunningMap.put(subtaskId, new AtomicBoolean(true)); |
| SourceReader<T, SplitT> reader = source.createReader(readerContext); |
| readerMap.put(subtaskId, reader); |
| } |
| } |
| |
| @Override |
| public void open() throws Exception { |
| executorService = ThreadPoolExecutorFactory.createScheduledThreadPoolExecutor(parallelism, "parallel-split-enumerator-executor"); |
| splitEnumerator.open(); |
| restoredSplitStateMap.forEach((subtaskId, splits) -> { |
| splitEnumerator.addSplitsBack(splits, subtaskId); |
| }); |
| readerMap.entrySet().parallelStream().forEach(entry -> { |
| try { |
| entry.getValue().open(); |
| splitEnumerator.registerReader(entry.getKey()); |
| } catch (Exception e) { |
| throw new RuntimeException(e); |
| } |
| }); |
| } |
| |
| @Override |
| public void run(Collector<T> collector) throws Exception { |
| readerMap.entrySet().parallelStream().forEach(entry -> { |
| final AtomicBoolean flag = readerRunningMap.get(entry.getKey()); |
| final SourceReader<T, SplitT> reader = entry.getValue(); |
| executorService.execute(() -> { |
| while (flag.get()) { |
| try { |
| reader.pollNext(collector); |
| Thread.sleep(SLEEP_TIME_INTERVAL); |
| } catch (Exception e) { |
| running = false; |
| flag.set(false); |
| throw new RuntimeException(e); |
| } |
| } |
| }); |
| }); |
| splitEnumerator.run(); |
| while (running) { |
| Thread.sleep(SLEEP_TIME_INTERVAL); |
| } |
| } |
| |
| @Override |
| public void close() throws IOException { |
| running = false; |
| |
| for (Map.Entry<Integer, SourceReader<T, SplitT>> entry : readerMap.entrySet()) { |
| readerRunningMap.get(entry.getKey()).set(false); |
| entry.getValue().close(); |
| } |
| |
| if (executorService != null) { |
| executorService.shutdown(); |
| } |
| |
| try (SourceSplitEnumerator<SplitT, StateT> closed = splitEnumerator) { |
| // just close the resources |
| } |
| } |
| |
| // -------------------------------------------------------------------------------------------- |
| // Checkpoint & state |
| // -------------------------------------------------------------------------------------------- |
| |
| @Override |
| public Map<Integer, List<byte[]>> snapshotState(long checkpointId) throws Exception { |
| StateT enumeratorState = splitEnumerator.snapshotState(checkpointId); |
| byte[] enumeratorStateBytes = enumeratorStateSerializer.serialize(enumeratorState); |
| Map<Integer, List<byte[]>> allStates = readerMap.entrySet() |
| .parallelStream() |
| .collect(Collectors.toMap( |
| Map.Entry<Integer, SourceReader<T, SplitT>>::getKey, |
| readerEntry -> { |
| try { |
| List<SplitT> splitStates = readerEntry.getValue().snapshotState(checkpointId); |
| final List<byte[]> rawValues = new ArrayList<>(splitStates.size()); |
| for (SplitT splitState : splitStates) { |
| rawValues.add(splitSerializer.serialize(splitState)); |
| } |
| return rawValues; |
| } catch (Exception e) { |
| throw new RuntimeException(e); |
| } |
| })); |
| allStates.put(-1, Collections.singletonList(enumeratorStateBytes)); |
| return allStates; |
| } |
| |
| @Override |
| public void notifyCheckpointComplete(long checkpointId) throws Exception { |
| splitEnumerator.notifyCheckpointComplete(checkpointId); |
| readerMap.values().parallelStream().forEach(reader -> { |
| try { |
| reader.notifyCheckpointComplete(checkpointId); |
| } catch (Exception e) { |
| throw new RuntimeException(e); |
| } |
| }); |
| } |
| |
| @Override |
| public void notifyCheckpointAborted(long checkpointId) throws Exception { |
| splitEnumerator.notifyCheckpointAborted(checkpointId); |
| readerMap.values().parallelStream().forEach(reader -> { |
| try { |
| reader.notifyCheckpointAborted(checkpointId); |
| } catch (Exception e) { |
| throw new RuntimeException(e); |
| } |
| }); |
| } |
| |
| // -------------------------------------------------------------------------------------------- |
| // Reader context methods |
| // -------------------------------------------------------------------------------------------- |
| |
| protected void handleNoMoreElement(int subtaskId) { |
| readerRunningMap.get(subtaskId).set(false); |
| readerContextMap.remove(subtaskId); |
| if (completedReader.incrementAndGet() == this.parallelism) { |
| this.running = false; |
| } |
| } |
| |
| protected void handleSplitRequest(int subtaskId) { |
| splitEnumerator.handleSplitRequest(subtaskId); |
| } |
| |
| protected void handleReaderEvent(int subtaskId, SourceEvent event) { |
| splitEnumerator.handleSourceEvent(subtaskId, event); |
| } |
| |
| // -------------------------------------------------------------------------------------------- |
| // Enumerator context methods |
| // -------------------------------------------------------------------------------------------- |
| |
| public int currentReaderCount() { |
| return readerContextMap.size(); |
| } |
| |
| public Set<Integer> registeredReaders() { |
| return readerMap.keySet(); |
| } |
| |
| protected void addSplits(int subtaskId, List<SplitT> splits) { |
| readerMap.get(subtaskId).addSplits(splits); |
| } |
| |
| protected void handleNoMoreSplits(int subtaskId) { |
| readerMap.get(subtaskId).handleNoMoreSplits(); |
| } |
| |
| protected void handleEnumeratorEvent(int subtaskId, SourceEvent event) { |
| readerMap.get(subtaskId).handleSourceEvent(event); |
| } |
| } |