Merge pull request #150 from dszoboszlay/bug/unload-race-condition
Fix race condition between meck:unload/1 and calls to the mocked module
diff --git a/src/meck_code_gen.erl b/src/meck_code_gen.erl
index 6010aef..6974244 100644
--- a/src/meck_code_gen.erl
+++ b/src/meck_code_gen.erl
@@ -144,7 +144,7 @@
-spec exec(CallerPid::pid(), Mod::atom(), Func::atom(), Args::[any()]) ->
Result::any().
exec(Pid, Mod, Func, Args) ->
- case meck_proc:get_result_spec(Mod, Func, Args) of
+ try meck_proc:get_result_spec(Mod, Func, Args) of
undefined ->
meck_proc:invalidate(Mod),
raise(Pid, Mod, Func, Args, error, function_clause);
@@ -160,6 +160,9 @@
after
erase(?CURRENT_CALL)
end
+ catch
+ error:{not_mocked, Mod} ->
+ apply(Mod, Func, Args)
end.
-spec handle_exception(CallerPid::pid(), Mod::atom(), Func::atom(),
diff --git a/src/meck_proc.erl b/src/meck_proc.erl
index cb29768..bce3bfe 100644
--- a/src/meck_proc.erl
+++ b/src/meck_proc.erl
@@ -179,7 +179,7 @@
-spec invalidate(Mod::atom()) -> ok.
invalidate(Mod) ->
- gen_server(call, Mod, invalidate).
+ gen_server(cast, Mod, invalidate).
-spec stop(Mod::atom()) -> ok.
stop(Mod) ->
@@ -264,14 +264,14 @@
end;
handle_call(reset, _From, S) ->
{reply, ok, S#state{history = []}};
-handle_call(invalidate, _From, S) ->
- {reply, ok, S#state{valid = false}};
handle_call(validate, _From, S) ->
{reply, S#state.valid, S};
handle_call(stop, _From, S) ->
{stop, normal, ok, S}.
%% @hidden
+handle_cast(invalidate, S) ->
+ {noreply, S#state{valid = false}};
handle_cast({add_history, HistoryRecord}, S = #state{history = undefined,
trackers = Trackers}) ->
UpdTracker = update_trackers(HistoryRecord, Trackers),
diff --git a/test/meck_tests.erl b/test/meck_tests.erl
index fdd517b..6215e41 100644
--- a/test/meck_tests.erl
+++ b/test/meck_tests.erl
@@ -1082,6 +1082,30 @@
?assert(not lists:member(self(), Links)),
ok = meck:unload(mymod).
+%% @doc A concurrent process calling into the mocked module while it's
+%% being unloaded gets either the mocked response or the original
+%% response, but won't crash.
+atomic_unload_test() ->
+ ok = meck:new(meck_test_module),
+ ok = meck:expect(meck_test_module, a, fun () -> c end),
+
+ %% Suspend the meck_proc in order to ensure all messages are in
+ %% its inbox in the correct order before it would process them
+ Proc = meck_util:proc_name(meck_test_module),
+ sys:suspend(Proc),
+ StopReq = concurrent_req(
+ Proc,
+ fun () -> ?assertEqual(ok, meck:unload(meck_test_module)) end),
+ SpecReq = concurrent_req(
+ Proc,
+ fun () -> ?assertMatch(V when V =:= a orelse V =:= c,
+ meck_test_module:a())
+ end),
+ sys:resume(Proc),
+
+ ?assertEqual(normal, wait_concurrent_req(StopReq)),
+ ?assertEqual(normal, wait_concurrent_req(SpecReq)).
+
%% @doc Exception is thrown when you run expect on a non-existing (and not yet
%% mocked) module.
expect_without_new_test() ->
@@ -1439,3 +1463,49 @@
assert_called(Mod, Function, Args, Pid, WasCalled) ->
?assertEqual(WasCalled, meck:called(Mod, Function, Args, Pid)),
?assert(meck:validate(Mod)).
+
+%% @doc Spawn a new process to concurrently call `Fun'. `Fun' is
+%% expected to send a request to the specified process, and this
+%% function will wait for this message to arrive. (Therefore the
+%% process should be suspended and not consuming its message queue.)
+%%
+%% The returned request handle can be used later in in {@link
+%% wait_concurrent_req/1} to wait for the concurrent process to
+%% terminate.
+concurrent_req(Name, Fun) when is_atom(Name) ->
+ case whereis(Name) of
+ Pid when is_pid(Pid) ->
+ concurrent_req(Pid, Fun);
+ undefined ->
+ exit(noproc)
+ end;
+concurrent_req(Pid, Fun) when is_pid(Pid) ->
+ {message_queue_len, Msgs} = process_info(Pid, message_queue_len),
+ Req = spawn_monitor(Fun),
+ wait_message(Pid, Msgs + 1, 100),
+ Req.
+
+%% @doc Wait for a concurrent request started with {@link
+%% concurrent_req/2} to terminate. The return value is the exit reason
+%% of the process.
+wait_concurrent_req(Req = {Pid, Monitor}) ->
+ receive
+ {'DOWN', Monitor, process, Pid, Reason} ->
+ Reason
+ after
+ 1000 ->
+ exit(Pid, kill),
+ wait_concurrent_req(Req)
+ end.
+
+wait_message(Pid, _ExpMsgs, Retries) when Retries < 0 ->
+ exit(Pid, kill),
+ exit(wait_message_timeout);
+wait_message(Pid, ExpMsgs, Retries) ->
+ {message_queue_len, Msgs} = process_info(Pid, message_queue_len),
+ if Msgs >= ExpMsgs ->
+ ok;
+ true ->
+ timer:sleep(1),
+ wait_message(Pid, ExpMsgs, Retries - 1)
+ end.