send "Connection: close" header when the server is going to force-close the connection #146
diff --git a/CHANGES.md b/CHANGES.md
index b591a43..24a59d6 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,10 @@
+Version 2.12.0 released 2015-01-16
+
+* Send "Connection: close" header when the server is going to close
+  a Keep-Alive connection, usually due to unread data from the
+  client
+  https://github.com/mochi/mochiweb/issues/146
+
 Version 2.11.2 released 2015-01-16
 
 * Fix regression introduced in #147
diff --git a/src/mochiweb.app.src b/src/mochiweb.app.src
index 97fc90d..8cb43ac 100644
--- a/src/mochiweb.app.src
+++ b/src/mochiweb.app.src
@@ -1,7 +1,7 @@
 %% This is generated from src/mochiweb.app.src
 {application, mochiweb,
  [{description, "MochiMedia Web Server"},
-  {vsn, "2.11.2"},
+  {vsn, "2.12.0"},
   {modules, []},
   {registered, []},
   {env, []},
diff --git a/src/mochiweb_request.erl b/src/mochiweb_request.erl
index 9622926..c97070f 100644
--- a/src/mochiweb_request.erl
+++ b/src/mochiweb_request.erl
@@ -302,11 +302,17 @@
 format_response_header({Code, ResponseHeaders}, {?MODULE, [_Socket, _Opts, _Method, _RawPath, Version, _Headers]}=THIS) ->
     HResponse = mochiweb_headers:make(ResponseHeaders),
     HResponse1 = mochiweb_headers:default_from_list(server_headers(), HResponse),
+    HResponse2 = case should_close(THIS) of
+                     true ->
+                         mochiweb_headers:enter("Connection", "close", HResponse1);
+                     false ->
+                         HResponse1
+                 end,
     F = fun ({K, V}, Acc) ->
                 [mochiweb_util:make_io(K), <<": ">>, V, <<"\r\n">> | Acc]
         end,
-    End = lists:foldl(F, [<<"\r\n">>], mochiweb_headers:to_list(HResponse1)),
-    Response = mochiweb:new_response({THIS, Code, HResponse1}),
+    End = lists:foldl(F, [<<"\r\n">>], mochiweb_headers:to_list(HResponse2)),
+    Response = mochiweb:new_response({THIS, Code, HResponse2}),
     {[make_version(Version), make_code(Code), <<"\r\n">> | End], Response};
 format_response_header({Code, ResponseHeaders, Length},
                        {?MODULE, [_Socket, _Opts, _Method, _RawPath, _Version, _Headers]}=THIS) ->
diff --git a/test/mochiweb_test_util.erl b/test/mochiweb_test_util.erl
index 0801ab5..2fbf14f 100644
--- a/test/mochiweb_test_util.erl
+++ b/test/mochiweb_test_util.erl
@@ -1,6 +1,8 @@
 -module(mochiweb_test_util).
--export([with_server/3, client_request/4, sock_fun/2]).
+-export([with_server/3, client_request/4, sock_fun/2,
+         read_server_headers/1, drain_reply/3]).
 -include("mochiweb_test_util.hrl").
+-include_lib("eunit/include/eunit.hrl").
 
 ssl_cert_opts() ->
     EbinDir = filename:dirname(code:which(?MODULE)),
@@ -66,6 +68,7 @@
                client_headers(Body, Rest =:= []),
                "\r\n",
                Body],
+    ok = SockFun({setopts, [{packet, http}]}),
     ok = SockFun({send, Request}),
     case Method of
         'GET' ->
@@ -75,18 +78,30 @@
         'CONNECT' ->
             {ok, {http_response, {1,1}, 200, "OK"}} = SockFun(recv)
     end,
-    ok = SockFun({setopts, [{packet, httph}]}),
-    {ok, {http_header, _, 'Server', _, "MochiWeb" ++ _}} = SockFun(recv),
-    {ok, {http_header, _, 'Date', _, _}} = SockFun(recv),
-    {ok, {http_header, _, 'Content-Type', _, _}} = SockFun(recv),
-    {ok, {http_header, _, 'Content-Length', _, ConLenStr}} = SockFun(recv),
-    ContentLength = list_to_integer(ConLenStr),
-    {ok, http_eoh} = SockFun(recv),
-    ok = SockFun({setopts, [{packet, raw}]}),
+    Headers = read_server_headers(SockFun),
+    ?assertMatch("MochiWeb" ++ _, mochiweb_headers:get_value("Server", Headers)),
+    ?assert(mochiweb_headers:get_value("Date", Headers) =/= undefined),
+    ?assert(mochiweb_headers:get_value("Content-Type", Headers) =/= undefined),
+    ContentLength = list_to_integer(mochiweb_headers:get_value("Content-Length", Headers)),
     {payload, ExReply} = {payload, drain_reply(SockFun, ContentLength, <<>>)},
-    ok = SockFun({setopts, [{packet, http}]}),
     client_request(SockFun, Method, Rest).
 
+read_server_headers(SockFun) ->
+    ok = SockFun({setopts, [{packet, httph}]}),
+    Headers = read_server_headers(SockFun, mochiweb_headers:empty()),
+    ok = SockFun({setopts, [{packet, raw}]}),
+    Headers.
+
+read_server_headers(SockFun, Headers) ->
+    case SockFun(recv) of
+        {ok, http_eoh} ->
+            Headers;
+        {ok, {http_header, _, Header, _, Value}} ->
+            read_server_headers(
+              SockFun,
+              mochiweb_headers:insert(Header, Value, Headers))
+    end.
+
 client_headers(Body, IsLastRequest) ->
     ["Host: localhost\r\n",
      case Body of
diff --git a/test/mochiweb_tests.erl b/test/mochiweb_tests.erl
index fdda3fd..209971b 100644
--- a/test/mochiweb_tests.erl
+++ b/test/mochiweb_tests.erl
@@ -150,3 +150,54 @@
             mochiweb_test_util:client_request(Transport, Port, Method, TestReqs)
     end.
 
+close_on_unread_data_test() ->
+    ok = with_server(
+           plain,
+           fun mochiweb_request:not_found/1,
+           fun close_on_unread_data_client/2).
+
+close_on_unread_data_client(Transport, Port) ->
+    SockFun = mochiweb_test_util:sock_fun(Transport, Port),
+    %% A normal GET request should not trigger this behavior
+    Request0 = string:join(
+                 ["GET / HTTP/1.1",
+                  "Host: localhost",
+                  "",
+                  ""],
+                 "\r\n"),
+    ok = SockFun({setopts, [{packet, http}]}),
+    ok = SockFun({send, Request0}),
+    ?assertMatch(
+       {ok, {http_response, {1, 1}, 404, _}},
+       SockFun(recv)),
+    Headers0 = mochiweb_test_util:read_server_headers(SockFun),
+    ?assertEqual(
+       undefined,
+       mochiweb_headers:get_value("Connection", Headers0)),
+    Len0 = list_to_integer(
+             mochiweb_headers:get_value("Content-Length", Headers0)),
+    _Body0 = mochiweb_test_util:drain_reply(SockFun, Len0, <<>>),
+    %% Re-use same socket
+    Request = string:join(
+                ["POST / HTTP/1.1",
+                 "Host: localhost",
+                 "Content-Type: application/json",
+                 "Content-Length: 2",
+                 "",
+                 "{}"],
+                "\r\n"),
+    ok = SockFun({setopts, [{packet, http}]}),
+    ok = SockFun({send, Request}),
+    ?assertMatch(
+       {ok, {http_response, {1, 1}, 404, _}},
+       SockFun(recv)),
+    Headers = mochiweb_test_util:read_server_headers(SockFun),
+    %% Expect to see a Connection: close header when we know the
+    %% server will close the connection re #146
+    ?assertEqual(
+       "close",
+       mochiweb_headers:get_value("Connection", Headers)),
+    Len = list_to_integer(mochiweb_headers:get_value("Content-Length", Headers)),
+    _Body = mochiweb_test_util:drain_reply(SockFun, Len, <<>>),
+    ?assertEqual({error, closed}, SockFun(recv)),
+    ok.
diff --git a/test/mochiweb_websocket_tests.erl b/test/mochiweb_websocket_tests.erl
index 5711a55..eb8de5b 100644
--- a/test/mochiweb_websocket_tests.erl
+++ b/test/mochiweb_websocket_tests.erl
@@ -116,16 +116,12 @@
                     ""], "\r\n"),
     ok = S({send, UpgradeReq}),
     {ok, {http_response, {1,1}, 101, _}} = S(recv),
