blob: 23997fc349f0f13618a7c07a69e27510e899675f [file] [log] [blame]
/*
* Copyright 2017 HugeGraph Authors
*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with this
* work for additional information regarding copyright ownership. The ASF
* licenses this file to You under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.baidu.hugegraph.computer.core.network.session;
import java.nio.ByteBuffer;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Function;
import org.slf4j.Logger;
import com.baidu.hugegraph.computer.core.common.exception.ComputerException;
import com.baidu.hugegraph.computer.core.common.exception.TransportException;
import com.baidu.hugegraph.computer.core.network.TransportConf;
import com.baidu.hugegraph.computer.core.network.TransportState;
import com.baidu.hugegraph.computer.core.network.buffer.NetworkBuffer;
import com.baidu.hugegraph.computer.core.network.buffer.NioBuffer;
import com.baidu.hugegraph.computer.core.network.message.AbstractMessage;
import com.baidu.hugegraph.computer.core.network.message.DataMessage;
import com.baidu.hugegraph.computer.core.network.message.FinishMessage;
import com.baidu.hugegraph.computer.core.network.message.Message;
import com.baidu.hugegraph.computer.core.network.message.MessageType;
import com.baidu.hugegraph.computer.core.network.message.StartMessage;
import com.baidu.hugegraph.util.E;
import com.baidu.hugegraph.util.Log;
public class ClientSession extends TransportSession {
private static final Logger LOG = Log.logger(ClientSession.class);
private final int maxPendingRequests;
private final int minPendingRequests;
private final Lock lock;
private volatile boolean flowBlocking;
private final AtomicReference<CompletableFuture<Void>> startedFutureRef;
private final AtomicReference<CompletableFuture<Void>> finishedFutureRef;
private final Function<Message, Future<Void>> sendFunction;
public ClientSession(TransportConf conf,
Function<Message, Future<Void>> sendFunction) {
super(conf);
this.maxPendingRequests = this.conf.maxPendingRequests();
this.minPendingRequests = this.conf.minPendingRequests();
this.lock = new ReentrantLock();
this.flowBlocking = false;
this.startedFutureRef = new AtomicReference<>();
this.finishedFutureRef = new AtomicReference<>();
this.sendFunction = sendFunction;
}
@Override
protected void stateReady() {
this.flowBlocking = false;
super.stateReady();
}
private void stateStartSent() {
this.maxRequestId = AbstractMessage.START_SEQ;
this.state = TransportState.START_SENT;
}
private void stateFinishSent(int finishId) {
this.finishId = finishId;
this.state = TransportState.FINISH_SENT;
}
public synchronized void start(long timeout) throws TransportException {
CompletableFuture<Void> startFuture = this.startAsync();
try {
startFuture.get(timeout, TimeUnit.MILLISECONDS);
} catch (Throwable e) {
this.stateReady();
if (e instanceof TimeoutException) {
throw new TransportException(
"Timeout(%sms) to wait start-response", timeout);
} else {
throw new TransportException("Failed to wait start-response",
e);
}
} finally {
startFuture.cancel(false);
this.startedFutureRef.compareAndSet(startFuture, null);
}
}
public synchronized CompletableFuture<Void> startAsync() {
E.checkArgument(this.state == TransportState.READY,
"The state must be READY instead of %s " +
"at startAsync()", this.state);
CompletableFuture<Void> startedFuture = new CompletableFuture<>();
boolean success = this.startedFutureRef.compareAndSet(null,
startedFuture);
E.checkArgument(success, "The startedFutureRef value must be null " +
"at startAsync()");
this.stateStartSent();
try {
this.sendFunction.apply(StartMessage.INSTANCE);
} catch (Throwable e) {
this.stateReady();
startedFuture.cancel(false);
this.startedFutureRef.compareAndSet(startedFuture, null);
throw e;
}
return startedFuture;
}
public synchronized void finish(long timeout) throws TransportException {
CompletableFuture<Void> finishFuture = this.finishAsync();
try {
finishFuture.get(timeout, TimeUnit.MILLISECONDS);
} catch (Throwable e) {
this.stateEstablished();
if (e instanceof TimeoutException) {
throw new TransportException(
"Timeout(%sms) to wait finish-response", timeout);
} else {
throw new TransportException("Failed to wait finish-response",
e);
}
} finally {
this.finishedFutureRef.compareAndSet(finishFuture, null);
}
}
public synchronized CompletableFuture<Void> finishAsync() {
E.checkArgument(this.state == TransportState.ESTABLISHED,
"The state must be ESTABLISHED instead of %s " +
"at finishAsync()", this.state);
CompletableFuture<Void> finishedFuture = new CompletableFuture<>();
boolean success = this.finishedFutureRef.compareAndSet(null,
finishedFuture);
E.checkArgument(success, "The finishedFutureRef value must be null " +
"at finishAsync()");
int finishId = this.genFinishId();
this.stateFinishSent(finishId);
try {
FinishMessage finishMessage = new FinishMessage(finishId);
this.sendFunction.apply(finishMessage);
} catch (Throwable e) {
this.stateEstablished();
finishedFuture.cancel(false);
this.finishedFutureRef.compareAndSet(finishedFuture, null);
throw e;
}
return finishedFuture;
}
public synchronized void sendAsync(MessageType messageType, int partition,
ByteBuffer buffer) {
E.checkArgument(this.state == TransportState.ESTABLISHED,
"The state must be ESTABLISHED instead of %s " +
"at sendAsync()", this.state);
int requestId = this.nextRequestId();
NetworkBuffer networkBuffer = new NioBuffer(buffer);
DataMessage dataMessage = new DataMessage(messageType, requestId,
partition, networkBuffer);
this.sendFunction.apply(dataMessage);
this.updateFlowBlocking();
}
public void onRecvAck(int ackId) {
switch (this.state) {
case START_SENT:
if (ackId == AbstractMessage.START_SEQ) {
this.onRecvStartAck();
break;
}
case FINISH_SENT:
if (ackId == this.finishId) {
this.onRecvFinishAck();
} else {
this.onRecvDataAck(ackId);
}
break;
case ESTABLISHED:
this.onRecvDataAck(ackId);
break;
default:
throw new ComputerException("Receive one ack message, but " +
"the state not match, state: %s, " +
"ackId: %s", this.state, ackId);
}
}
private void onRecvStartAck() {
E.checkArgument(this.state == TransportState.START_SENT,
"The state must be START_SENT instead of %s " +
"at completeStateStart()", this.state);
this.maxAckId = AbstractMessage.START_SEQ;
this.stateEstablished();
CompletableFuture<Void> startedFuture = this.startedFutureRef.get();
if (startedFuture != null) {
if (!startedFuture.isCancelled()) {
boolean complete = startedFuture.complete(null);
if (!complete) {
LOG.warn("The startedFuture can't be completed");
}
}
this.startedFutureRef.compareAndSet(startedFuture, null);
}
}
private void onRecvFinishAck() {
E.checkArgument(this.state == TransportState.FINISH_SENT,
"The state must be FINISH_SENT instead of %s " +
"at completeStateFinish()", this.state);
this.stateReady();
CompletableFuture<Void> finishedFuture = this.finishedFutureRef.get();
if (finishedFuture != null) {
if (!finishedFuture.isCancelled()) {
boolean complete = finishedFuture.complete(null);
if (!complete) {
LOG.warn("The finishedFuture can't be completed");
}
}
this.finishedFutureRef.compareAndSet(finishedFuture, null);
}
}
private void onRecvDataAck(int ackId) {
if (ackId > this.maxAckId) {
this.maxAckId = ackId;
}
this.updateFlowBlocking();
}
public boolean flowBlocking() {
return this.flowBlocking;
}
private void updateFlowBlocking() {
this.lock.lock();
try {
int pendingRequests = this.maxRequestId - this.maxAckId;
if (pendingRequests >= this.maxPendingRequests) {
this.flowBlocking = true;
} else if (pendingRequests < this.minPendingRequests) {
this.flowBlocking = false;
}
} finally {
this.lock.unlock();
}
}
}