Merge pull request #144 from mochi/websocket-ssl

Fix ssl receive support for websocket
diff --git a/src/mochiweb_websocket.erl b/src/mochiweb_websocket.erl
index cc3127e..2768a3e 100644
--- a/src/mochiweb_websocket.erl
+++ b/src/mochiweb_websocket.erl
@@ -27,6 +27,9 @@
 
 -export([loop/5, upgrade_connection/2, request/5]).
 -export([send/3]).
+-ifdef(TEST).
+-compile(export_all).
+-endif.
 
 loop(Socket, Body, State, WsVersion, ReplyChannel) ->
     ok = mochiweb_socket:setopts(Socket, [{packet, 0}, {active, once}]),
@@ -44,7 +47,7 @@
         {tcp_error, _, _} ->
             mochiweb_socket:close(Socket),
             exit(normal);
-        {tcp, _, WsFrames} ->
+        {Proto, _, WsFrames} when Proto =:= tcp orelse Proto =:= ssl ->
             case parse_frames(WsVersion, WsFrames, Socket) of
                 close ->
                     mochiweb_socket:close(Socket),
@@ -214,7 +217,7 @@
         {tcp_error, _, _} ->
             mochiweb_socket:close(Socket),
             exit(normal);
-        {tcp, _, Continuation} ->
+        {Proto, _, Continuation} when Proto =:= tcp orelse Proto =:= ssl ->
             parse_hybi_frames(Socket, <<PartFrame/binary, Continuation/binary>>,
                               Acc);
         _ ->
@@ -276,11 +279,3 @@
   {Buffer, Rest};
 parse_hixie(<<H, T/binary>>, Buffer) ->
   parse_hixie(T, <<Buffer/binary, H>>).
