NET-477 TFTP sendFile retry broken

git-svn-id: https://svn.apache.org/repos/asf/commons/proper/net/trunk@1782352 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/src/changes/changes.xml b/src/changes/changes.xml
index 6719450..2a1aa2b 100644
--- a/src/changes/changes.xml
+++ b/src/changes/changes.xml
@@ -87,6 +87,9 @@
   The POP3Mail examples can now get password from console, stdin or an environment variable.
   
 ">
+            <action issue="NET-477" type="fix" dev="sebb" due-to="John Walton">
+            TFTP sendFile retry broken
+            </action>
             <action issue="NET-612" type="update" dev="sebb">
             Allow TFTP socket IO tracing
             </action>
diff --git a/src/main/java/org/apache/commons/net/tftp/TFTPClient.java b/src/main/java/org/apache/commons/net/tftp/TFTPClient.java
index 79e31fc..53cc243 100644
--- a/src/main/java/org/apache/commons/net/tftp/TFTPClient.java
+++ b/src/main/java/org/apache/commons/net/tftp/TFTPClient.java
@@ -146,166 +146,115 @@
     public int receiveFile(String filename, int mode, OutputStream output,
                            InetAddress host, int port) throws IOException
     {
-        int bytesRead, timeouts, lastBlock, block, hostPort, dataLength;
-        TFTPPacket sent, received = null;
-        TFTPErrorPacket error;
-        TFTPDataPacket data;
-        TFTPAckPacket ack = new TFTPAckPacket(host, port, 0);
+        int bytesRead = 0;
+        int lastBlock = 0;
+        int block = 1;
+        int hostPort = 0;
+        int dataLength = 0;
 
-        beginBufferedOps();
-
-        dataLength = lastBlock = hostPort = bytesRead = 0;
         totalBytesReceived = 0;
-        block = 1;
 
         if (mode == TFTP.ASCII_MODE) {
             output = new FromNetASCIIOutputStream(output);
         }
 
-        sent =
-            new TFTPReadRequestPacket(host, port, filename, mode);
+        TFTPPacket sent = new TFTPReadRequestPacket(host, port, filename, mode);
+        TFTPAckPacket ack = new TFTPAckPacket(host, port, 0);
 
-_sendPacket:
-        do
-        {
-            bufferedSend(sent);
+        beginBufferedOps();
 
-_receivePacket:
-            while (true)
-            {
-                timeouts = 0;
-                do {
-                    try
-                    {
-                        received = bufferedReceive();
-                        break;
-                    }
-                    catch (SocketException e)
-                    {
-                        if (++timeouts >= __maxTimeouts)
+        try {
+            do { // while more data to fetch
+                bufferedSend(sent); // start the fetch/send an ack
+                boolean wantReply = true;
+                int timeouts = 0;
+                do { // until successful response
+                    try {
+                        TFTPPacket received = bufferedReceive();
+                        // The first time we receive we get the port number and
+                        // answering host address (for hosts with multiple IPs)
+                        final int recdPort = received.getPort();
+                        final InetAddress recdAddress = received.getAddress();
+                        if (lastBlock == 0)
                         {
-                            endBufferedOps();
+                            hostPort = recdPort;
+                            ack.setPort(hostPort);
+                            if(!host.equals(recdAddress))
+                            {
+                                host = recdAddress;
+                                ack.setAddress(host);
+                                sent.setAddress(host);
+                            }
+                        }
+                        // Comply with RFC 783 indication that an error acknowledgment
+                        // should be sent to originator if unexpected TID or host.
+                        if (host.equals(recdAddress) && recdPort == hostPort) {
+                            switch (received.getType()) {
+
+                            case TFTPPacket.ERROR:
+                                TFTPErrorPacket error = (TFTPErrorPacket)received;
+                                throw new IOException("Error code " + error.getError() +
+                                                      " received: " + error.getMessage());
+                            case TFTPPacket.DATA:
+                                TFTPDataPacket data = (TFTPDataPacket)received;
+                                dataLength = data.getDataLength();
+                                lastBlock = data.getBlockNumber();
+
+                                if (lastBlock == block) { // is the next block number?
+                                    try {
+                                        output.write(data.getData(), data.getDataOffset(), dataLength);
+                                    } catch (IOException e) {
+                                        error = new TFTPErrorPacket(host, hostPort,
+                                                                    TFTPErrorPacket.OUT_OF_SPACE,
+                                                                    "File write failed.");
+                                        bufferedSend(error);
+                                        throw e;
+                                    }
+                                    ++block;
+                                    if (block > 65535) {
+                                        // wrap the block number
+                                        block = 0;
+                                    }
+                                    wantReply = false; // got the next block, drop out to ack it
+                                } else { // unexpected block number
+                                    discardPackets();
+                                    if (lastBlock == (block == 0 ? 65535 : (block - 1))) {
+                                        wantReply = false; // Resend last acknowledgemen
+                                    }
+                                }
+                                break;
+
+                            default:
+                                throw new IOException("Received unexpected packet type (" + received.getType() + ")");
+                            }
+                        } else { // incorrect host or TID
+                            TFTPErrorPacket error = new TFTPErrorPacket(recdAddress, recdPort,
+                                    TFTPErrorPacket.UNKNOWN_TID,
+                                    "Unexpected host or port.");
+                            bufferedSend(error);
+                        }
+                    } catch (SocketException e) {
+                        if (++timeouts >= __maxTimeouts) {
                             throw new IOException("Connection timed out.");
                         }
-                        continue _sendPacket;
-                    }
-                    catch (InterruptedIOException e)
-                    {
-                        if (++timeouts >= __maxTimeouts)
-                        {
-                            endBufferedOps();
+                    } catch (InterruptedIOException e) {
+                        if (++timeouts >= __maxTimeouts) {
                             throw new IOException("Connection timed out.");
                         }
-                        continue _sendPacket;
-                    }
-                    catch (TFTPPacketException e)
-                    {
-                        endBufferedOps();
+                    } catch (TFTPPacketException e) {
                         throw new IOException("Bad packet: " + e.getMessage());
                     }
-                } while (timeouts < __maxTimeouts); // __maxTimeouts >=1 so will always do loop at least once
+                } while(wantReply); // waiting for response
 
-                // The first time we receive we get the port number and
-                // answering host address (for hosts with multiple IPs)
-                if (lastBlock == 0)
-                {
-                    hostPort = received.getPort();
-                    ack.setPort(hostPort);
-                    if(!host.equals(received.getAddress()))
-                    {
-                        host = received.getAddress();
-                        ack.setAddress(host);
-                        sent.setAddress(host);
-                    }
-                }
-
-                // Comply with RFC 783 indication that an error acknowledgment
-                // should be sent to originator if unexpected TID or host.
-                if (host.equals(received.getAddress()) &&
-                        received.getPort() == hostPort)
-                {
-
-                    switch (received.getType())
-                    {
-                    case TFTPPacket.ERROR:
-                        error = (TFTPErrorPacket)received;
-                        endBufferedOps();
-                        throw new IOException("Error code " + error.getError() +
-                                              " received: " + error.getMessage());
-                    case TFTPPacket.DATA:
-                        data = (TFTPDataPacket)received;
-                        dataLength = data.getDataLength();
-
-                        lastBlock = data.getBlockNumber();
-
-                        if (lastBlock == block)
-                        {
-                            try
-                            {
-                                output.write(data.getData(), data.getDataOffset(),
-                                             dataLength);
-                            }
-                            catch (IOException e)
-                            {
-                                error = new TFTPErrorPacket(host, hostPort,
-                                                            TFTPErrorPacket.OUT_OF_SPACE,
-                                                            "File write failed.");
-                                bufferedSend(error);
-                                endBufferedOps();
-                                throw e;
-                            }
-                            ++block;
-                            if (block > 65535)
-                            {
-                                // wrap the block number
-                                block = 0;
-                            }
-
-                            break _receivePacket;
-                        }
-                        else
-                        {
-                            discardPackets();
-
-                            if (lastBlock == (block == 0 ? 65535 : (block - 1))) {
-                                continue _sendPacket;  // Resend last acknowledgement.
-                            }
-
-                            continue _receivePacket; // Start fetching packets again.
-                        }
-                        //break;
-
-                    default:
-                        endBufferedOps();
-                        throw new IOException("Received unexpected packet type.");
-                    }
-                }
-                else
-                {
-                    error = new TFTPErrorPacket(received.getAddress(),
-                                                received.getPort(),
-                                                TFTPErrorPacket.UNKNOWN_TID,
-                                                "Unexpected host or port.");
-                    bufferedSend(error);
-                    continue _sendPacket;
-                }
-
-                // We should never get here, but this is a safety to avoid
-                // infinite loop.  If only Java had the goto statement.
-                //break;
-            }
-
-            ack.setBlockNumber(lastBlock);
-            sent = ack;
-            bytesRead += dataLength;
-            totalBytesReceived += dataLength;
-        } // First data packet less than 512 bytes signals end of stream.
-
-        while (dataLength == TFTPPacket.SEGMENT_SIZE);
-
-        bufferedSend(sent);
-        endBufferedOps();
-
+                ack.setBlockNumber(lastBlock);
+                sent = ack;
+                bytesRead += dataLength;
+                totalBytesReceived += dataLength;
+            } while (dataLength == TFTPPacket.SEGMENT_SIZE); // not eof
+            bufferedSend(sent); // send the final ack
+        } finally {
+            endBufferedOps();
+        }
         return bytesRead;
     }
 
@@ -396,182 +345,120 @@
     public void sendFile(String filename, int mode, InputStream input,
                          InetAddress host, int port) throws IOException
     {
-        int bytesRead, timeouts, lastBlock, block, hostPort, dataLength, offset, totalThisPacket;
-        TFTPPacket sent, received = null;
-        TFTPErrorPacket error;
-        TFTPDataPacket data =
-            new TFTPDataPacket(host, port, 0, _sendBuffer, 4, 0);
-        TFTPAckPacket ack;
-
+        int block = 0;
+        int hostPort = 0;
         boolean justStarted = true;
-
-        beginBufferedOps();
-
-        dataLength = lastBlock = hostPort = bytesRead = totalThisPacket = 0;
-        totalBytesSent = 0L;
-        block = 0;
         boolean lastAckWait = false;
 
+        totalBytesSent = 0L;
+
         if (mode == TFTP.ASCII_MODE) {
             input = new ToNetASCIIInputStream(input);
         }
 
-        sent =
-            new TFTPWriteRequestPacket(host, port, filename, mode);
+        TFTPPacket sent = new TFTPWriteRequestPacket(host, port, filename, mode);
+        TFTPDataPacket data = new TFTPDataPacket(host, port, 0, _sendBuffer, 4, 0);
 
-_sendPacket:
-        do
-        {
-            // first time: block is 0, lastBlock is 0, send a request packet.
-            // subsequent: block is integer starting at 1, send data packet.
-            bufferedSend(sent);
+        beginBufferedOps();
 
-            // this is trying to receive an ACK
-_receivePacket:
-            while (true)
-            {
-
-
-                timeouts = 0;
+        try {
+            do { // until eof
+                // first time: block is 0, lastBlock is 0, send a request packet.
+                // subsequent: block is integer starting at 1, send data packet.
+                bufferedSend(sent);
+                boolean wantReply = true;
+                int timeouts = 0;
                 do {
-                    try
-                    {
-                        received = bufferedReceive();
-                        break;
-                    }
-                    catch (SocketException e)
-                    {
-                        if (++timeouts >= __maxTimeouts)
-                        {
-                            endBufferedOps();
+                    try {
+                        TFTPPacket received = bufferedReceive();
+                        final InetAddress recdAddress = received.getAddress();
+                        final int recdPort = received.getPort();
+                        // The first time we receive we get the port number and
+                        // answering host address (for hosts with multiple IPs)
+                        if (justStarted) {
+                            justStarted = false;
+                            hostPort = recdPort;
+                            data.setPort(hostPort);
+                            if (!host.equals(recdAddress)) {
+                                host = recdAddress;
+                                data.setAddress(host);
+                                sent.setAddress(host);
+                            }
+                        }
+                        // Comply with RFC 783 indication that an error acknowledgment
+                        // should be sent to originator if unexpected TID or host.
+                        if (host.equals(recdAddress) && recdPort == hostPort) {
+
+                            switch (received.getType()) {
+                            case TFTPPacket.ERROR:
+                                TFTPErrorPacket error = (TFTPErrorPacket)received;
+                                throw new IOException("Error code " + error.getError() +
+                                                      " received: " + error.getMessage());
+                            case TFTPPacket.ACKNOWLEDGEMENT:
+
+                                int lastBlock = ((TFTPAckPacket)received).getBlockNumber();
+
+                                if (lastBlock == block) {
+                                    ++block;
+                                    if (block > 65535) {
+                                        // wrap the block number
+                                        block = 0;
+                                    }
+                                    wantReply = false; // got the ack we want
+                                } else { 
+                                    discardPackets();
+                                }
+                                break;
+                            default:
+                                throw new IOException("Received unexpected packet type.");
+                            }
+                        } else { // wrong host or TID; send error
+                            TFTPErrorPacket error = new TFTPErrorPacket(recdAddress,
+                                                        recdPort,
+                                                        TFTPErrorPacket.UNKNOWN_TID,
+                                                        "Unexpected host or port.");
+                            bufferedSend(error);
+                        }
+                    } catch (SocketException e) {
+                        if (++timeouts >= __maxTimeouts) {
                             throw new IOException("Connection timed out.");
                         }
-                        continue _sendPacket;
-                    }
-                    catch (InterruptedIOException e)
-                    {
-                        if (++timeouts >= __maxTimeouts)
-                        {
-                            endBufferedOps();
+                    } catch (InterruptedIOException e) {
+                        if (++timeouts >= __maxTimeouts) {
                             throw new IOException("Connection timed out.");
                         }
-                        continue _sendPacket;
-                    }
-                    catch (TFTPPacketException e)
-                    {
-                        endBufferedOps();
+                    } catch (TFTPPacketException e) {
                         throw new IOException("Bad packet: " + e.getMessage());
                     }
-                } // end of while loop over tries to receive
-                while (timeouts < __maxTimeouts); // __maxTimeouts >=1 so will always do loop at least once
+                    // retry until a good ack
+                } while(wantReply);
 
-
-                // The first time we receive we get the port number and
-                // answering host address (for hosts with multiple IPs)
-                if (justStarted)
-                {
-                    justStarted = false;
-                    hostPort = received.getPort();
-                    data.setPort(hostPort);
-                    if(!host.equals(received.getAddress()))
-                    {
-                        host = received.getAddress();
-                        data.setAddress(host);
-                        sent.setAddress(host);
-                    }
+                if (lastAckWait) {
+                    break; // we were waiting for this; now all done
                 }
 
-                // Comply with RFC 783 indication that an error acknowledgment
-                // should be sent to originator if unexpected TID or host.
-                if (host.equals(received.getAddress()) &&
-                        received.getPort() == hostPort)
-                {
-
-                    switch (received.getType())
-                    {
-                    case TFTPPacket.ERROR:
-                        error = (TFTPErrorPacket)received;
-                        endBufferedOps();
-                        throw new IOException("Error code " + error.getError() +
-                                              " received: " + error.getMessage());
-                    case TFTPPacket.ACKNOWLEDGEMENT:
-                        ack = (TFTPAckPacket)received;
-
-                        lastBlock = ack.getBlockNumber();
-
-                        if (lastBlock == block)
-                        {
-                            ++block;
-                            if (block > 65535)
-                            {
-                                // wrap the block number
-                                block = 0;
-                            }
-                            if (lastAckWait) {
-
-                              break _sendPacket;
-                            }
-                            else {
-                              break _receivePacket;
-                            }
-                        }
-                        else
-                        {
-                            discardPackets();
-
-                            continue _receivePacket; // Start fetching packets again.
-                        }
-                        //break;
-
-                    default:
-                        endBufferedOps();
-                        throw new IOException("Received unexpected packet type.");
-                    }
+                int dataLength = TFTPPacket.SEGMENT_SIZE;
+                int offset = 4;
+                int totalThisPacket = 0;
+                int bytesRead = 0;
+                while (dataLength > 0 &&
+                        (bytesRead = input.read(_sendBuffer, offset, dataLength)) > 0) {
+                    offset += bytesRead;
+                    dataLength -= bytesRead;
+                    totalThisPacket += bytesRead;
                 }
-                else
-                {
-                    error = new TFTPErrorPacket(received.getAddress(),
-                                                received.getPort(),
-                                                TFTPErrorPacket.UNKNOWN_TID,
-                                                "Unexpected host or port.");
-                    bufferedSend(error);
-                    continue _sendPacket;
+                if( totalThisPacket < TFTPPacket.SEGMENT_SIZE ) {
+                    /* this will be our last packet -- send, wait for ack, stop */
+                    lastAckWait = true;
                 }
-
-                // We should never get here, but this is a safety to avoid
-                // infinite loop.  If only Java had the goto statement.
-                //break;
-            }
-
-            // OK, we have just gotten ACK about the last data we sent. Make another
-            // and send it
-
-            dataLength = TFTPPacket.SEGMENT_SIZE;
-            offset = 4;
-            totalThisPacket = 0;
-            while (dataLength > 0 &&
-                    (bytesRead = input.read(_sendBuffer, offset, dataLength)) > 0)
-            {
-                offset += bytesRead;
-                dataLength -= bytesRead;
-                totalThisPacket += bytesRead;
-            }
-
-            if( totalThisPacket < TFTPPacket.SEGMENT_SIZE ) {
-                /* this will be our last packet -- send, wait for ack, stop */
-                lastAckWait = true;
-            }
-            data.setBlockNumber(block);
-            data.setData(_sendBuffer, 4, totalThisPacket);
-            sent = data;
-            totalBytesSent += totalThisPacket;
+                data.setBlockNumber(block);
+                data.setData(_sendBuffer, 4, totalThisPacket);
+                sent = data;
+                totalBytesSent += totalThisPacket;
+            } while (true); // loops until after lastAckWait is set
+        } finally {
+            endBufferedOps();
         }
-        while ( totalThisPacket > 0 || lastAckWait );
-        // Note: this was looping while dataLength == 0 || lastAckWait,
-        // which was discarding the last packet if it was not full size
-        // Should send the packet.
-
-        endBufferedOps();
     }