Add convenience functions coroutines thread context
diff --git a/log4j-api-kotlin/src/main/kotlin/org/apache/logging/log4j/kotlin/CoroutineThreadContext.kt b/log4j-api-kotlin/src/main/kotlin/org/apache/logging/log4j/kotlin/CoroutineThreadContext.kt
index c266702..b9376f1 100644
--- a/log4j-api-kotlin/src/main/kotlin/org/apache/logging/log4j/kotlin/CoroutineThreadContext.kt
+++ b/log4j-api-kotlin/src/main/kotlin/org/apache/logging/log4j/kotlin/CoroutineThreadContext.kt
@@ -17,7 +17,9 @@
 package org.apache.logging.log4j.kotlin
 
 import edu.umd.cs.findbugs.annotations.SuppressFBWarnings
+import kotlinx.coroutines.CoroutineScope
 import kotlinx.coroutines.ThreadContextElement
+import kotlinx.coroutines.withContext
 import org.apache.logging.log4j.ThreadContext
 import kotlin.coroutines.AbstractCoroutineContextElement
 import kotlin.coroutines.CoroutineContext
@@ -34,7 +36,12 @@
 data class ThreadContextData(
   val map: Map<String, String>? = ContextMap.view,
   val stack: Collection<String>? = ContextStack.view
-)
+) {
+  operator fun plus(data: ThreadContextData) = ThreadContextData(
+    map = this.map.orEmpty() + data.map.orEmpty(),
+    stack = this.stack.orEmpty() + data.stack.orEmpty(),
+  )
+}
 
 /**
  * Log4j2 [ThreadContext] element for [CoroutineContext].
@@ -59,6 +66,9 @@
  * Use `withContext(CoroutineThreadContext()) { ... }` to capture updated map of Thread keys and values
  * for the specified block of code.
  *
+ * See [loggingContext] and [additionalLoggingContext] for convenience functions that make working with a
+ * [CoroutineThreadContext] simpler.
+ *
  * @param contextData the value of [Thread] context map and context stack.
  * Default value is the copy of the current thread's context map that is acquired via
  * [ContextMap.view] and [ContextStack.view].
@@ -95,3 +105,73 @@
     contextData.stack?.let { ContextStack.set(it) }
   }
 }
+
+/**
+ * Convenience function to obtain a [CoroutineThreadContext] with the given map and stack, which default
+ * to no context. Any existing logging context in scope is ignored.
+ *
+ * Example:
+ *
+ * ```
+ * launch(loggingContext(mapOf("kotlin" to "rocks"))) {
+ *     logger.info { "..." }   // The Thread context contains the mapping here
+ * }
+ * ```
+ */
+fun loggingContext(
+  map: Map<String, String>? = null,
+  stack: Collection<String>? = null,
+): CoroutineThreadContext = CoroutineThreadContext(ThreadContextData(map = map, stack = stack))
+
+/**
+ * Convenience function to obtain a [CoroutineThreadContext] that inherits the current context (if any), plus adds
+ * the context from the given map and stack, which default to nothing.
+ *
+ * Example:
+ *
+ * ```
+ * launch(additionalLoggingContext(mapOf("kotlin" to "rocks"))) {
+ *     logger.info { "..." }   // The Thread context contains the mapping plus whatever context was in scope at launch
+ * }
+ * ```
+ */
+fun additionalLoggingContext(
+  map: Map<String, String>? = null,
+  stack: Collection<String>? = null,
+): CoroutineThreadContext = CoroutineThreadContext(ThreadContextData() + ThreadContextData(map = map, stack = stack))
+
+/**
+ * Run the given block with the provided logging context, which default to no context. Any existing logging context
+ * in scope is ignored.
+ *
+ * Example:
+ *
+ * ```
+ * withLoggingContext(mapOf("kotlin" to "rocks")) {
+ *     logger.info { "..." }   // The Thread context contains the mapping
+ * }
+ * ```
+ */
+suspend fun <R> withLoggingContext(
+  map: Map<String, String>? = null,
+  stack: Collection<String>? = null,
+  block: suspend CoroutineScope.() -> R,
+): R = withContext(loggingContext(map, stack), block)
+
+/**
+ * Run the given block with the provided additional logging context. The given context is added to any existing
+ * logging context in scope.
+ *
+ * Example:
+ *
+ * ```
+ * withAdditionalLoggingContext(mapOf("kotlin" to "rocks")) {
+ *     logger.info { "..." }   // The Thread context contains the mapping plus whatever context was in the scope previously
+ * }
+ * ```
+ */
+suspend fun <R> withAdditionalLoggingContext(
+  map: Map<String, String>? = null,
+  stack: Collection<String>? = null,
+  block: suspend CoroutineScope.() -> R,
+): R = withContext(additionalLoggingContext(map, stack), block)
diff --git a/log4j-api-kotlin/src/test/kotlin/org.apache.logging.log4j.kotlin/ThreadContextTest.kt b/log4j-api-kotlin/src/test/kotlin/org.apache.logging.log4j.kotlin/ThreadContextTest.kt
index 64f71ab..f6ed109 100644
--- a/log4j-api-kotlin/src/test/kotlin/org.apache.logging.log4j.kotlin/ThreadContextTest.kt
+++ b/log4j-api-kotlin/src/test/kotlin/org.apache.logging.log4j.kotlin/ThreadContextTest.kt
@@ -53,6 +53,10 @@
       assertNull(ContextMap["myKey"])
       assertTrue(ContextStack.empty)
     }.join()
