Restore context correctly
diff --git a/log4j-api-kotlin/src/main/kotlin/org/apache/logging/log4j/kotlin/CoroutineThreadContext.kt b/log4j-api-kotlin/src/main/kotlin/org/apache/logging/log4j/kotlin/CoroutineThreadContext.kt
index 9c38f56..c266702 100644
--- a/log4j-api-kotlin/src/main/kotlin/org/apache/logging/log4j/kotlin/CoroutineThreadContext.kt
+++ b/log4j-api-kotlin/src/main/kotlin/org/apache/logging/log4j/kotlin/CoroutineThreadContext.kt
@@ -89,7 +89,9 @@
   }
 
   private fun setCurrent(contextData: ThreadContextData) {
-    contextData.map?.let { ContextMap += it } ?: ContextMap.clear()
-    contextData.stack?.let { ContextStack.set(it) } ?: ContextStack.clear()
+    ContextMap.clear()
+    ContextStack.clear()
+    contextData.map?.let { ContextMap += it }
+    contextData.stack?.let { ContextStack.set(it) }
   }
 }
diff --git a/log4j-api-kotlin/src/test/kotlin/org.apache.logging.log4j.kotlin/ThreadContextTest.kt b/log4j-api-kotlin/src/test/kotlin/org.apache.logging.log4j.kotlin/ThreadContextTest.kt
index f725f26..daab166 100644
--- a/log4j-api-kotlin/src/test/kotlin/org.apache.logging.log4j.kotlin/ThreadContextTest.kt
+++ b/log4j-api-kotlin/src/test/kotlin/org.apache.logging.log4j.kotlin/ThreadContextTest.kt
@@ -38,6 +38,7 @@
     ContextStack.clear()
   }
 
+  @DelicateCoroutinesApi
   @Test
   fun `Context is not passed by default between coroutines`() = runBlocking {
     ContextMap["myKey"] = "myValue"
@@ -49,6 +50,7 @@
     }.join()
   }
 
+  @DelicateCoroutinesApi
   @Test
   fun `Context can be passed between coroutines`() = runBlocking {
     ContextMap["myKey"] = "myValue"
@@ -121,4 +123,16 @@
       }
     }
   }
+
+  @Test
+  fun `Context is restored after a context block is complete`() = runBlocking {
+    assertTrue(ContextMap.empty)
+    assertTrue(ContextStack.empty)
+    withContext(CoroutineThreadContext(ThreadContextData(mapOf("myKey" to "myValue"), listOf("test")))) {
+      assertEquals("myValue", ContextMap["myKey"])
+      assertEquals("test", ContextStack.peek())
+    }
+    assertTrue(ContextMap.empty)
+    assertTrue(ContextStack.empty)
+  }
 }