refactor and general tidy-up of websockets detection and examples. origin validation.
diff --git a/examples/websockets/README.txt b/examples/websockets/README.txt
new file mode 100644
index 0000000..378e274
--- /dev/null
+++ b/examples/websockets/README.txt
@@ -0,0 +1,9 @@
+There are two ways to use websockets, active and passive.
+
+In passive mode, you can call the blocking get_data() to read a frame
+from the websocket. Passive mode is best suited to a transmit-only
+server, or where you read from the client then send, then read in a loop.
+
+Active mode allows for fully asynchronous sending and receiving.
+Your process gets sent {frame, Msg} when messages arrive, and you can
+call send at any time.
diff --git a/examples/websockets/websockets_active.erl b/examples/websockets/websockets_active.erl
index 53fa3ed..1ed6ba3 100644
--- a/examples/websockets/websockets_active.erl
+++ b/examples/websockets/websockets_active.erl
@@ -8,9 +8,16 @@
start(Options) ->
{DocRoot, Options1} = get_option(docroot, Options),
Loop = fun (Req) -> ?MODULE:loop(Req, DocRoot) end,
- % websockets options:
+ % How we validate origin for cross-domain checks:
+ OriginValidator = fun(Origin) ->
+ io:format("Origin '~s' -> OK~n", [Origin]),
+ true
+ end,
+ % websocket options
WsOpts = [ {active, true},
+ {origin_validator, OriginValidator},
{loop, {?MODULE, wsloop_active}} ],
+
mochiweb_http:start([{name, ?MODULE},
{loop, Loop},
{websocket_opts, WsOpts} | Options1]).
@@ -19,7 +26,7 @@
mochiweb_http:stop(?MODULE).
wsloop_active(Pid) ->
- mochiweb_websocket_delegate:send(Pid, "WELCOME MSG!"),
+ mochiweb_websocket_delegate:send(Pid, "WELCOME MSG FROM THE SERVER!"),
wsloop_active0(Pid).
wsloop_active0(Pid) ->
@@ -30,9 +37,8 @@
{error, Reason} ->
io:format("client api got error ~p~n", [Reason]),
ok;
- % {legacy_frame, M} or {utf8_frame, M}
- {_, X} ->
- Msg = io_lib:format("SRVER_GOT: ~p", [X]),
+ {frame, Frame} ->
+ Msg = ["Dear client, thanks for sending us this msg: ", Frame],
mochiweb_websocket_delegate:send(Pid, Msg),
wsloop_active0(Pid)
after 10000 ->
diff --git a/examples/websockets/websockets_passive.erl b/examples/websockets/websockets_passive.erl
index 7632799..4e4151f 100644
--- a/examples/websockets/websockets_passive.erl
+++ b/examples/websockets/websockets_passive.erl
@@ -8,8 +8,14 @@
start(Options) ->
{DocRoot, Options1} = get_option(docroot, Options),
Loop = fun (Req) -> ?MODULE:loop(Req, DocRoot) end,
+ % How we validate origin for cross-domain checks:
+ OriginValidator = fun(Origin) ->
+ io:format("Origin '~s' -> OK~n", [Origin]),
+ true
+ end,
% websockets options:
WsOpts = [ {active, false},
+ {origin_validator, OriginValidator},
{loop, {?MODULE, wsloop}} ],
mochiweb_http:start([{name, ?MODULE},
{loop, Loop},
@@ -19,23 +25,18 @@
mochiweb_http:stop(?MODULE).
wsloop(Ws) ->
- io:format("Websocket request, path: ~p~n", [Ws:get(path)]),
+ Ws:send("WELCOME MSG FROM THE SERVER part 1/2"),
+ Ws:send("WELCOME MSG FROM THE SERVER part 2/2"),
+ wsloop0(Ws).
+
+wsloop0(Ws) ->
case Ws:get_data() of
- closed -> ok;
+ closed -> ok;
closing -> ok;
timeout -> timeout;
-
- % older websockets spec which is in the wild, messages are framed with
- % 0x00...0xFF
- {legacy_frame, Body} ->
- Ws:send(["YOU SENT US LEGACY FRAME: ", Body]),
- wsloop(Ws);
-
- % current spec, each message has a 0xFF/<64bit length> header
- % and must contain utf8 bytestream
- {utf8_frame, Body} ->
- Ws:send(["YOU SENT US MODERN FRAME: ", Body]),
- wsloop(Ws)
+ {frame, Body} ->
+ Ws:send(["Dear client, thanks for sending this message: ", Body]),
+ wsloop0(Ws)
end.
loop(Req, DocRoot) ->
diff --git a/src/mochiweb_http.erl b/src/mochiweb_http.erl
index 7da890c..55bef18 100644
--- a/src/mochiweb_http.erl
+++ b/src/mochiweb_http.erl
@@ -17,22 +17,32 @@
-define(DEFAULTS, [{name, ?MODULE},
{port, 8888}]).
-% client loop holds fun/info on how to hand off request to client code
--record(body, {http_loop, websocket_loop, websocket_active}).
+%% unless specified, we accept any origin:
+-define(DEFAULT_ORIGIN_VALIDATOR, fun(_Origin) -> true end).
+
+-record(body, {http_loop, % normal http handler fun
+ websocket_loop, % websocket handler fun
+ websocket_active, % boolean: active or passive api
+ websocket_origin_validator % fun(Origin) -> true/false
+ }).
parse_options(Options) ->
HttpLoop = proplists:get_value(loop, Options),
case proplists:get_value(websocket_opts, Options) of
WsProps when is_list(WsProps) ->
WsLoop = proplists:get_value(loop, WsProps),
+ WsOrigin = proplists:get_value(origin_validator, WsProps,
+ ?DEFAULT_ORIGIN_VALIDATOR),
WsActive = proplists:get_value(active, WsProps, false);
_ ->
WsLoop = undefined,
+ WsOrigin = undefined,
WsActive = undefined
end,
- Body = #body{http_loop = HttpLoop,
- websocket_loop = WsLoop,
- websocket_active = WsActive},
+ Body = #body{http_loop = HttpLoop,
+ websocket_loop = WsLoop,
+ websocket_origin_validator = WsOrigin,
+ websocket_active = WsActive},
Loop = fun (S) -> ?MODULE:loop(S, Body) end,
Options1 = [{loop, Loop} |
proplists:delete(loop,
@@ -135,35 +145,11 @@
case mochiweb_socket:recv(Socket, 0, ?HEADERS_RECV_TIMEOUT) of
{ok, http_eoh} ->
mochiweb_socket:setopts(Socket, [{packet, raw}]),
- %% Examine headers to decide if this a websocket upgrade request:
H = mochiweb_headers:make(Headers),
- HeaderFun = fun(K) ->
- case mochiweb_headers:get_value(K, H) of
- undefined -> "";
- V -> string:to_lower(V)
- end
- end,
- case {HeaderFun("upgrade"), HeaderFun("connection")} of
- {"websocket", "upgrade"} ->
- io:format("notmal -> ws~n",[]),
- {_, {abs_path,Path}, _} = Request,
- ok = websocket_init(Socket, Path, H),
- case Body#body.websocket_active of
- true ->
- {ok, WSPid} = mochiweb_websocket_delegate:start_link(Path, H, self()),
- mochiweb_websocket_delegate:go(WSPid, Socket),
- call_body(Body#body.websocket_loop, WSPid);
- false ->
- WsReq = mochiweb_wsrequest:new(Socket, Path, H),
- call_body(Body#body.websocket_loop, WsReq);
- undefined ->
- Req = mochiweb:new_request({Socket, Request,
- lists:reverse(Headers)}),
- io:format("Websocket upgrade requested, but no websocket handler provided: ~s~n",[Req:get(path)]),
- Req:not_found()
- end;
- X -> %% not websocket:
- io:format("notmal~p~n",[X]),
+ case is_websocket_upgrade_requested(H) of
+ true ->
+ headers_ws_upgrade(Socket, Request, Headers, Body, H);
+ false ->
Req = mochiweb:new_request({Socket, Request,
lists:reverse(Headers)}),
call_body(Body#body.http_loop, Req),
@@ -179,6 +165,36 @@
handle_invalid_request(Socket, Request, Headers)
end.
+% checks if these headers are a valid websocket upgrade request
+is_websocket_upgrade_requested(H) ->
+ Hdr = fun(K) -> case mochiweb_headers:get_value(K, H) of
+ undefined -> undefined;
+ V when is_list(V) -> string:to_lower(V)
+ end
+ end,
+ Hdr("upgrade") == "websocket" andalso Hdr("connection") == "upgrade".
+
+% entered once we've seen valid websocket upgrade headers
+headers_ws_upgrade(Socket, Request, Headers, Body, H) ->
+ {_, {abs_path,Path}, _} = Request,
+ OriginValidator = Body#body.websocket_origin_validator,
+ % websocket_init will exit() if anything looks fishy
+ websocket_init(Socket, Path, H, OriginValidator),
+ case Body#body.websocket_active of
+ true ->
+ {ok, WSPid} = mochiweb_websocket_delegate:start_link(Path,H,self()),
+ mochiweb_websocket_delegate:go(WSPid, Socket),
+ call_body(Body#body.websocket_loop, WSPid);
+ false ->
+ WsReq = mochiweb_wsrequest:new(Socket, Path, H),
+ call_body(Body#body.websocket_loop, WsReq);
+ undefined ->
+ % what is the correct way to respond when a server doesn't
+ % support websockets, but the client requests the upgrade?
+ % use a 400 for now:
+ handle_invalid_request(Socket, Request, Headers)
+ end.
+
call_body({M, F}, Req) ->
M:F(Req);
call_body(Body, Req) ->
@@ -246,9 +262,19 @@
%% Respond to the websocket upgrade request with valid signature
%% or exit() if any of the sec- headers look suspicious.
-websocket_init(Socket, Path, Headers) ->
+websocket_init(Socket, Path, Headers, OriginValidator) ->
+ Origin = mochiweb_headers:get_value("origin", Headers),
+ %% If origin is invalid, just uncerimoniously close the socket
+ case Origin /= undefiend andalso OriginValidator(Origin) == true of
+ true ->
+ websocket_init_with_origin_validated(Socket, Path, Headers, Origin);
+ false ->
+ mochiweb_socket:close(Socket),
+ exit(websocket_origin_check_failed)
+ end.
+
+websocket_init_with_origin_validated(Socket, Path, Headers, _Origin) ->
Host = mochiweb_headers:get_value("Host", Headers),
- %Origin = mochiweb_headers:get_value("origin", Headers), % TODO
SubProto = mochiweb_headers:get_value("Sec-Websocket-Protocol", Headers),
Key1 = mochiweb_headers:get_value("Sec-Websocket-Key1", Headers),
Key2 = mochiweb_headers:get_value("Sec-Websocket-Key2", Headers),
diff --git a/src/mochiweb_websocket_delegate.erl b/src/mochiweb_websocket_delegate.erl
index e77c79e..6682b1b 100644
--- a/src/mochiweb_websocket_delegate.erl
+++ b/src/mochiweb_websocket_delegate.erl
@@ -10,7 +10,9 @@
%% an older version of the websocket spec, where messages are framed 0x00...0xFF
%% so the newer protocol with length headers has not been tested with a browser.
%%
-%% Guarantees that 'closed' will be sent to the client pid once the socket dies
+%% Guarantees that 'closed' will be sent to the client pid once the socket dies,
+%% Messages are:
+%% closed, {error, Reason}, {frame, Data}
-module(mochiweb_websocket_delegate).
-behaviour(gen_server).
@@ -40,7 +42,6 @@
gen_server:cast(Pid, {go, Socket}).
send(Pid, Msg) ->
- io:format("send:~s~n",[Msg]),
gen_server:call(Pid, {send, Msg}).
close(Pid) ->
@@ -87,20 +88,16 @@
{noreply, State#state{socket=Socket}}.
handle_info({'EXIT', _, _}, State) ->
- io:format("TRAP EXIT~n",[]),
State#state.dest ! closed,
{stop, normal, State};
handle_info({tcp_closed, Sock}, State = #state{socket=Sock}) ->
- io:format("TCP CLOSED~n",[]),
State#state.dest ! closed,
{stop, normal, State};
handle_info({tcp_error, Sock, Reason}, State = #state{socket=Sock}) ->
- io:format("TCP ERROR~n",[]),
State#state.dest ! {error, Reason},
State#state.dest ! closed,
{stop, normal, State};
handle_info({tcp, Sock, Data}, State = #state{socket=Sock, buffer=Buffer}) ->
- %mochiweb_socket:setopts(Sock, [{active, once}]),
NewState = process_data(State#state{buffer= <<Buffer/binary,Data/binary>>}),
{noreply, NewState}.
@@ -113,38 +110,33 @@
%%% Internal functions
process_data(State = #state{buffer= <<>>}) ->
- %io:format("A 0~n", []),
State;
process_data(State = #state{buffer= <<FrameType:8,Buffer/binary>>, ft=undefined}) ->
- %io:format("A 1~n", []),
process_data(State#state{buffer=Buffer, ft=FrameType, partial= <<>>});
% "Legacy" frames, 0x00...0xFF
% or modern closing handshake 0x00{8}
process_data(State = #state{buffer= <<0,0,0,0,0,0,0,0, Buffer/binary>>, ft=0}) ->
- %io:format("A 2~n", []),
State#state.dest ! closing_handshake,
process_data(State#state{buffer=Buffer, ft=undefined});
process_data(State = #state{buffer= <<255, Rest/binary>>, ft=0}) ->
- %io:format("A 3~n", []),
- State#state.dest ! {legacy_frame, State#state.partial},
+ % message received in full
+ State#state.dest ! {frame, State#state.partial},
process_data(State#state{partial= <<>>, ft=undefined, buffer=Rest});
process_data(State = #state{buffer= <<Byte:8, Rest/binary>>, ft=0, partial=Partial}) ->
- %io:format("A 4, byte=~p state:~p~n", [Byte,State]),
NewPartial = case Partial of <<>> -> <<Byte>> ; _ -> <<Partial/binary, <<Byte>>/binary>> end,
NewState = State#state{buffer=Rest, partial=NewPartial},
process_data(NewState);
% "Modern" frames, starting with 0xFF, followed by 64 bit length
process_data(State = #state{buffer= <<Len:64/unsigned-integer,Buffer/binary>>, ft=255, flen=undefined}) ->
- %io:format("A 5~n", []),
BitsLen = Len*8,
case Buffer of
<<Frame:BitsLen/binary, Rest/binary>> ->
- State#state.dest ! {utf8_frame, Frame},
+ State#state.dest ! {frame, Frame},
process_data(State#state{ft=undefined, flen=undefined, buffer=Rest});
_ ->
@@ -152,11 +144,10 @@
end;
process_data(State = #state{buffer=Buffer, ft=255, flen=Len}) when is_integer(Len) ->
- %io:format("A 6~n", []),
BitsLen = Len*8,
case Buffer of
<<Frame:BitsLen/binary, Rest/binary>> ->
- State#state.dest ! {utf8_frame, Frame},
+ State#state.dest ! {frame, Frame},
process_data(State#state{ft=undefined, flen=undefined, buffer=Rest});
_ ->
diff --git a/src/mochiweb_wsrequest.erl b/src/mochiweb_wsrequest.erl
index 1d107c2..9991e3b 100644
--- a/src/mochiweb_wsrequest.erl
+++ b/src/mochiweb_wsrequest.erl
@@ -27,7 +27,7 @@
{ok, <<Len:64/unsigned-integer>>} =
mochiweb_socket:recv(Socket, 8, ?TIMEOUT),
{ok, Frame} = mochiweb_socket:recv(Socket, Len, ?TIMEOUT),
- {utf8_frame, Frame};
+ {frame, Frame};
<<0>> -> % modern close request, or older no-length-frame msg
case mochiweb_socket:recv(Socket, 1, ?TIMEOUT) of
{ok, <<0>>} ->
@@ -37,12 +37,12 @@
{ok, <<255>>} ->
% empty legacy frame.
erlang:put(legacy, true),
- {legacy_frame, <<>>};
+ {frame, <<>>};
{ok, Byte2} ->
% Read up to the first 0xFF for the body
erlang:put(legacy, true),
Body = read_until_FF(Socket, Byte2),
- {legacy_frame, Body}
+ {frame, Body}
end
end
end.