-
-%%
-%% Tests
-%%
--ifdef(TEST).
--include_lib("eunit/include/eunit.hrl").
--compile(export_all).
--endif.
diff --git a/test/mochiweb_test_util.erl b/test/mochiweb_test_util.erl
new file mode 100644
index 0000000..0801ab5
--- /dev/null
+++ b/test/mochiweb_test_util.erl
@@ -0,0 +1,111 @@
+-module(mochiweb_test_util).
+-export([with_server/3, client_request/4, sock_fun/2]).
+-include("mochiweb_test_util.hrl").
+
+ssl_cert_opts() ->
+    EbinDir = filename:dirname(code:which(?MODULE)),
+    CertDir = filename:join([EbinDir, "..", "support", "test-materials"]),
+    CertFile = filename:join(CertDir, "test_ssl_cert.pem"),
+    KeyFile = filename:join(CertDir, "test_ssl_key.pem"),
+    [{certfile, CertFile}, {keyfile, KeyFile}].
+
+with_server(Transport, ServerFun, ClientFun) ->
+    ServerOpts0 = [{ip, "127.0.0.1"}, {port, 0}, {loop, ServerFun}],
+    ServerOpts = case Transport of
+        plain ->
+            ServerOpts0;
+        ssl ->
+            ServerOpts0 ++ [{ssl, true}, {ssl_opts, ssl_cert_opts()}]
+    end,
+    {ok, Server} = mochiweb_http:start_link(ServerOpts),
+    Port = mochiweb_socket_server:get(Server, port),
+    Res = (catch ClientFun(Transport, Port)),
+    mochiweb_http:stop(Server),
+    Res.
+
+sock_fun(Transport, Port) ->
+    Opts = [binary, {active, false}, {packet, http}],
+    case Transport of
+        plain ->
+            {ok, Socket} = gen_tcp:connect("127.0.0.1", Port, Opts),
+            fun (recv) ->
+                    gen_tcp:recv(Socket, 0);
+                ({recv, Length}) ->
+                    gen_tcp:recv(Socket, Length);
+                ({send, Data}) ->
+                    gen_tcp:send(Socket, Data);
+                ({setopts, L}) ->
+                    inet:setopts(Socket, L);
+                (get) ->
+                    Socket
+            end;
+        ssl ->
+            {ok, Socket} = ssl:connect("127.0.0.1", Port, [{ssl_imp, new} | Opts]),
+            fun (recv) ->
+                    ssl:recv(Socket, 0);
+                ({recv, Length}) ->
+                    ssl:recv(Socket, Length);
+                ({send, Data}) ->
+                    ssl:send(Socket, Data);
+                ({setopts, L}) ->
+                    ssl:setopts(Socket, L);
+                (get) ->
+                    {ssl, Socket}
+            end
+    end.
+
+client_request(Transport, Port, Method, TestReqs) ->
+    client_request(sock_fun(Transport, Port), Method, TestReqs).
+
+client_request(SockFun, _Method, []) ->
+    {the_end, {error, closed}} = {the_end, SockFun(recv)},
+    ok;
+client_request(SockFun, Method,
+               [#treq{path=Path, body=Body, xreply=ExReply} | Rest]) ->
+    Request = [atom_to_list(Method), " ", Path, " HTTP/1.1\r\n",
+               client_headers(Body, Rest =:= []),
+               "\r\n",
+               Body],
+    ok = SockFun({send, Request}),
+    case Method of
+        'GET' ->
+            {ok, {http_response, {1,1}, 200, "OK"}} = SockFun(recv);
+        'POST' ->
+            {ok, {http_response, {1,1}, 201, "Created"}} = SockFun(recv);
+        'CONNECT' ->
+            {ok, {http_response, {1,1}, 200, "OK"}} = SockFun(recv)
+    end,
+    ok = SockFun({setopts, [{packet, httph}]}),
+    {ok, {http_header, _, 'Server', _, "MochiWeb" ++ _}} = SockFun(recv),
+    {ok, {http_header, _, 'Date', _, _}} = SockFun(recv),
+    {ok, {http_header, _, 'Content-Type', _, _}} = SockFun(recv),
+    {ok, {http_header, _, 'Content-Length', _, ConLenStr}} = SockFun(recv),
+    ContentLength = list_to_integer(ConLenStr),
+    {ok, http_eoh} = SockFun(recv),
+    ok = SockFun({setopts, [{packet, raw}]}),
+    {payload, ExReply} = {payload, drain_reply(SockFun, ContentLength, <<>>)},
+    ok = SockFun({setopts, [{packet, http}]}),
+    client_request(SockFun, Method, Rest).
+
+client_headers(Body, IsLastRequest) ->
+    ["Host: localhost\r\n",
+     case Body of
+        <<>> ->
+            "";
+        _ ->
+            ["Content-Type: application/octet-stream\r\n",
+             "Content-Length: ", integer_to_list(byte_size(Body)), "\r\n"]
+     end,
+     case IsLastRequest of
+         true ->
+             "Connection: close\r\n";
+         false ->
+             ""
+     end].
+
+drain_reply(_SockFun, 0, Acc) ->
+    Acc;
+drain_reply(SockFun, Length, Acc) ->
+    Sz = erlang:min(Length, 1024),
+    {ok, B} = SockFun({recv, Sz}),
+    drain_reply(SockFun, Length - Sz, <<Acc/bytes, B/bytes>>).
diff --git a/test/mochiweb_test_util.hrl b/test/mochiweb_test_util.hrl
new file mode 100644
index 0000000..99fdc4e
--- /dev/null
+++ b/test/mochiweb_test_util.hrl
@@ -0,0 +1 @@
+-record(treq, {path, body= <<>>, xreply= <<>>}).
diff --git a/test/mochiweb_tests.erl b/test/mochiweb_tests.erl
index c8bc8ac..fdda3fd 100644
--- a/test/mochiweb_tests.erl
+++ b/test/mochiweb_tests.erl
@@ -1,28 +1,9 @@
 -module(mochiweb_tests).
 -include_lib("eunit/include/eunit.hrl").
-
--record(treq, {path, body= <<>>, xreply= <<>>}).
-
-ssl_cert_opts() ->
-    EbinDir = filename:dirname(code:which(?MODULE)),
-    CertDir = filename:join([EbinDir, "..", "support", "test-materials"]),
-    CertFile = filename:join(CertDir, "test_ssl_cert.pem"),
-    KeyFile = filename:join(CertDir, "test_ssl_key.pem"),
-    [{certfile, CertFile}, {keyfile, KeyFile}].
+-include("mochiweb_test_util.hrl").
 
 with_server(Transport, ServerFun, ClientFun) ->
-    ServerOpts0 = [{ip, "127.0.0.1"}, {port, 0}, {loop, ServerFun}],
-    ServerOpts = case Transport of
-        plain ->
-            ServerOpts0;
-        ssl ->
-            ServerOpts0 ++ [{ssl, true}, {ssl_opts, ssl_cert_opts()}]
-    end,
-    {ok, Server} = mochiweb_http:start_link(ServerOpts),
-    Port = mochiweb_socket_server:get(Server, port),
-    Res = (catch ClientFun(Transport, Port)),
-    mochiweb_http:stop(Server),
-    Res.
+    mochiweb_test_util:with_server(Transport, ServerFun, ClientFun).
 
 request_test() ->
     R = mochiweb_request:new(z, z, "/foo/bar/baz%20wibble+quux?qs=2", z, []),
@@ -148,6 +129,7 @@
     ClientFun = new_client_fun('GET', TestReqs),
     ok = with_server(Transport, ServerFun, ClientFun),
     ok.
+
 do_POST(Transport, Size, Times) ->
     ServerFun = fun (Req) ->
                         Body = Req:recv_body(),
@@ -165,86 +147,6 @@
 
 new_client_fun(Method, TestReqs) ->
     fun (Transport, Port) ->
-            client_request(Transport, Port, Method, TestReqs)
+            mochiweb_test_util:client_request(Transport, Port, Method, TestReqs)
     end.
 
-client_request(Transport, Port, Method, TestReqs) ->
-    Opts = [binary, {active, false}, {packet, http}],
-    SockFun = case Transport of
-        plain ->
-            {ok, Socket} = gen_tcp:connect("127.0.0.1", Port, Opts),
-            fun (recv) ->
-                    gen_tcp:recv(Socket, 0);
-                ({recv, Length}) ->
-                    gen_tcp:recv(Socket, Length);
-                ({send, Data}) ->
-                    gen_tcp:send(Socket, Data);
-                ({setopts, L}) ->
-                    inet:setopts(Socket, L)
-            end;
-        ssl ->
-            {ok, Socket} = ssl:connect("127.0.0.1", Port, [{ssl_imp, new} | Opts]),
-            fun (recv) ->
-                    ssl:recv(Socket, 0);
-                ({recv, Length}) ->
-                    ssl:recv(Socket, Length);
-                ({send, Data}) ->
-                    ssl:send(Socket, Data);
-                ({setopts, L}) ->
-                    ssl:setopts(Socket, L)
-            end
-    end,
-    client_request(SockFun, Method, TestReqs).
-
-client_request(SockFun, _Method, []) ->
-    {the_end, {error, closed}} = {the_end, SockFun(recv)},
-    ok;
-client_request(SockFun, Method,
-               [#treq{path=Path, body=Body, xreply=ExReply} | Rest]) ->
-    Request = [atom_to_list(Method), " ", Path, " HTTP/1.1\r\n",
-               client_headers(Body, Rest =:= []),
-               "\r\n",
-               Body],
-    ok = SockFun({send, Request}),
-    case Method of
-        'GET' ->
-            {ok, {http_response, {1,1}, 200, "OK"}} = SockFun(recv);
-        'POST' ->
-            {ok, {http_response, {1,1}, 201, "Created"}} = SockFun(recv);
-        'CONNECT' ->
-            {ok, {http_response, {1,1}, 200, "OK"}} = SockFun(recv)
-    end,
-    ok = SockFun({setopts, [{packet, httph}]}),
-    {ok, {http_header, _, 'Server', _, "MochiWeb" ++ _}} = SockFun(recv),
-    {ok, {http_header, _, 'Date', _, _}} = SockFun(recv),
-    {ok, {http_header, _, 'Content-Type', _, _}} = SockFun(recv),
-    {ok, {http_header, _, 'Content-Length', _, ConLenStr}} = SockFun(recv),
-    ContentLength = list_to_integer(ConLenStr),
-    {ok, http_eoh} = SockFun(recv),
-    ok = SockFun({setopts, [{packet, raw}]}),
-    {payload, ExReply} = {payload, drain_reply(SockFun, ContentLength, <<>>)},
-    ok = SockFun({setopts, [{packet, http}]}),
-    client_request(SockFun, Method, Rest).
-
-client_headers(Body, IsLastRequest) ->
-    ["Host: localhost\r\n",
-     case Body of
-        <<>> ->
-            "";
-        _ ->
-            ["Content-Type: application/octet-stream\r\n",
-             "Content-Length: ", integer_to_list(byte_size(Body)), "\r\n"]
-     end,
-     case IsLastRequest of
-         true ->
-             "Connection: close\r\n";
-         false ->
-             ""
-     end].
-
-drain_reply(_SockFun, 0, Acc) ->
-    Acc;
-drain_reply(SockFun, Length, Acc) ->
-    Sz = erlang:min(Length, 1024),
-    {ok, B} = SockFun({recv, Sz}),
-    drain_reply(SockFun, Length - Sz, <<Acc/bytes, B/bytes>>).
diff --git a/test/mochiweb_websocket_tests.erl b/test/mochiweb_websocket_tests.erl
index 890aa17..5711a55 100644
--- a/test/mochiweb_websocket_tests.erl
+++ b/test/mochiweb_websocket_tests.erl
@@ -82,3 +82,95 @@
        mochiweb_websocket:parse_hixie_frames(
          <<0,102,111,111,255,0,98,97,114,255>>,
          [])).
+
+end_to_end_test_factory(ServerTransport) ->
+    mochiweb_test_util:with_server(
+      ServerTransport,
+      fun end_to_end_server/1,
+      fun (Transport, Port) ->
+              end_to_end_client(mochiweb_test_util:sock_fun(Transport, Port))
+      end).
+
+end_to_end_server(Req) ->
+    ?assertEqual("Upgrade", Req:get_header_value("connection")),
+    ?assertEqual("websocket", Req:get_header_value("upgrade")),
+    {ReentryWs, _ReplyChannel} = mochiweb_websocket:upgrade_connection(
+                                   Req,
+                                   fun end_to_end_ws_loop/3),
+    ReentryWs(ok).
+
+end_to_end_ws_loop(Payload, State, ReplyChannel) ->
+    %% Echo server
+    lists:foreach(ReplyChannel, Payload),
+    State.
+
+end_to_end_client(S) ->
+    %% Key and Accept per https://tools.ietf.org/html/rfc6455
+    UpgradeReq = string:join(
+                   ["GET / HTTP/1.1",
+                    "Host: localhost",
+                    "Upgrade: websocket",
+                    "Connection: Upgrade",
+                    "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==",
+                    "",
+                    ""], "\r\n"),
+    ok = S({send, UpgradeReq}),
+    {ok, {http_response, {1,1}, 101, _}} = S(recv),
+    ok = S({setopts, [{packet, httph}]}),
+    D = read_expected_headers(
+          S,
+          gb_from_list(
+            [{'Upgrade', "websocket"},
+             {'Connection', "Upgrade"},
+             {'Content-Length', "0"},
+             {"Sec-Websocket-Accept", "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}])),
+    ?assertEqual([], gb_trees:to_list(D)),
+    ok = S({setopts, [{packet, raw}]}),
+    %% The first message sent over telegraph :)
+    SmallMessage = <<"What hath God wrought?">>,
+    ok = S({send,
+       << 1:1, %% Fin
+          0:1, %% Rsv1
+          0:1, %% Rsv2
+          0:1, %% Rsv3
+          2:4, %% Opcode, 1 = text frame
+          1:1, %% Mask on
+          (byte_size(SmallMessage)):7, %% Length, <125 case
+          0:32, %% Mask (trivial)
+          SmallMessage/binary >>}),
+    {ok, WsFrames} = S(recv),
+    << 1:1, %% Fin
+       0:1, %% Rsv1
+       0:1, %% Rsv2
+       0:1, %% Rsv3
+       1:4, %% Opcode, text frame (all mochiweb suports for now)
+       MsgSize:8, %% Expecting small size
+       SmallMessage/binary >> = WsFrames,
+    ?assertEqual(MsgSize, byte_size(SmallMessage)),
+    ok.
+
+gb_from_list(L) ->
+    lists:foldl(
+      fun ({K, V}, D) -> gb_trees:insert(K, V, D) end,
+      gb_trees:empty(),
+      L).
+
+read_expected_headers(S, D) ->
+    case S(recv) of
+        {ok, http_eoh} ->
+            D;
+        {ok, {http_header, _, K, _, V}} ->
+            case gb_trees:lookup(K, D) of
+                {value, V1} ->
+                    ?assertEqual({K, V}, {K, V1}),
+                    read_expected_headers(S, gb_trees:delete(K, D));
+                none ->
+                    read_expected_headers(S, D)
+            end
+    end.
+
+end_to_end_http_test() ->
+    end_to_end_test_factory(plain).
+
+end_to_end_https_test() ->
+    end_to_end_test_factory(ssl).