Merge pull request #150 from dszoboszlay/bug/unload-race-condition

Fix race condition between meck:unload/1 and calls to the mocked module
diff --git a/src/meck_code_gen.erl b/src/meck_code_gen.erl
index 6010aef..6974244 100644
--- a/src/meck_code_gen.erl
+++ b/src/meck_code_gen.erl
@@ -144,7 +144,7 @@
 -spec exec(CallerPid::pid(), Mod::atom(), Func::atom(), Args::[any()]) ->
         Result::any().
 exec(Pid, Mod, Func, Args) ->
-    case meck_proc:get_result_spec(Mod, Func, Args) of
+    try meck_proc:get_result_spec(Mod, Func, Args) of
         undefined ->
             meck_proc:invalidate(Mod),
             raise(Pid, Mod, Func, Args, error, function_clause);
@@ -160,6 +160,9 @@
             after
                 erase(?CURRENT_CALL)
             end
+    catch
+        error:{not_mocked, Mod} ->
+            apply(Mod, Func, Args)
     end.
 
 -spec handle_exception(CallerPid::pid(), Mod::atom(), Func::atom(),
diff --git a/src/meck_proc.erl b/src/meck_proc.erl
index cb29768..bce3bfe 100644
--- a/src/meck_proc.erl
+++ b/src/meck_proc.erl
@@ -179,7 +179,7 @@
 
 -spec invalidate(Mod::atom()) -> ok.
 invalidate(Mod) ->
-    gen_server(call, Mod, invalidate).
+    gen_server(cast, Mod, invalidate).
 
 -spec stop(Mod::atom()) -> ok.
 stop(Mod) ->
@@ -264,14 +264,14 @@
     end;
 handle_call(reset, _From, S) ->
     {reply, ok, S#state{history = []}};
-handle_call(invalidate, _From, S) ->
-    {reply, ok, S#state{valid = false}};
 handle_call(validate, _From, S) ->
     {reply, S#state.valid, S};
 handle_call(stop, _From, S) ->
     {stop, normal, ok, S}.
 
 %% @hidden
+handle_cast(invalidate, S) ->
+    {noreply, S#state{valid = false}};
 handle_cast({add_history, HistoryRecord}, S = #state{history = undefined,
                                                      trackers = Trackers}) ->
     UpdTracker = update_trackers(HistoryRecord, Trackers),
diff --git a/test/meck_tests.erl b/test/meck_tests.erl
index fdd517b..6215e41 100644
--- a/test/meck_tests.erl
+++ b/test/meck_tests.erl
@@ -1082,6 +1082,30 @@
     ?assert(not lists:member(self(), Links)),
     ok = meck:unload(mymod).
 
+%% @doc A concurrent process calling into the mocked module while it's
+%% being unloaded gets either the mocked response or the original
+%% response, but won't crash.
+atomic_unload_test() ->
+    ok = meck:new(meck_test_module),
+    ok = meck:expect(meck_test_module, a, fun () -> c end),
+
+    %% Suspend the meck_proc in order to ensure all messages are in
+    %% its inbox in the correct order before it would process them
+    Proc = meck_util:proc_name(meck_test_module),
+    sys:suspend(Proc),
+    StopReq = concurrent_req(
+                Proc,
+                fun () -> ?assertEqual(ok, meck:unload(meck_test_module)) end),
+    SpecReq = concurrent_req(
+                Proc,
+                fun () -> ?assertMatch(V when V =:= a orelse V =:= c,
+                                       meck_test_module:a())
+                end),
+    sys:resume(Proc),
+
+    ?assertEqual(normal, wait_concurrent_req(StopReq)),
+    ?assertEqual(normal, wait_concurrent_req(SpecReq)).
+
 %% @doc Exception is thrown when you run expect on a non-existing (and not yet
 %% mocked) module.
 expect_without_new_test() ->
@@ -1439,3 +1463,49 @@
 assert_called(Mod, Function, Args, Pid, WasCalled) ->
     ?assertEqual(WasCalled, meck:called(Mod, Function, Args, Pid)),
     ?assert(meck:validate(Mod)).
+
+%% @doc Spawn a new process to concurrently call `Fun'. `Fun' is
+%% expected to send a request to the specified process, and this
+%% function will wait for this message to arrive. (Therefore the
+%% process should be suspended and not consuming its message queue.)
+%%
+%% The returned request handle can be used later in in {@link
+%% wait_concurrent_req/1} to wait for the concurrent process to
+%% terminate.
+concurrent_req(Name, Fun) when is_atom(Name) ->
+    case whereis(Name) of
+        Pid when is_pid(Pid) ->
+            concurrent_req(Pid, Fun);
+        undefined ->
+            exit(noproc)
+    end;
+concurrent_req(Pid, Fun) when is_pid(Pid) ->
+    {message_queue_len, Msgs} = process_info(Pid, message_queue_len),
+    Req = spawn_monitor(Fun),
+    wait_message(Pid, Msgs + 1, 100),
+    Req.
+
+%% @doc Wait for a concurrent request started with {@link
+%% concurrent_req/2} to terminate. The return value is the exit reason
+%% of the process.
+wait_concurrent_req(Req = {Pid, Monitor}) ->
+    receive
+        {'DOWN', Monitor, process, Pid, Reason} ->
+            Reason
+    after
+        1000 ->
+            exit(Pid, kill),
+            wait_concurrent_req(Req)
+    end.
+
+wait_message(Pid, _ExpMsgs, Retries) when Retries < 0 ->
+    exit(Pid, kill),
+    exit(wait_message_timeout);
+wait_message(Pid, ExpMsgs, Retries) ->
+    {message_queue_len, Msgs} = process_info(Pid, message_queue_len),
+    if Msgs >= ExpMsgs ->
+            ok;
+       true ->
+            timer:sleep(1),
+            wait_message(Pid, ExpMsgs, Retries - 1)
+    end.