-    ok = S({setopts, [{packet, httph}]}),
-    D = read_expected_headers(
-          S,
-          gb_from_list(
-            [{'Upgrade', "websocket"},
-             {'Connection', "Upgrade"},
-             {'Content-Length', "0"},
-             {"Sec-Websocket-Accept", "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}])),
-    ?assertEqual([], gb_trees:to_list(D)),
-    ok = S({setopts, [{packet, raw}]}),
+    read_expected_headers(
+      S,
+      [{'Upgrade', "websocket"},
+       {'Connection', "Upgrade"},
+       {'Content-Length', "0"},
+       {"Sec-Websocket-Accept", "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}]),
     %% The first message sent over telegraph :)
     SmallMessage = <<"What hath God wrought?">>,
     ok = S({send,
@@ -149,25 +145,13 @@
     ?assertEqual(MsgSize, byte_size(SmallMessage)),
     ok.
 
-gb_from_list(L) ->
-    lists:foldl(
-      fun ({K, V}, D) -> gb_trees:insert(K, V, D) end,
-      gb_trees:empty(),
-      L).
-
 read_expected_headers(S, D) ->
-    case S(recv) of
-        {ok, http_eoh} ->
-            D;
-        {ok, {http_header, _, K, _, V}} ->
-            case gb_trees:lookup(K, D) of
-                {value, V1} ->
-                    ?assertEqual({K, V}, {K, V1}),
-                    read_expected_headers(S, gb_trees:delete(K, D));
-                none ->
-                    read_expected_headers(S, D)
-            end
-    end.
+    Headers = mochiweb_test_util:read_server_headers(S),
+    lists:foreach(
+      fun ({K, V}) ->
+              ?assertEqual(V, mochiweb_headers:get_value(K, Headers))
+      end,
+      D).
 
 end_to_end_http_test() ->
     end_to_end_test_factory(plain).