blob: 278c9dd9a824669f4323f19f6caf0bb81f6dfdd5 [file] [log] [blame]
/*
* Copyright 2017 HugeGraph Authors
*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with this
* work for additional information regarding copyright ownership. The ASF
* licenses this file to You under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.baidu.hugegraph.computer.core.util;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import org.slf4j.Logger;
import com.baidu.hugegraph.computer.core.common.exception.ComputerException;
import com.baidu.hugegraph.util.Log;
public final class Consumers<V> {
private static final int CPU_CORE_NUM = 4;
public static final int THREADS = 4 + CPU_CORE_NUM / 4;
public static final int QUEUE_WORKER_SIZE = 1000;
public static final long CONSUMER_WAKE_PERIOD = 1;
private static final Logger LOG = Log.logger(Consumers.class);
private final ExecutorService executor;
private final Consumer<V> consumer;
private final Runnable done;
private final int workers;
private final int queueSize;
private final CountDownLatch latch;
private final BlockingQueue<V> queue;
private volatile boolean ending = false;
private volatile Throwable exception = null;
public Consumers(ExecutorService executor, Consumer<V> consumer) {
this(executor, consumer, null);
}
public Consumers(ExecutorService executor,
Consumer<V> consumer, Runnable done) {
this.executor = executor;
this.consumer = consumer;
this.done = done;
int workers = THREADS;
if (this.executor instanceof ThreadPoolExecutor) {
workers = ((ThreadPoolExecutor) this.executor).getCorePoolSize();
}
this.workers = workers;
this.queueSize = QUEUE_WORKER_SIZE * workers;
this.latch = new CountDownLatch(workers);
this.queue = new ArrayBlockingQueue<>(this.queueSize);
}
public void start(String name) {
this.ending = false;
this.exception = null;
if (this.executor == null) {
return;
}
LOG.info("Starting {} workers[{}] with queue size {}...",
this.workers, name, this.queueSize);
for (int i = 0; i < this.workers; i++) {
this.executor.execute(this::runAndDone);
}
}
private void runAndDone() {
try {
this.run();
} catch (Throwable e) {
// Only the first exception of one thread can be stored
this.exception = e;
if (!(e instanceof StopExecution)) {
LOG.error("Error when running task", e);
}
} finally {
this.done();
this.latch.countDown();
}
}
private void run() {
LOG.debug("Start to work...");
while (!this.ending) {
this.consume();
}
assert this.ending;
while (this.consume()) {
}
LOG.debug("Worker finished");
}
private boolean consume() {
V elem;
try {
elem = this.queue.poll(CONSUMER_WAKE_PERIOD, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
// ignore
return true;
}
if (elem == null) {
return false;
}
// do job
this.consumer.accept(elem);
return true;
}
private void done() {
if (this.done == null) {
return;
}
try {
this.done.run();
} catch (Throwable e) {
if (this.exception == null) {
this.exception = e;
} else {
LOG.warn("Error while calling done()", e);
}
}
}
private Throwable throwException() {
assert this.exception != null;
Throwable e = this.exception;
this.exception = null;
return e;
}
public void provide(V v) throws Throwable {
if (this.executor == null) {
assert this.exception == null;
// do job directly if without thread pool
this.consumer.accept(v);
} else if (this.exception != null) {
throw this.throwException();
} else {
try {
this.queue.put(v);
} catch (InterruptedException e) {
LOG.warn("Interrupted while enqueue", e);
}
}
}
public void await() throws Throwable {
this.ending = true;
if (this.executor == null) {
// call done() directly if without thread pool
this.done();
} else {
try {
this.latch.await();
} catch (InterruptedException e) {
String error = "Interrupted while waiting for consumers";
this.exception = new ComputerException(error, e);
LOG.warn(error, e);
}
}
if (this.exception != null) {
throw this.throwException();
}
}
public ExecutorService executor() {
return this.executor;
}
public static RuntimeException wrapException(Throwable e) {
if (e instanceof RuntimeException) {
throw (RuntimeException) e;
}
throw new ComputerException("Error when running task: %s",
ComputerException.rootCause(e).getMessage(),
e);
}
public static class StopExecution extends ComputerException {
private static final long serialVersionUID = -371829356182454517L;
public StopExecution(String message) {
super(message);
}
public StopExecution(String message, Object... args) {
super(message, args);
}
}
}