blob: 2b7cf54a8068c6615e2be63512bf1b91fcf28cf9 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.reef.vortex.driver;
import org.apache.reef.driver.task.RunningTask;
import org.apache.reef.io.serialization.Codec;
import org.apache.reef.io.serialization.SerializableCodec;
import org.apache.reef.tang.Tang;
import org.apache.reef.tang.exceptions.InjectionException;
import org.apache.reef.vortex.util.VoidCodec;
import org.apache.reef.util.Optional;
import org.apache.reef.vortex.api.VortexFunction;
import org.apache.reef.vortex.api.VortexFuture;
import org.apache.reef.vortex.common.*;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Utility methods for tests.
*/
public final class TestUtil {
private static final Codec<Void> VOID_CODEC = new VoidCodec();
private static final Codec<Integer> INTEGER_CODEC = new SerializableCodec<>();
private final AtomicInteger taskletId = new AtomicInteger(0);
private final AtomicInteger workerId = new AtomicInteger(0);
private final Executor executor = Executors.newFixedThreadPool(5);
private final VortexMaster vortexMaster = mock(VortexMaster.class);
/**
* @return a new mocked worker, with a mocked {@link VortexMaster}.
*/
public VortexWorkerManager newWorker() {
return newWorker(vortexMaster);
}
/**
* @return a new mocked worker, with the {@link VortexMaster} passed in.
*/
public VortexWorkerManager newWorker(final VortexMaster master) {
final RunningTask reefTask = mock(RunningTask.class);
when(reefTask.getId()).thenReturn("worker" + String.valueOf(workerId.getAndIncrement()));
final VortexRequestor vortexRequestor = mock(VortexRequestor.class);
final VortexWorkerManager workerManager = new VortexWorkerManager(vortexRequestor, reefTask);
doAnswer(new Answer() {
@Override
public Object answer(final InvocationOnMock invocation) throws Throwable {
final VortexRequest request = (VortexRequest)invocation.getArguments()[1];
if (request instanceof TaskletCancellationRequest) {
final TaskletReport cancelReport = new TaskletCancelledReport(
((TaskletCancellationRequest)request).getTaskletId());
master.workerReported(workerManager.getId(), new WorkerReport(Collections.singleton(cancelReport)));
}
return null;
}
}).when(vortexRequestor).sendAsync(any(RunningTask.class), any(VortexRequest.class));
return workerManager;
}
/**
* @return a new dummy tasklet.
*/
public Tasklet newTasklet() {
final int id = taskletId.getAndIncrement();
return new Tasklet(id, Optional.empty(), null, null, new VortexFuture(executor, vortexMaster, id, VOID_CODEC));
}
/**
* @return a new {@link AggregateFunctionRepository}
*/
public AggregateFunctionRepository newAggregateFunctionRepository() throws InjectionException {
return Tang.Factory.getTang().newInjector().getInstance(AggregateFunctionRepository.class);
}
/**
* @return a new dummy function.
*/
public VortexFunction<Void, Void> newFunction() {
return new VortexFunction<Void, Void>() {
@Override
public Void call(final Void input) throws Exception {
return null;
}
@Override
public Codec getInputCodec() {
return VOID_CODEC;
}
@Override
public Codec getOutputCodec() {
return VOID_CODEC;
}
};
}
/**
* @return a queryable {@link org.apache.reef.vortex.driver.TestUtil.TestSchedulingPolicy}
*/
public TestSchedulingPolicy newSchedulingPolicy() {
return new TestSchedulingPolicy();
}
/**
* @return a new dummy function.
*/
public VortexFunction<Void, Void> newInfiniteLoopFunction() {
return new VortexFunction<Void, Void>() {
@Override
public Void call(final Void input) throws Exception {
while(true) {
Thread.sleep(100);
if (Thread.currentThread().isInterrupted()) {
throw new InterruptedException();
}
}
}
@Override
public Codec getInputCodec() {
return VOID_CODEC;
}
@Override
public Codec getOutputCodec() {
return VOID_CODEC;
}
};
}
/**
* @return a dummy integer-integer function.
*/
public VortexFunction<Integer, Integer> newIntegerFunction() {
return new VortexFunction<Integer, Integer>() {
@Override
public Integer call(final Integer input) throws Exception {
return 1;
}
@Override
public Codec<Integer> getInputCodec() {
return INTEGER_CODEC;
}
@Override
public Codec<Integer> getOutputCodec() {
return INTEGER_CODEC;
}
};
}
static final class TestSchedulingPolicy implements SchedulingPolicy {
private final SchedulingPolicy policy = new RandomSchedulingPolicy();
private final Set<Integer> doneTasklets = new HashSet<>();
private TestSchedulingPolicy() {
}
@Override
public Optional<String> trySchedule(final Tasklet tasklet) {
return policy.trySchedule(tasklet);
}
@Override
public void workerAdded(final VortexWorkerManager vortexWorker) {
policy.workerAdded(vortexWorker);
}
@Override
public void workerRemoved(final VortexWorkerManager vortexWorker) {
policy.workerRemoved(vortexWorker);
}
@Override
public void taskletLaunched(final VortexWorkerManager vortexWorker, final Tasklet tasklet) {
policy.taskletLaunched(vortexWorker, tasklet);
}
@Override
public void taskletsDone(final VortexWorkerManager vortexWorker, final List<Tasklet> tasklets) {
policy.taskletsDone(vortexWorker, tasklets);
for (final Tasklet t : tasklets) {
doneTasklets.add(t.getId());
}
}
/**
* @return true if Tasklet with taskletId is done, false otherwise.
*/
public boolean taskletIsDone(final int taskletId) {
return doneTasklets.contains(taskletId);
}
}
}