Merge pull request #144 from mochi/websocket-ssl
Fix ssl receive support for websocket
diff --git a/src/mochiweb_websocket.erl b/src/mochiweb_websocket.erl
index cc3127e..2768a3e 100644
--- a/src/mochiweb_websocket.erl
+++ b/src/mochiweb_websocket.erl
@@ -27,6 +27,9 @@
-export([loop/5, upgrade_connection/2, request/5]).
-export([send/3]).
+-ifdef(TEST).
+-compile(export_all).
+-endif.
loop(Socket, Body, State, WsVersion, ReplyChannel) ->
ok = mochiweb_socket:setopts(Socket, [{packet, 0}, {active, once}]),
@@ -44,7 +47,7 @@
{tcp_error, _, _} ->
mochiweb_socket:close(Socket),
exit(normal);
- {tcp, _, WsFrames} ->
+ {Proto, _, WsFrames} when Proto =:= tcp orelse Proto =:= ssl ->
case parse_frames(WsVersion, WsFrames, Socket) of
close ->
mochiweb_socket:close(Socket),
@@ -214,7 +217,7 @@
{tcp_error, _, _} ->
mochiweb_socket:close(Socket),
exit(normal);
- {tcp, _, Continuation} ->
+ {Proto, _, Continuation} when Proto =:= tcp orelse Proto =:= ssl ->
parse_hybi_frames(Socket, <<PartFrame/binary, Continuation/binary>>,
Acc);
_ ->
@@ -276,11 +279,3 @@
{Buffer, Rest};
parse_hixie(<<H, T/binary>>, Buffer) ->
parse_hixie(T, <<Buffer/binary, H>>).
-
-%%
-%% Tests
-%%
--ifdef(TEST).
--include_lib("eunit/include/eunit.hrl").
--compile(export_all).
--endif.
diff --git a/test/mochiweb_test_util.erl b/test/mochiweb_test_util.erl
new file mode 100644
index 0000000..0801ab5
--- /dev/null
+++ b/test/mochiweb_test_util.erl
@@ -0,0 +1,111 @@
+-module(mochiweb_test_util).
+-export([with_server/3, client_request/4, sock_fun/2]).
+-include("mochiweb_test_util.hrl").
+
+ssl_cert_opts() ->
+ EbinDir = filename:dirname(code:which(?MODULE)),
+ CertDir = filename:join([EbinDir, "..", "support", "test-materials"]),
+ CertFile = filename:join(CertDir, "test_ssl_cert.pem"),
+ KeyFile = filename:join(CertDir, "test_ssl_key.pem"),
+ [{certfile, CertFile}, {keyfile, KeyFile}].
+
+with_server(Transport, ServerFun, ClientFun) ->
+ ServerOpts0 = [{ip, "127.0.0.1"}, {port, 0}, {loop, ServerFun}],
+ ServerOpts = case Transport of
+ plain ->
+ ServerOpts0;
+ ssl ->
+ ServerOpts0 ++ [{ssl, true}, {ssl_opts, ssl_cert_opts()}]
+ end,
+ {ok, Server} = mochiweb_http:start_link(ServerOpts),
+ Port = mochiweb_socket_server:get(Server, port),
+ Res = (catch ClientFun(Transport, Port)),
+ mochiweb_http:stop(Server),
+ Res.
+
+sock_fun(Transport, Port) ->
+ Opts = [binary, {active, false}, {packet, http}],
+ case Transport of
+ plain ->
+ {ok, Socket} = gen_tcp:connect("127.0.0.1", Port, Opts),
+ fun (recv) ->
+ gen_tcp:recv(Socket, 0);
+ ({recv, Length}) ->
+ gen_tcp:recv(Socket, Length);
+ ({send, Data}) ->
+ gen_tcp:send(Socket, Data);
+ ({setopts, L}) ->
+ inet:setopts(Socket, L);
+ (get) ->
+ Socket
+ end;
+ ssl ->
+ {ok, Socket} = ssl:connect("127.0.0.1", Port, [{ssl_imp, new} | Opts]),
+ fun (recv) ->
+ ssl:recv(Socket, 0);
+ ({recv, Length}) ->
+ ssl:recv(Socket, Length);
+ ({send, Data}) ->
+ ssl:send(Socket, Data);
+ ({setopts, L}) ->
+ ssl:setopts(Socket, L);
+ (get) ->
+ {ssl, Socket}
+ end
+ end.
+
+client_request(Transport, Port, Method, TestReqs) ->
+ client_request(sock_fun(Transport, Port), Method, TestReqs).
+
+client_request(SockFun, _Method, []) ->
+ {the_end, {error, closed}} = {the_end, SockFun(recv)},
+ ok;
+client_request(SockFun, Method,
+ [#treq{path=Path, body=Body, xreply=ExReply} | Rest]) ->
+ Request = [atom_to_list(Method), " ", Path, " HTTP/1.1\r\n",
+ client_headers(Body, Rest =:= []),
+ "\r\n",
+ Body],
+ ok = SockFun({send, Request}),
+ case Method of
+ 'GET' ->
+ {ok, {http_response, {1,1}, 200, "OK"}} = SockFun(recv);
+ 'POST' ->
+ {ok, {http_response, {1,1}, 201, "Created"}} = SockFun(recv);
+ 'CONNECT' ->
+ {ok, {http_response, {1,1}, 200, "OK"}} = SockFun(recv)
+ end,
+ ok = SockFun({setopts, [{packet, httph}]}),
+ {ok, {http_header, _, 'Server', _, "MochiWeb" ++ _}} = SockFun(recv),
+ {ok, {http_header, _, 'Date', _, _}} = SockFun(recv),
+ {ok, {http_header, _, 'Content-Type', _, _}} = SockFun(recv),
+ {ok, {http_header, _, 'Content-Length', _, ConLenStr}} = SockFun(recv),
+ ContentLength = list_to_integer(ConLenStr),
+ {ok, http_eoh} = SockFun(recv),
+ ok = SockFun({setopts, [{packet, raw}]}),
+ {payload, ExReply} = {payload, drain_reply(SockFun, ContentLength, <<>>)},
+ ok = SockFun({setopts, [{packet, http}]}),
+ client_request(SockFun, Method, Rest).
+
+client_headers(Body, IsLastRequest) ->
+ ["Host: localhost\r\n",
+ case Body of
+ <<>> ->
+ "";
+ _ ->
+ ["Content-Type: application/octet-stream\r\n",
+ "Content-Length: ", integer_to_list(byte_size(Body)), "\r\n"]
+ end,
+ case IsLastRequest of
+ true ->
+ "Connection: close\r\n";
+ false ->
+ ""
+ end].
+
+drain_reply(_SockFun, 0, Acc) ->
+ Acc;
+drain_reply(SockFun, Length, Acc) ->
+ Sz = erlang:min(Length, 1024),
+ {ok, B} = SockFun({recv, Sz}),
+ drain_reply(SockFun, Length - Sz, <<Acc/bytes, B/bytes>>).
diff --git a/test/mochiweb_test_util.hrl b/test/mochiweb_test_util.hrl
new file mode 100644
index 0000000..99fdc4e
--- /dev/null
+++ b/test/mochiweb_test_util.hrl
@@ -0,0 +1 @@
+-record(treq, {path, body= <<>>, xreply= <<>>}).
diff --git a/test/mochiweb_tests.erl b/test/mochiweb_tests.erl
index c8bc8ac..fdda3fd 100644
--- a/test/mochiweb_tests.erl
+++ b/test/mochiweb_tests.erl
@@ -1,28 +1,9 @@
-module(mochiweb_tests).
-include_lib("eunit/include/eunit.hrl").
-
--record(treq, {path, body= <<>>, xreply= <<>>}).
-
-ssl_cert_opts() ->
- EbinDir = filename:dirname(code:which(?MODULE)),
- CertDir = filename:join([EbinDir, "..", "support", "test-materials"]),
- CertFile = filename:join(CertDir, "test_ssl_cert.pem"),
- KeyFile = filename:join(CertDir, "test_ssl_key.pem"),
- [{certfile, CertFile}, {keyfile, KeyFile}].
+-include("mochiweb_test_util.hrl").
with_server(Transport, ServerFun, ClientFun) ->
- ServerOpts0 = [{ip, "127.0.0.1"}, {port, 0}, {loop, ServerFun}],
- ServerOpts = case Transport of
- plain ->
- ServerOpts0;
- ssl ->
- ServerOpts0 ++ [{ssl, true}, {ssl_opts, ssl_cert_opts()}]
- end,
- {ok, Server} = mochiweb_http:start_link(ServerOpts),
- Port = mochiweb_socket_server:get(Server, port),
- Res = (catch ClientFun(Transport, Port)),
- mochiweb_http:stop(Server),
- Res.
+ mochiweb_test_util:with_server(Transport, ServerFun, ClientFun).
request_test() ->
R = mochiweb_request:new(z, z, "/foo/bar/baz%20wibble+quux?qs=2", z, []),
@@ -148,6 +129,7 @@
ClientFun = new_client_fun('GET', TestReqs),
ok = with_server(Transport, ServerFun, ClientFun),
ok.
+
do_POST(Transport, Size, Times) ->
ServerFun = fun (Req) ->
Body = Req:recv_body(),
@@ -165,86 +147,6 @@
new_client_fun(Method, TestReqs) ->
fun (Transport, Port) ->
- client_request(Transport, Port, Method, TestReqs)
+ mochiweb_test_util:client_request(Transport, Port, Method, TestReqs)
end.
-client_request(Transport, Port, Method, TestReqs) ->
- Opts = [binary, {active, false}, {packet, http}],
- SockFun = case Transport of
- plain ->
- {ok, Socket} = gen_tcp:connect("127.0.0.1", Port, Opts),
- fun (recv) ->
- gen_tcp:recv(Socket, 0);
- ({recv, Length}) ->
- gen_tcp:recv(Socket, Length);
- ({send, Data}) ->
- gen_tcp:send(Socket, Data);
- ({setopts, L}) ->
- inet:setopts(Socket, L)
- end;
- ssl ->
- {ok, Socket} = ssl:connect("127.0.0.1", Port, [{ssl_imp, new} | Opts]),
- fun (recv) ->
- ssl:recv(Socket, 0);
- ({recv, Length}) ->
- ssl:recv(Socket, Length);
- ({send, Data}) ->
- ssl:send(Socket, Data);
- ({setopts, L}) ->
- ssl:setopts(Socket, L)
- end
- end,
- client_request(SockFun, Method, TestReqs).
-
-client_request(SockFun, _Method, []) ->
- {the_end, {error, closed}} = {the_end, SockFun(recv)},
- ok;
-client_request(SockFun, Method,
- [#treq{path=Path, body=Body, xreply=ExReply} | Rest]) ->
- Request = [atom_to_list(Method), " ", Path, " HTTP/1.1\r\n",
- client_headers(Body, Rest =:= []),
- "\r\n",
- Body],
- ok = SockFun({send, Request}),
- case Method of
- 'GET' ->
- {ok, {http_response, {1,1}, 200, "OK"}} = SockFun(recv);
- 'POST' ->
- {ok, {http_response, {1,1}, 201, "Created"}} = SockFun(recv);
- 'CONNECT' ->
- {ok, {http_response, {1,1}, 200, "OK"}} = SockFun(recv)
- end,
- ok = SockFun({setopts, [{packet, httph}]}),
- {ok, {http_header, _, 'Server', _, "MochiWeb" ++ _}} = SockFun(recv),
- {ok, {http_header, _, 'Date', _, _}} = SockFun(recv),
- {ok, {http_header, _, 'Content-Type', _, _}} = SockFun(recv),
- {ok, {http_header, _, 'Content-Length', _, ConLenStr}} = SockFun(recv),
- ContentLength = list_to_integer(ConLenStr),
- {ok, http_eoh} = SockFun(recv),
- ok = SockFun({setopts, [{packet, raw}]}),
- {payload, ExReply} = {payload, drain_reply(SockFun, ContentLength, <<>>)},
- ok = SockFun({setopts, [{packet, http}]}),
- client_request(SockFun, Method, Rest).
-
-client_headers(Body, IsLastRequest) ->
- ["Host: localhost\r\n",
- case Body of
- <<>> ->
- "";
- _ ->
- ["Content-Type: application/octet-stream\r\n",
- "Content-Length: ", integer_to_list(byte_size(Body)), "\r\n"]
- end,
- case IsLastRequest of
- true ->
- "Connection: close\r\n";
- false ->
- ""
- end].
-
-drain_reply(_SockFun, 0, Acc) ->
- Acc;
-drain_reply(SockFun, Length, Acc) ->
- Sz = erlang:min(Length, 1024),
- {ok, B} = SockFun({recv, Sz}),
- drain_reply(SockFun, Length - Sz, <<Acc/bytes, B/bytes>>).
diff --git a/test/mochiweb_websocket_tests.erl b/test/mochiweb_websocket_tests.erl
index 890aa17..5711a55 100644
--- a/test/mochiweb_websocket_tests.erl
+++ b/test/mochiweb_websocket_tests.erl
@@ -82,3 +82,95 @@
mochiweb_websocket:parse_hixie_frames(
<<0,102,111,111,255,0,98,97,114,255>>,
[])).
+
+end_to_end_test_factory(ServerTransport) ->
+ mochiweb_test_util:with_server(
+ ServerTransport,
+ fun end_to_end_server/1,
+ fun (Transport, Port) ->
+ end_to_end_client(mochiweb_test_util:sock_fun(Transport, Port))
+ end).
+
+end_to_end_server(Req) ->
+ ?assertEqual("Upgrade", Req:get_header_value("connection")),
+ ?assertEqual("websocket", Req:get_header_value("upgrade")),
+ {ReentryWs, _ReplyChannel} = mochiweb_websocket:upgrade_connection(
+ Req,
+ fun end_to_end_ws_loop/3),
+ ReentryWs(ok).
+
+end_to_end_ws_loop(Payload, State, ReplyChannel) ->
+ %% Echo server
+ lists:foreach(ReplyChannel, Payload),
+ State.
+
+end_to_end_client(S) ->
+ %% Key and Accept per https://tools.ietf.org/html/rfc6455
+ UpgradeReq = string:join(
+ ["GET / HTTP/1.1",
+ "Host: localhost",
+ "Upgrade: websocket",
+ "Connection: Upgrade",
+ "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==",
+ "",
+ ""], "\r\n"),
+ ok = S({send, UpgradeReq}),
+ {ok, {http_response, {1,1}, 101, _}} = S(recv),
+ ok = S({setopts, [{packet, httph}]}),
+ D = read_expected_headers(
+ S,
+ gb_from_list(
+ [{'Upgrade', "websocket"},
+ {'Connection', "Upgrade"},
+ {'Content-Length', "0"},
+ {"Sec-Websocket-Accept", "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}])),
+ ?assertEqual([], gb_trees:to_list(D)),
+ ok = S({setopts, [{packet, raw}]}),
+ %% The first message sent over telegraph :)
+ SmallMessage = <<"What hath God wrought?">>,
+ ok = S({send,
+ << 1:1, %% Fin
+ 0:1, %% Rsv1
+ 0:1, %% Rsv2
+ 0:1, %% Rsv3
+ 2:4, %% Opcode, 1 = text frame
+ 1:1, %% Mask on
+ (byte_size(SmallMessage)):7, %% Length, <125 case
+ 0:32, %% Mask (trivial)
+ SmallMessage/binary >>}),
+ {ok, WsFrames} = S(recv),
+ << 1:1, %% Fin
+ 0:1, %% Rsv1
+ 0:1, %% Rsv2
+ 0:1, %% Rsv3
+ 1:4, %% Opcode, text frame (all mochiweb suports for now)
+ MsgSize:8, %% Expecting small size
+ SmallMessage/binary >> = WsFrames,
+ ?assertEqual(MsgSize, byte_size(SmallMessage)),
+ ok.
+
+gb_from_list(L) ->
+ lists:foldl(
+ fun ({K, V}, D) -> gb_trees:insert(K, V, D) end,
+ gb_trees:empty(),
+ L).
+
+read_expected_headers(S, D) ->
+ case S(recv) of
+ {ok, http_eoh} ->
+ D;
+ {ok, {http_header, _, K, _, V}} ->
+ case gb_trees:lookup(K, D) of
+ {value, V1} ->
+ ?assertEqual({K, V}, {K, V1}),
+ read_expected_headers(S, gb_trees:delete(K, D));
+ none ->
+ read_expected_headers(S, D)
+ end
+ end.
+
+end_to_end_http_test() ->
+ end_to_end_test_factory(plain).
+
+end_to_end_https_test() ->
+ end_to_end_test_factory(ssl).