| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.runners.dataflow.worker; |
| |
| import static org.apache.beam.sdk.testing.SystemNanoTimeSleeper.sleepMillis; |
| import static org.hamcrest.Matchers.allOf; |
| import static org.hamcrest.Matchers.equalTo; |
| import static org.hamcrest.Matchers.greaterThan; |
| import static org.hamcrest.Matchers.greaterThanOrEqualTo; |
| import static org.hamcrest.Matchers.lessThan; |
| import static org.hamcrest.Matchers.not; |
| import static org.junit.Assert.assertEquals; |
| import static org.junit.Assert.assertFalse; |
| |
| import java.io.IOException; |
| import java.util.*; |
| import java.util.concurrent.*; |
| import java.util.concurrent.atomic.AtomicInteger; |
| import java.util.function.Consumer; |
| import java.util.function.Function; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkResponse; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationCommitWorkRequest; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataRequest; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataResponse; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; |
| import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub; |
| import org.apache.beam.vendor.guava.v20_0.com.google.common.net.HostAndPort; |
| import org.apache.beam.vendor.guava.v20_0.com.google.common.util.concurrent.Uninterruptibles; |
| import org.joda.time.Instant; |
| import org.junit.rules.ErrorCollector; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| /** An in-memory Windmill server that offers provided work and data. */ |
| class FakeWindmillServer extends WindmillServerStub { |
| private static final Logger LOG = LoggerFactory.getLogger(FakeWindmillServer.class); |
| |
| private final Queue<Windmill.GetWorkResponse> workToOffer; |
| private final Queue<Function<GetDataRequest, GetDataResponse>> dataToOffer; |
| // Keys are work tokens. |
| private final Map<Long, WorkItemCommitRequest> commitsReceived; |
| private final ArrayList<Windmill.ReportStatsRequest> statsReceived; |
| private final LinkedBlockingQueue<Windmill.Exception> exceptions; |
| private int commitsRequested = 0; |
| private int numGetDataRequests = 0; |
| private final AtomicInteger expectedExceptionCount; |
| private final ErrorCollector errorCollector; |
| private boolean isReady = true; |
| |
| public FakeWindmillServer(ErrorCollector errorCollector) { |
| workToOffer = new ConcurrentLinkedQueue<>(); |
| dataToOffer = new ConcurrentLinkedQueue<>(); |
| commitsReceived = new ConcurrentHashMap<>(); |
| exceptions = new LinkedBlockingQueue<>(); |
| expectedExceptionCount = new AtomicInteger(); |
| this.errorCollector = errorCollector; |
| statsReceived = new ArrayList<>(); |
| } |
| |
| public void addWorkToOffer(Windmill.GetWorkResponse work) { |
| workToOffer.add(work); |
| } |
| |
| public void addDataToOffer(Windmill.GetDataResponse data) { |
| dataToOffer.add((GetDataRequest request) -> data); |
| } |
| |
| public void addDataFnToOffer(Function<GetDataRequest, GetDataResponse> f) { |
| dataToOffer.add(f); |
| } |
| |
| @Override |
| public Windmill.GetWorkResponse getWork(Windmill.GetWorkRequest request) { |
| LOG.debug("getWorkRequest: {}", request.toString()); |
| Windmill.GetWorkResponse response = workToOffer.poll(); |
| if (response == null) { |
| return Windmill.GetWorkResponse.newBuilder().build(); |
| } |
| LOG.debug("getWorkResponse: {}", response.toString()); |
| return response; |
| } |
| |
| private void validateGetDataRequest(Windmill.GetDataRequest request) { |
| for (ComputationGetDataRequest computationRequest : request.getRequestsList()) { |
| for (KeyedGetDataRequest keyRequest : computationRequest.getRequestsList()) { |
| errorCollector.checkThat(keyRequest.hasWorkToken(), equalTo(true)); |
| errorCollector.checkThat( |
| keyRequest.getShardingKey(), allOf(greaterThan(0L), lessThan(Long.MAX_VALUE))); |
| errorCollector.checkThat(keyRequest.getMaxBytes(), greaterThanOrEqualTo(0L)); |
| } |
| } |
| } |
| |
| @Override |
| public Windmill.GetDataResponse getData(Windmill.GetDataRequest request) { |
| LOG.info("getDataRequest: {}", request.toString()); |
| validateGetDataRequest(request); |
| ++numGetDataRequests; |
| GetDataResponse response; |
| Function<GetDataRequest, GetDataResponse> responseFn = dataToOffer.poll(); |
| if (responseFn == null) { |
| response = Windmill.GetDataResponse.newBuilder().build(); |
| } else { |
| response = responseFn.apply(request); |
| try { |
| // Sleep for a little bit to ensure that *-windmill-read state-sampled counters |
| // show up. |
| sleepMillis(500); |
| } catch (InterruptedException e) { |
| Thread.currentThread().interrupt(); |
| } |
| } |
| LOG.debug("getDataResponse: {}", response.toString()); |
| return response; |
| } |
| |
| private void validateCommitWorkRequest(Windmill.CommitWorkRequest request) { |
| for (ComputationCommitWorkRequest computationRequest : request.getRequestsList()) { |
| for (WorkItemCommitRequest commit : computationRequest.getRequestsList()) { |
| errorCollector.checkThat(commit.hasWorkToken(), equalTo(true)); |
| errorCollector.checkThat( |
| commit.getShardingKey(), allOf(greaterThan(0L), lessThan(Long.MAX_VALUE))); |
| errorCollector.checkThat(commit.getCacheToken(), not(equalTo(0L))); |
| } |
| } |
| } |
| |
| @Override |
| public CommitWorkResponse commitWork(Windmill.CommitWorkRequest request) { |
| LOG.debug("commitWorkRequest: {}", request); |
| validateCommitWorkRequest(request); |
| for (ComputationCommitWorkRequest computationRequest : request.getRequestsList()) { |
| for (WorkItemCommitRequest commit : computationRequest.getRequestsList()) { |
| commitsReceived.put(commit.getWorkToken(), commit); |
| } |
| } |
| CommitWorkResponse response = CommitWorkResponse.newBuilder().build(); |
| LOG.debug("commitWorkResponse: {}", response); |
| return response; |
| } |
| |
| @Override |
| public Windmill.GetConfigResponse getConfig(Windmill.GetConfigRequest request) { |
| return Windmill.GetConfigResponse.newBuilder().build(); |
| } |
| |
| @Override |
| public Windmill.ReportStatsResponse reportStats(Windmill.ReportStatsRequest request) { |
| for (Windmill.Exception exception : request.getExceptionsList()) { |
| Uninterruptibles.putUninterruptibly(exceptions, exception); |
| } |
| |
| statsReceived.add(request); |
| if (request.getExceptionsList().isEmpty() || expectedExceptionCount.getAndDecrement() > 0) { |
| return Windmill.ReportStatsResponse.newBuilder().build(); |
| } else { |
| return Windmill.ReportStatsResponse.newBuilder().setFailed(true).build(); |
| } |
| } |
| |
| @Override |
| public long getAndResetThrottleTime() { |
| return (long) 0; |
| } |
| |
| @Override |
| public GetWorkStream getWorkStream(Windmill.GetWorkRequest request, WorkItemReceiver receiver) { |
| LOG.debug("getWorkStream: {}", request.toString()); |
| Instant startTime = Instant.now(); |
| final CountDownLatch done = new CountDownLatch(1); |
| return new GetWorkStream() { |
| @Override |
| public void closeAfterDefaultTimeout() { |
| while (done.getCount() > 0) { |
| Windmill.GetWorkResponse response = workToOffer.poll(); |
| if (response == null) { |
| try { |
| sleepMillis(500); |
| } catch (InterruptedException e) { |
| close(); |
| Thread.currentThread().interrupt(); |
| } |
| continue; |
| } |
| for (Windmill.ComputationWorkItems computationWork : response.getWorkList()) { |
| Instant inputDataWatermark = |
| WindmillTimeUtils.windmillToHarnessWatermark( |
| computationWork.getInputDataWatermark()); |
| for (Windmill.WorkItem workItem : computationWork.getWorkList()) { |
| receiver.receiveWork( |
| computationWork.getComputationId(), inputDataWatermark, Instant.now(), workItem); |
| } |
| } |
| } |
| } |
| |
| @Override |
| public void close() { |
| done.countDown(); |
| } |
| |
| @Override |
| public boolean awaitTermination(int time, TimeUnit unit) throws InterruptedException { |
| return done.await(time, unit); |
| } |
| |
| @Override |
| public Instant startTime() { |
| return startTime; |
| } |
| }; |
| } |
| |
| @Override |
| public GetDataStream getDataStream() { |
| Instant startTime = Instant.now(); |
| return new GetDataStream() { |
| @Override |
| public Windmill.KeyedGetDataResponse requestKeyedData( |
| String computation, KeyedGetDataRequest request) { |
| Windmill.GetDataRequest getDataRequest = |
| GetDataRequest.newBuilder() |
| .addRequests( |
| ComputationGetDataRequest.newBuilder() |
| .setComputationId(computation) |
| .addRequests(request) |
| .build()) |
| .build(); |
| GetDataResponse getDataResponse = getData(getDataRequest); |
| if (getDataResponse.getDataList().isEmpty()) { |
| return null; |
| } |
| assertEquals(1, getDataResponse.getDataCount()); |
| if (getDataResponse.getData(0).getDataList().isEmpty()) { |
| return null; |
| } |
| assertEquals(1, getDataResponse.getData(0).getDataCount()); |
| return getDataResponse.getData(0).getData(0); |
| } |
| |
| @Override |
| public Windmill.GlobalData requestGlobalData(Windmill.GlobalDataRequest request) { |
| Windmill.GetDataRequest getDataRequest = |
| GetDataRequest.newBuilder().addGlobalDataFetchRequests(request).build(); |
| GetDataResponse getDataResponse = getData(getDataRequest); |
| if (getDataResponse.getGlobalDataList().isEmpty()) { |
| return null; |
| } |
| assertEquals(1, getDataResponse.getGlobalDataCount()); |
| return getDataResponse.getGlobalData(0); |
| } |
| |
| @Override |
| public void refreshActiveWork(Map<String, List<KeyedGetDataRequest>> active) {} |
| |
| @Override |
| public void close() {} |
| |
| @Override |
| public boolean awaitTermination(int time, TimeUnit unit) { |
| return true; |
| } |
| |
| @Override |
| public void closeAfterDefaultTimeout() {} |
| |
| @Override |
| public Instant startTime() { |
| return startTime; |
| } |
| }; |
| } |
| |
| @Override |
| public CommitWorkStream commitWorkStream() { |
| Instant startTime = Instant.now(); |
| return new CommitWorkStream() { |
| @Override |
| public boolean commitWorkItem( |
| String computation, |
| WorkItemCommitRequest request, |
| Consumer<Windmill.CommitStatus> onDone) { |
| LOG.debug("commitWorkStream::commitWorkItem: {}", request); |
| errorCollector.checkThat(request.hasWorkToken(), equalTo(true)); |
| errorCollector.checkThat( |
| request.getShardingKey(), allOf(greaterThan(0L), lessThan(Long.MAX_VALUE))); |
| errorCollector.checkThat(request.getCacheToken(), not(equalTo(0L))); |
| commitsReceived.put(request.getWorkToken(), request); |
| onDone.accept(Windmill.CommitStatus.OK); |
| return true; // The request was accepted. |
| } |
| |
| @Override |
| public void flush() {} |
| |
| @Override |
| public void close() {} |
| |
| @Override |
| public boolean awaitTermination(int time, TimeUnit unit) { |
| return true; |
| } |
| |
| @Override |
| public void closeAfterDefaultTimeout() {} |
| |
| @Override |
| public Instant startTime() { |
| return startTime; |
| } |
| }; |
| } |
| |
| public void waitForEmptyWorkQueue() { |
| while (!workToOffer.isEmpty()) { |
| Uninterruptibles.sleepUninterruptibly(1000, TimeUnit.MILLISECONDS); |
| } |
| } |
| |
| public Map<Long, WorkItemCommitRequest> waitForAndGetCommits(int numCommits) { |
| LOG.debug("waitForAndGetCommitsRequest: {}", numCommits); |
| int maxTries = 10; |
| while (maxTries-- > 0 && commitsReceived.size() < commitsRequested + numCommits) { |
| Uninterruptibles.sleepUninterruptibly(1000, TimeUnit.MILLISECONDS); |
| } |
| |
| assertFalse( |
| "Should have received " |
| + numCommits |
| + " more commits beyond " |
| + commitsRequested |
| + " commits already seen, but after 10s have only seen " |
| + commitsReceived |
| + ". Exceptions seen: " |
| + exceptions, |
| commitsReceived.size() < commitsRequested + numCommits); |
| commitsRequested += numCommits; |
| |
| LOG.debug("waitForAndGetCommitsResponse: {}", commitsReceived); |
| return commitsReceived; |
| } |
| |
| public void setExpectedExceptionCount(int i) { |
| expectedExceptionCount.getAndAdd(i); |
| } |
| |
| public Windmill.Exception getException() throws InterruptedException { |
| return exceptions.take(); |
| } |
| |
| public int numGetDataRequests() { |
| return numGetDataRequests; |
| } |
| |
| public ArrayList<Windmill.ReportStatsRequest> getStatsReceived() { |
| return statsReceived; |
| } |
| |
| @Override |
| public void setWindmillServiceEndpoints(Set<HostAndPort> endpoints) throws IOException { |
| isReady = true; |
| } |
| |
| @Override |
| public boolean isReady() { |
| return isReady; |
| } |
| |
| public void setIsReady(boolean ready) { |
| this.isReady = ready; |
| } |
| } |