blob: 60b5a163dfa636a744a9e801fbb2f970ce0731c7 [file] [log] [blame]
/****************************************************************
* Licensed to the Apache Software Foundation (ASF) under one *
* or more contributor license agreements. See the NOTICE file *
* distributed with this work for additional information *
* regarding copyright ownership. The ASF licenses this file *
* to you under the Apache License, Version 2.0 (the *
* "License"); you may not use this file except in compliance *
* with the License. You may obtain a copy of the License at *
* *
* http://www.apache.org/licenses/LICENSE-2.0 *
* *
* Unless required by applicable law or agreed to in writing, *
* software distributed under the License is distributed on an *
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY *
* KIND, either express or implied. See the License for the *
* specific language governing permissions and limitations *
* under the License. *
****************************************************************/
package org.apache.james.task;
import java.io.IOException;
import java.time.Duration;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import org.apache.james.util.MDCBuilder;
import org.apache.james.util.concurrent.NamedThreadFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.collect.Sets;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;
public class SerialTaskManagerWorker implements TaskManagerWorker {
private static final Logger LOGGER = LoggerFactory.getLogger(SerialTaskManagerWorker.class);
private final ExecutorService taskExecutor;
private final Listener listener;
private final AtomicReference<Tuple2<TaskId, Future<?>>> runningTask;
private final Semaphore semaphore;
private final Set<TaskId> cancelledTasks;
public SerialTaskManagerWorker(Listener listener) {
this.taskExecutor = Executors.newSingleThreadExecutor(NamedThreadFactory.withName("task executor"));
this.listener = listener;
this.cancelledTasks = Sets.newConcurrentHashSet();
this.runningTask = new AtomicReference<>();
this.semaphore = new Semaphore(1);
}
@Override
public Mono<Task.Result> executeTask(TaskWithId taskWithId) {
return Mono
.using(
acquireSemaphore(taskWithId, listener),
executeWithSemaphore(taskWithId, listener),
Semaphore::release);
}
private Callable<Semaphore> acquireSemaphore(TaskWithId taskWithId, Listener listener) {
return () -> {
try {
semaphore.acquire();
return semaphore;
} catch (InterruptedException e) {
listener.cancelled(taskWithId.getId(), taskWithId.getTask().details());
throw e;
}
};
}
private Function<Semaphore, Mono<Task.Result>> executeWithSemaphore(TaskWithId taskWithId, Listener listener) {
return semaphore -> {
if (!cancelledTasks.remove(taskWithId.getId())) {
CompletableFuture<Task.Result> future = CompletableFuture.supplyAsync(() -> runWithMdc(taskWithId, listener), taskExecutor);
runningTask.set(Tuples.of(taskWithId.getId(), future));
return Mono.using(
() -> pollAdditionalInformation(taskWithId).subscribe(),
ignored -> Mono.fromFuture(future)
.doOnError(exception -> handleExecutionError(taskWithId, listener, exception))
.onErrorReturn(Task.Result.PARTIAL),
polling -> polling.dispose());
} else {
listener.cancelled(taskWithId.getId(), taskWithId.getTask().details());
return Mono.empty();
}
};
}
private void handleExecutionError(TaskWithId taskWithId, Listener listener, Throwable exception) {
if (exception instanceof CancellationException) {
listener.cancelled(taskWithId.getId(), taskWithId.getTask().details());
} else {
listener.failed(taskWithId.getId(), taskWithId.getTask().details(), exception);
}
}
private Flux<TaskExecutionDetails.AdditionalInformation> pollAdditionalInformation(TaskWithId taskWithId) {
return Mono.fromCallable(() -> taskWithId.getTask().details())
.delayElement(Duration.ofSeconds(1))
.repeat()
.flatMap(Mono::justOrEmpty)
.doOnNext(information -> listener.updated(taskWithId.getId(), information));
}
private Task.Result runWithMdc(TaskWithId taskWithId, Listener listener) {
return MDCBuilder.withMdc(
MDCBuilder.create()
.addContext(Task.TASK_ID, taskWithId.getId())
.addContext(Task.TASK_TYPE, taskWithId.getTask().type())
.addContext(Task.TASK_DETAILS, taskWithId.getTask().details()),
() -> run(taskWithId, listener));
}
private Task.Result run(TaskWithId taskWithId, Listener listener) {
listener.started(taskWithId.getId());
try {
return taskWithId.getTask()
.run()
.onComplete(result -> listener.completed(taskWithId.getId(), result, taskWithId.getTask().details()))
.onFailure(() -> {
LOGGER.error("Task was partially performed. Check logs for more details. Taskid : " + taskWithId.getId());
listener.failed(taskWithId.getId(), taskWithId.getTask().details());
});
} catch (InterruptedException e) {
listener.cancelled(taskWithId.getId(), taskWithId.getTask().details());
return Task.Result.PARTIAL;
} catch (Exception e) {
LOGGER.error("Error while running task {}", taskWithId.getId(), e);
listener.failed(taskWithId.getId(), taskWithId.getTask().details(), e);
return Task.Result.PARTIAL;
}
}
@Override
public void cancelTask(TaskId taskId) {
cancelledTasks.add(taskId);
Optional.ofNullable(runningTask.get())
.filter(task -> task.getT1().equals(taskId))
.ifPresent(task -> task.getT2().cancel(true));
}
@Override
public void fail(TaskId taskId, Optional<TaskExecutionDetails.AdditionalInformation> additionalInformation, String errorMessage, Throwable reason) {
listener.failed(taskId, additionalInformation, errorMessage, reason);
}
@Override
public void close() throws IOException {
taskExecutor.shutdownNow();
}
}