blob: 73e1f8bafeb4efdc7e4689f4498456ba8c01e2a4 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE
* file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file
* to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
* License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package org.apache.tuweni.rlpx.wire;
import org.apache.tuweni.bytes.Bytes;
import org.apache.tuweni.concurrent.AsyncCompletion;
import org.apache.tuweni.concurrent.CompletableAsyncCompletion;
import org.apache.tuweni.concurrent.CompletableAsyncResult;
import org.apache.tuweni.crypto.SECP256K1;
import org.apache.tuweni.rlpx.RLPxMessage;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import com.google.common.collect.Range;
import com.google.common.collect.RangeMap;
import com.google.common.collect.TreeRangeMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A stateful connection between two peers under the Devp2p wire protocol.
*/
public final class DefaultWireConnection implements WireConnection {
private final static Logger logger = LoggerFactory.getLogger(DefaultWireConnection.class);
private final Bytes nodeId;
private final Bytes peerNodeId;
private final Consumer<RLPxMessage> writer;
private final Consumer<HelloMessage> afterHandshakeListener;
private final Runnable disconnectHandler;
private final LinkedHashMap<SubProtocolIdentifier, SubProtocolHandler> subprotocols;
private final int p2pVersion;
private final String clientId;
private final int advertisedPort;
private final CompletableAsyncResult<WireConnection> ready;
private final String peerHost;
private final int peerPort;
private final AtomicBoolean disconnectRequested = new AtomicBoolean(false);
private CompletableAsyncCompletion awaitingPong;
private HelloMessage myHelloMessage;
private HelloMessage peerHelloMessage;
private RangeMap<Integer, SubProtocolIdentifier> subprotocolRangeMap = TreeRangeMap.create();
private DisconnectReason disconnectReason;
private boolean disconnectReceived;
private EventListener eventListener;
/**
* Default constructor.
*
* @param nodeId the node id of this node
* @param peerNodeId the node id of the peer
* @param writer the message writer
* @param afterHandshakeListener a listener called after the handshake is complete with the peer hello message.
* @param disconnectHandler the handler to run upon receiving a disconnect message
* @param subprotocols the subprotocols supported by this connection
* @param p2pVersion the version of the devp2p protocol supported by this client
* @param clientId the client ID to announce in HELLO messages
* @param advertisedPort the port we listen to, to announce in HELLO messages
* @param ready a handle to complete when the connection is ready for use.
* @param peerHost the peer's host
* @param peerPort the peer's port
*/
public DefaultWireConnection(
Bytes nodeId,
Bytes peerNodeId,
Consumer<RLPxMessage> writer,
Consumer<HelloMessage> afterHandshakeListener,
Runnable disconnectHandler,
LinkedHashMap<SubProtocolIdentifier, SubProtocolHandler> subprotocols,
int p2pVersion,
String clientId,
int advertisedPort,
CompletableAsyncResult<WireConnection> ready,
String peerHost,
int peerPort) {
this.nodeId = nodeId;
this.peerNodeId = peerNodeId;
this.writer = writer;
this.afterHandshakeListener = afterHandshakeListener;
this.disconnectHandler = disconnectHandler;
this.subprotocols = subprotocols;
this.p2pVersion = p2pVersion;
this.clientId = clientId;
this.advertisedPort = advertisedPort;
this.ready = ready;
this.peerHost = peerHost;
this.peerPort = peerPort;
logger.debug("New wire connection created");
}
public void messageReceived(RLPxMessage message) {
if (message.messageId() == 0) {
peerHelloMessage = HelloMessage.read(message.content());
logger.debug("Received peer Hello message {}", peerHelloMessage);
initSupportedRange(peerHelloMessage.capabilities());
if (peerHelloMessage.nodeId() == null || peerHelloMessage.nodeId().isEmpty()) {
disconnect(DisconnectReason.NULL_NODE_IDENTITY_RECEIVED);
ready.complete(this);
return;
}
if (!peerHelloMessage.nodeId().equals(peerNodeId)) {
disconnect(DisconnectReason.UNEXPECTED_IDENTITY);
ready.complete(this);
return;
}
if (peerHelloMessage.nodeId().equals(nodeId)) {
disconnect(DisconnectReason.CONNECTED_TO_SELF);
ready.complete(this);
return;
}
if (peerHelloMessage.p2pVersion() > p2pVersion) {
disconnect(DisconnectReason.INCOMPATIBLE_DEVP2P_VERSION);
ready.complete(this);
return;
}
if (subprotocolRangeMap.asMapOfRanges().isEmpty()) {
logger
.debug(
"Useless peer detected, caps {}, our caps {}",
peerHelloMessage.capabilities(),
subprotocols.keySet());
disconnect(DisconnectReason.USELESS_PEER);
ready.complete(this);
return;
}
if (myHelloMessage == null) {
sendHello();
}
afterHandshakeListener.accept(peerHelloMessage);
AsyncCompletion allSubProtocols = AsyncCompletion
.allOf(
subprotocolRangeMap
.asMapOfRanges()
.values()
.stream()
.map(subprotocols::get)
.map(handler -> handler.handleNewPeerConnection(this)));
allSubProtocols.thenRun(() -> {
ready.complete(this);
eventListener.onEvent(Event.CONNECTED);
});
return;
} else if (message.messageId() == 1) {
DisconnectMessage disconnect = DisconnectMessage.read(message.content());
logger.debug("Received disconnect {}", disconnect);
disconnectReceived = true;
disconnectReason = DisconnectReason.valueOf(disconnect.reason());
disconnectHandler.run();
if (!ready.isDone()) {
ready.complete(this); // Return the connection as is.
}
eventListener.onEvent(Event.DISCONNECTED);
return;
}
if (peerHelloMessage == null || myHelloMessage == null) {
logger.debug("Message sent before hello exchanged {}", message.messageId());
disconnect(DisconnectReason.PROTOCOL_BREACH);
}
if (message.messageId() == 2) {
sendPong();
} else if (message.messageId() == 3) {
if (awaitingPong != null) {
awaitingPong.complete();
}
} else {
Map.Entry<Range<Integer>, SubProtocolIdentifier> subProtocolEntry =
subprotocolRangeMap.getEntry(message.messageId());
if (subProtocolEntry == null) {
logger.debug("Unknown message received {}", message.messageId());
disconnect(DisconnectReason.PROTOCOL_BREACH);
if (!ready.isDone()) {
ready.complete(this);
}
} else {
int offset = subProtocolEntry.getKey().lowerEndpoint();
logger.trace("Received message of type {}", message.messageId() - offset);
SubProtocolHandler handler = subprotocols.get(subProtocolEntry.getValue());
handler
.handle(this, message.messageId() - offset, message.content())
.exceptionally(t -> logger.error("Handler " + handler.toString() + " threw an exception", t));
}
}
}
void initSupportedRange(List<Capability> capabilities) {
int startRange = 16;
Map<String, SubProtocolIdentifier> pickedCapabilities = new HashMap<>();
// find the max capability supported by the subprotocol
for (SubProtocolIdentifier sp : subprotocols.keySet()) {
for (Capability cap : capabilities) {
if (sp.equals(SubProtocolIdentifier.of(cap.name(), cap.version()))) {
SubProtocolIdentifier oldPick = pickedCapabilities.get(cap.name());
if (oldPick == null || oldPick.version() < cap.version()) {
pickedCapabilities.put(cap.name(), sp);
}
}
}
}
for (Capability cap : capabilities) {
SubProtocolIdentifier capSp = SubProtocolIdentifier.of(cap.name(), cap.version());
if (!Objects.equals(pickedCapabilities.get(cap.name()), capSp)) {
continue;
}
for (SubProtocolIdentifier sp : subprotocols.keySet()) {
if (sp.equals(capSp)) {
int numberOfMessageTypes = sp.versionRange();
subprotocolRangeMap.put(Range.closedOpen(startRange, startRange + numberOfMessageTypes), sp);
startRange += numberOfMessageTypes;
break;
}
}
}
}
/**
* Sends a message to the peer explaining that we are about to disconnect.
*
* @param reason the reason for disconnection
*/
@Override
public void disconnect(DisconnectReason reason) {
if (disconnectRequested.compareAndSet(false, true)) {
logger.debug("Sending disconnect message with reason {}", reason);
writer.accept(new RLPxMessage(1, new DisconnectMessage(reason).toBytes()));
disconnectReason = reason;
disconnectHandler.run();
eventListener.onEvent(Event.DISCONNECTED);
}
}
/**
* Sends a ping message to the remote peer.
*
* @return a handler marking completion when a pong response is received
*/
public AsyncCompletion sendPing() {
logger.debug("Sending ping message");
writer.accept(new RLPxMessage(2, Bytes.EMPTY));
this.awaitingPong = AsyncCompletion.incomplete();
return awaitingPong;
}
private void sendPong() {
logger.debug("Sending pong message");
writer.accept(new RLPxMessage(3, Bytes.EMPTY));
}
void sendHello() {
myHelloMessage = HelloMessage
.create(
nodeId,
advertisedPort,
p2pVersion,
clientId,
subprotocols
.keySet()
.stream()
.map(
subProtocolIdentifier -> new Capability(
subProtocolIdentifier.name(),
subProtocolIdentifier.version()))
.collect(Collectors.toList()));
logger.debug("Sending hello message {}", myHelloMessage);
writer.accept(new RLPxMessage(0, myHelloMessage.toBytes()));
}
@Override
public boolean supports(SubProtocolIdentifier subProtocolIdentifier) {
for (SubProtocolIdentifier sp : subprotocolRangeMap.asMapOfRanges().values()) {
if (sp.equals(subProtocolIdentifier)) {
return true;
}
}
return false;
}
@Override
public Collection<SubProtocolIdentifier> agreedSubprotocols() {
List<SubProtocolIdentifier> identifiers = new ArrayList<>();
for (SubProtocolIdentifier sp : subprotocolRangeMap.asMapOfRanges().values()) {
identifiers.add(sp);
}
return identifiers;
}
public void sendMessage(SubProtocolIdentifier subProtocolIdentifier, int messageType, Bytes message) {
logger.trace("Sending sub-protocol message {} {}", messageType, message);
Integer offset = null;
for (Map.Entry<Range<Integer>, SubProtocolIdentifier> entry : subprotocolRangeMap.asMapOfRanges().entrySet()) {
if (entry.getValue().equals(subProtocolIdentifier)) {
offset = entry.getKey().lowerEndpoint();
break;
}
}
if (offset == null) {
throw new UnsupportedOperationException(); // no subprotocol mapped to this connection. Exit.
}
writer.accept(new RLPxMessage(messageType + offset, message));
}
public void handleConnectionStart() {
sendHello();
}
@Override
public String toString() {
return peerNodeId.toHexString();
}
@Override
public boolean isDisconnectReceived() {
return disconnectReceived;
}
@Override
public boolean isDisconnectRequested() {
return disconnectRequested.get();
}
@Override
public DisconnectReason getDisconnectReason() {
return disconnectReason;
}
@Override
public String peerHost() {
return peerHost;
}
@Override
public int peerPort() {
return peerPort;
}
@Override
public SECP256K1.PublicKey peerPublicKey() {
return SECP256K1.PublicKey.fromBytes(peerNodeId);
}
@Override
public HelloMessage getPeerHello() {
return peerHelloMessage;
}
@Override
public void registerListener(EventListener listener) {
eventListener = listener;
}
}