+    GlobalScope.launch(loggingContext()) {
+      assertNull(ContextMap["myKey"])
+      assertTrue(ContextStack.empty)
+    }.join()
   }
 
   @DelicateCoroutinesApi
@@ -65,6 +69,10 @@
       assertEquals("myValue", ContextMap["myKey"])
       assertEquals("test", ContextStack.peek())
     }.join()
+    GlobalScope.launch(additionalLoggingContext()) {
+      assertEquals("myValue", ContextMap["myKey"])
+      assertEquals("test", ContextStack.peek())
+    }.join()
   }
 
   @Test
@@ -75,11 +83,37 @@
     withContext(CoroutineThreadContext()) {
       ContextMap["myKey"] = "myValue2"
       ContextStack.push("test2")
-      // Scoped launch with inherited MDContext element
+      // Scoped launch with non-inherited MDContext element
       launch(Dispatchers.Default) {
         assertEquals("myValue", ContextMap["myKey"])
         assertEquals("test", ContextStack.peek())
       }
+      // Scoped launch with non-inherited MDContext element
+      launch(Dispatchers.Default + loggingContext()) {
+        assertTrue(ContextMap.empty)
+        assertTrue(ContextStack.empty)
+      }
+      // Scoped launch with non-inherited MDContext element
+      launch(Dispatchers.Default + loggingContext(mapOf("myKey2" to "myValue2"), listOf("test3"))) {
+        assertEquals(null, ContextMap["myKey"])
+        assertEquals("myValue2", ContextMap["myKey2"])
+        assertEquals(listOf("test3"), ContextStack.view.asList())
+      }
+      // Scoped launch with inherited MDContext element
+      launch(Dispatchers.Default + CoroutineThreadContext()) {
+        assertEquals("myValue2", ContextMap["myKey"])
+        assertEquals("test2", ContextStack.peek())
+      }
+      // Scoped launch with inherited plus additional empty MDContext element
+      launch(Dispatchers.Default + additionalLoggingContext()) {
+        assertEquals("myValue2", ContextMap["myKey"])
+        assertEquals(listOf("test", "test2"), ContextStack.view.asList())
+      }
+      launch(Dispatchers.Default + additionalLoggingContext(mapOf("myKey2" to "myValue2"), listOf("test3"))) {
+        assertEquals("myValue2", ContextMap["myKey"])
+        assertEquals("myValue2", ContextMap["myKey2"])
+        assertEquals(listOf("test", "test2", "test3"), ContextStack.view.asList())
+      }
     }
     assertEquals("myValue", ContextMap["myKey"])
     assertEquals("test", ContextStack.peek())
@@ -104,6 +138,10 @@
       assertEquals("myValue", ContextMap["myKey"])
       assertEquals("test", ContextStack.peek())
     }
+    runBlocking(additionalLoggingContext()) {
+      assertEquals("myValue", ContextMap["myKey"])
+      assertEquals("test", ContextStack.peek())
+    }
   }
 
   @Test
@@ -112,10 +150,18 @@
       assertTrue(ContextMap.empty)
       assertTrue(ContextStack.empty)
     }
