| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| package org.apache.sshd.common.channel; |
| |
| import java.io.EOFException; |
| import java.io.IOException; |
| import java.time.Duration; |
| import java.util.Objects; |
| import java.util.Queue; |
| import java.util.concurrent.ConcurrentLinkedQueue; |
| import java.util.concurrent.atomic.AtomicInteger; |
| import java.util.concurrent.atomic.AtomicLong; |
| import java.util.concurrent.atomic.AtomicReference; |
| |
| import org.apache.sshd.common.Closeable; |
| import org.apache.sshd.common.PropertyResolver; |
| import org.apache.sshd.common.channel.exception.SshChannelBufferedOutputException; |
| import org.apache.sshd.common.io.IoOutputStream; |
| import org.apache.sshd.common.io.IoWriteFuture; |
| import org.apache.sshd.common.util.ValidateUtils; |
| import org.apache.sshd.common.util.buffer.Buffer; |
| import org.apache.sshd.common.util.closeable.AbstractInnerCloseable; |
| import org.apache.sshd.common.util.functors.UnaryEquator; |
| import org.apache.sshd.core.CoreModuleProperties; |
| |
| /** |
| * An {@link IoOutputStream} capable of queuing write requests. |
| */ |
| public class BufferedIoOutputStream extends AbstractInnerCloseable implements IoOutputStream, ChannelIdentifier { |
| protected final Object id; |
| protected final long channelId; |
| protected final int maxPendingBytesCount; |
| protected final Duration maxWaitForPendingWrites; |
| protected final IoOutputStream out; |
| protected final AtomicInteger pendingBytesCount = new AtomicInteger(); |
| protected final AtomicLong writtenBytesCount = new AtomicLong(); |
| protected final Queue<IoWriteFutureImpl> writes = new ConcurrentLinkedQueue<>(); |
| protected final AtomicReference<IoWriteFutureImpl> currentWrite = new AtomicReference<>(); |
| protected final AtomicReference<SshChannelBufferedOutputException> pendingException = new AtomicReference<>(); |
| |
| public BufferedIoOutputStream(Object id, long channelId, IoOutputStream out, PropertyResolver resolver) { |
| this(id, channelId, out, CoreModuleProperties.BUFFERED_IO_OUTPUT_MAX_PENDING_WRITE_SIZE.getRequired(resolver), |
| CoreModuleProperties.BUFFERED_IO_OUTPUT_MAX_PENDING_WRITE_WAIT.getRequired(resolver)); |
| } |
| |
| public BufferedIoOutputStream(Object id, long channelId, IoOutputStream out, int maxPendingBytesCount, |
| Duration maxWaitForPendingWrites) { |
| this.id = Objects.requireNonNull(id, "No stream identifier provided"); |
| this.channelId = channelId; |
| this.out = Objects.requireNonNull(out, "No delegate output stream provided"); |
| this.maxPendingBytesCount = maxPendingBytesCount; |
| ValidateUtils.checkTrue(maxPendingBytesCount > 0, "Invalid max. pending bytes count: %d", maxPendingBytesCount); |
| this.maxWaitForPendingWrites = Objects.requireNonNull(maxWaitForPendingWrites, "No max. pending time value provided"); |
| } |
| |
| @Override |
| public long getChannelId() { |
| return channelId; |
| } |
| |
| public Object getId() { |
| return id; |
| } |
| |
| @Override |
| public IoWriteFuture writeBuffer(Buffer buffer) throws IOException { |
| if (isClosing()) { |
| throw new EOFException("Closed/ing - state=" + state); |
| } |
| |
| waitForAvailableWriteSpace(buffer.available()); |
| |
| IoWriteFutureImpl future = new IoWriteFutureImpl(getId(), buffer); |
| writes.add(future); |
| startWriting(); |
| return future; |
| } |
| |
| protected void waitForAvailableWriteSpace(int requiredSize) throws IOException { |
| /* |
| * NOTE: this code allows a single pending write to give this mechanism "the slip" and |
| * exit the loop "unscathed" even though there is a pending exception. However, the goal |
| * here is to avoid an OOM by having an unlimited accumulation of pending write requests |
| * due to fact that the peer is not consuming the sent data. Please note that the pending |
| * exception is "sticky" - i.e., the next write attempt will fail. This also means that if |
| * the write request that "got away" was the last one by chance and it was consumed by the |
| * peer there will be no exception thrown - which is also fine since as mentioned the goal |
| * is not to enforce a strict limit on the pending bytes size but rather on the accumulation |
| * of the pending write requests. |
| * |
| * We could have counted pending requests rather than bytes. However, we also want to avoid |
| * having a large amount of data pending consumption by the peer as well. This code strikes |
| * such a balance by allowing a single pending request to exceed the limit, but at the same |
| * time prevents too many bytes from pending by having a bunch of pending requests that while |
| * below the imposed number limit may cumulatively represent a lot of pending bytes. |
| */ |
| |
| long expireTime = System.currentTimeMillis() + maxWaitForPendingWrites.toMillis(); |
| synchronized (pendingBytesCount) { |
| for (int count = pendingBytesCount.get(); |
| /* |
| * The (count > 0) condition is put in place to allow a single pending |
| * write to exceed the maxPendingBytesCount as long as there are no |
| * other pending ones. |
| */ |
| (count > 0) |
| // Not already over the limit or about to be over it |
| && ((count + requiredSize) > maxPendingBytesCount) |
| // No pending exception signaled |
| && (pendingException.get() == null); |
| count = pendingBytesCount.get()) { |
| long remTime = expireTime - System.currentTimeMillis(); |
| if (remTime <= 0L) { |
| pendingException.compareAndSet(null, |
| new SshChannelBufferedOutputException(channelId, |
| "Max. pending write timeout expired after " + writtenBytesCount + " bytes")); |
| throw pendingException.get(); |
| } |
| |
| try { |
| pendingBytesCount.wait(remTime); |
| } catch (InterruptedException e) { |
| pendingException.compareAndSet(null, |
| new SshChannelBufferedOutputException(channelId, |
| "Waiting for pending writes interrupted after " + writtenBytesCount + " bytes")); |
| throw pendingException.get(); |
| } |
| } |
| |
| IOException e = pendingException.get(); |
| if (e != null) { |
| throw e; |
| } |
| |
| pendingBytesCount.addAndGet(requiredSize); |
| } |
| } |
| |
| protected void startWriting() throws IOException { |
| IoWriteFutureImpl future = writes.peek(); |
| // No more pending requests |
| if (future == null) { |
| return; |
| } |
| |
| // Don't try to write any further if pending exception signaled |
| Throwable pendingError = pendingException.get(); |
| if (pendingError != null) { |
| log.error("startWriting({})[{}] propagate to {} write requests pending error={}[{}]", |
| getId(), out, writes.size(), getClass().getSimpleName(), pendingError.getMessage()); |
| |
| IoWriteFutureImpl currentFuture = currentWrite.getAndSet(null); |
| for (IoWriteFutureImpl pendingWrite : writes) { |
| // Checking reference by design |
| if (UnaryEquator.isSameReference(pendingWrite, currentFuture)) { |
| continue; // will be taken care of when its listener is eventually called |
| } |
| |
| future.setValue(pendingError); |
| } |
| |
| writes.clear(); |
| return; |
| } |
| |
| // Cannot honor this request yet since other pending one incomplete |
| if (!currentWrite.compareAndSet(null, future)) { |
| return; |
| } |
| |
| Buffer buffer = future.getBuffer(); |
| int bufferSize = buffer.available(); |
| out.writeBuffer(buffer).addListener(f -> { |
| if (f.isWritten()) { |
| future.setValue(Boolean.TRUE); |
| } else { |
| future.setValue(f.getException()); |
| } |
| finishWrite(future, bufferSize); |
| }); |
| } |
| |
| protected void finishWrite(IoWriteFutureImpl future, int bufferSize) { |
| /* |
| * Update the pending bytes count only if successfully written, |
| * otherwise signal an error |
| */ |
| if (future.isWritten()) { |
| long writtenSize = writtenBytesCount.addAndGet(bufferSize); |
| |
| int stillPending; |
| synchronized (pendingBytesCount) { |
| stillPending = pendingBytesCount.addAndGet(0 - bufferSize); |
| pendingBytesCount.notifyAll(); |
| } |
| |
| /* |
| * NOTE: since the pending exception is updated outside the synchronized block |
| * a pending write could be successfully enqueued, however this is acceptable |
| * - see comment in waitForAvailableWriteSpace |
| */ |
| if (stillPending < 0) { |
| log.error("finishWrite({})[{}] - pending byte counts underflow ({}) after {} bytes", getId(), out, stillPending, |
| writtenSize); |
| pendingException.compareAndSet(null, |
| new SshChannelBufferedOutputException(channelId, "Pending byte counts underflow")); |
| } |
| } else { |
| Throwable t = future.getException(); |
| if (t instanceof SshChannelBufferedOutputException) { |
| pendingException.compareAndSet(null, (SshChannelBufferedOutputException) t); |
| } else { |
| pendingException.compareAndSet(null, new SshChannelBufferedOutputException(channelId, t)); |
| } |
| |
| // In case someone waiting so that they can detect the exception |
| synchronized (pendingBytesCount) { |
| pendingBytesCount.notifyAll(); |
| } |
| } |
| |
| writes.remove(future); |
| currentWrite.compareAndSet(future, null); |
| try { |
| startWriting(); |
| } catch (IOException e) { |
| if (e instanceof SshChannelBufferedOutputException) { |
| pendingException.compareAndSet(null, (SshChannelBufferedOutputException) e); |
| } else { |
| pendingException.compareAndSet(null, new SshChannelBufferedOutputException(channelId, e)); |
| } |
| error("finishWrite({})[{}] failed ({}) re-start writing: {}", |
| getId(), out, e.getClass().getSimpleName(), e.getMessage(), e); |
| } |
| } |
| |
| @Override |
| protected Closeable getInnerCloseable() { |
| return builder().when(getId(), writes).close(out).build(); |
| } |
| |
| @Override |
| public String toString() { |
| return getClass().getSimpleName() + "(" + getId() + "@" + channelId + ")[" + out + "]"; |
| } |
| } |