+    runBlocking(loggingContext()) {
+      assertTrue(ContextMap.empty)
+      assertTrue(ContextStack.empty)
+    }
+    runBlocking(additionalLoggingContext()) {
+      assertTrue(ContextMap.empty)
+      assertTrue(ContextStack.empty)
+    }
   }
 
   @Test
-  fun `Context with context`() = runBlocking {
+  fun `Context using withContext`() = runBlocking {
     ContextMap["myKey"] = "myValue"
     ContextStack.push("test")
     val mainDispatcher = coroutineContext[ContinuationInterceptor]!!
@@ -127,6 +173,61 @@
         assertEquals("test", ContextStack.peek())
       }
     }
+    withContext(Dispatchers.Default + additionalLoggingContext()) {
+      assertEquals("myValue", ContextMap["myKey"])
+      assertEquals("test", ContextStack.peek())
+      withContext(mainDispatcher) {
+        assertEquals("myValue", ContextMap["myKey"])
+        assertEquals("test", ContextStack.peek())
+      }
+    }
+    withContext(Dispatchers.Default + additionalLoggingContext(mapOf("myKey2" to "myValue2"), listOf("test2"))) {
+      assertEquals("myValue", ContextMap["myKey"])
+      assertEquals("myValue2", ContextMap["myKey2"])
+      assertEquals(listOf("test", "test2"), ContextStack.view.asList())
+      withContext(mainDispatcher) {
+        assertEquals("myValue", ContextMap["myKey"])
+        assertEquals("myValue2", ContextMap["myKey2"])
+        assertEquals(listOf("test", "test2"), ContextStack.view.asList())
+      }
+    }
+    withContext(Dispatchers.Default + loggingContext(mapOf("myKey2" to "myValue2"), listOf("test2"))) {
+      assertEquals(null, ContextMap["myKey"])
+      assertEquals("myValue2", ContextMap["myKey2"])
+      assertEquals(listOf("test2"), ContextStack.view.asList())
+      withContext(mainDispatcher) {
+        assertEquals(null, ContextMap["myKey"])
+        assertEquals("myValue2", ContextMap["myKey2"])
+        assertEquals(listOf("test2"), ContextStack.view.asList())
+      }
+    }
+  }
+
+  @Test
+  fun `Context using withLoggingContext`() = runBlocking {
+    ContextMap["myKey"] = "myValue"
+    ContextStack.push("test")
+    val mainDispatcher = coroutineContext[ContinuationInterceptor]!!
+    withAdditionalLoggingContext(mapOf("myKey2" to "myValue2"), listOf("test2")) {
+      assertEquals("myValue", ContextMap["myKey"])
+      assertEquals("myValue2", ContextMap["myKey2"])
+      assertEquals(listOf("test", "test2"), ContextStack.view.asList())
+      withContext(mainDispatcher) {
+        assertEquals("myValue", ContextMap["myKey"])
+        assertEquals("myValue2", ContextMap["myKey2"])
+        assertEquals(listOf("test", "test2"), ContextStack.view.asList())
+      }
+    }
+    withLoggingContext(mapOf("myKey2" to "myValue2"), listOf("test2")) {
+      assertEquals(null, ContextMap["myKey"])
+      assertEquals("myValue2", ContextMap["myKey2"])
+      assertEquals(listOf("test2"), ContextStack.view.asList())
+      withContext(mainDispatcher) {
+        assertEquals(null, ContextMap["myKey"])
+        assertEquals("myValue2", ContextMap["myKey2"])
+        assertEquals(listOf("test2"), ContextStack.view.asList())
+      }
+    }
   }
 
   @Test
@@ -139,5 +240,19 @@
     }
     assertTrue(ContextMap.empty)
     assertTrue(ContextStack.empty)
+
+    withContext(loggingContext(mapOf("myKey" to "myValue"), listOf("test"))) {
+      assertEquals("myValue", ContextMap["myKey"])
+      assertEquals("test", ContextStack.peek())
+    }
+    assertTrue(ContextMap.empty)
+    assertTrue(ContextStack.empty)
+
+    withLoggingContext(mapOf("myKey" to "myValue"), listOf("test")) {
+      assertEquals("myValue", ContextMap["myKey"])
+      assertEquals("test", ContextStack.peek())
+    }
+    assertTrue(ContextMap.empty)
+    assertTrue(ContextStack.empty)
   }
 }