Add convenience functions coroutines thread context
diff --git a/log4j-api-kotlin/src/main/kotlin/org/apache/logging/log4j/kotlin/CoroutineThreadContext.kt b/log4j-api-kotlin/src/main/kotlin/org/apache/logging/log4j/kotlin/CoroutineThreadContext.kt
index c266702..b9376f1 100644
--- a/log4j-api-kotlin/src/main/kotlin/org/apache/logging/log4j/kotlin/CoroutineThreadContext.kt
+++ b/log4j-api-kotlin/src/main/kotlin/org/apache/logging/log4j/kotlin/CoroutineThreadContext.kt
@@ -17,7 +17,9 @@
package org.apache.logging.log4j.kotlin
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings
+import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.ThreadContextElement
+import kotlinx.coroutines.withContext
import org.apache.logging.log4j.ThreadContext
import kotlin.coroutines.AbstractCoroutineContextElement
import kotlin.coroutines.CoroutineContext
@@ -34,7 +36,12 @@
data class ThreadContextData(
val map: Map<String, String>? = ContextMap.view,
val stack: Collection<String>? = ContextStack.view
-)
+) {
+ operator fun plus(data: ThreadContextData) = ThreadContextData(
+ map = this.map.orEmpty() + data.map.orEmpty(),
+ stack = this.stack.orEmpty() + data.stack.orEmpty(),
+ )
+}
/**
* Log4j2 [ThreadContext] element for [CoroutineContext].
@@ -59,6 +66,9 @@
* Use `withContext(CoroutineThreadContext()) { ... }` to capture updated map of Thread keys and values
* for the specified block of code.
*
+ * See [loggingContext] and [additionalLoggingContext] for convenience functions that make working with a
+ * [CoroutineThreadContext] simpler.
+ *
* @param contextData the value of [Thread] context map and context stack.
* Default value is the copy of the current thread's context map that is acquired via
* [ContextMap.view] and [ContextStack.view].
@@ -95,3 +105,73 @@
contextData.stack?.let { ContextStack.set(it) }
}
}
+
+/**
+ * Convenience function to obtain a [CoroutineThreadContext] with the given map and stack, which default
+ * to no context. Any existing logging context in scope is ignored.
+ *
+ * Example:
+ *
+ * ```
+ * launch(loggingContext(mapOf("kotlin" to "rocks"))) {
+ * logger.info { "..." } // The Thread context contains the mapping here
+ * }
+ * ```
+ */
+fun loggingContext(
+ map: Map<String, String>? = null,
+ stack: Collection<String>? = null,
+): CoroutineThreadContext = CoroutineThreadContext(ThreadContextData(map = map, stack = stack))
+
+/**
+ * Convenience function to obtain a [CoroutineThreadContext] that inherits the current context (if any), plus adds
+ * the context from the given map and stack, which default to nothing.
+ *
+ * Example:
+ *
+ * ```
+ * launch(additionalLoggingContext(mapOf("kotlin" to "rocks"))) {
+ * logger.info { "..." } // The Thread context contains the mapping plus whatever context was in scope at launch
+ * }
+ * ```
+ */
+fun additionalLoggingContext(
+ map: Map<String, String>? = null,
+ stack: Collection<String>? = null,
+): CoroutineThreadContext = CoroutineThreadContext(ThreadContextData() + ThreadContextData(map = map, stack = stack))
+
+/**
+ * Run the given block with the provided logging context, which default to no context. Any existing logging context
+ * in scope is ignored.
+ *
+ * Example:
+ *
+ * ```
+ * withLoggingContext(mapOf("kotlin" to "rocks")) {
+ * logger.info { "..." } // The Thread context contains the mapping
+ * }
+ * ```
+ */
+suspend fun <R> withLoggingContext(
+ map: Map<String, String>? = null,
+ stack: Collection<String>? = null,
+ block: suspend CoroutineScope.() -> R,
+): R = withContext(loggingContext(map, stack), block)
+
+/**
+ * Run the given block with the provided additional logging context. The given context is added to any existing
+ * logging context in scope.
+ *
+ * Example:
+ *
+ * ```
+ * withAdditionalLoggingContext(mapOf("kotlin" to "rocks")) {
+ * logger.info { "..." } // The Thread context contains the mapping plus whatever context was in the scope previously
+ * }
+ * ```
+ */
+suspend fun <R> withAdditionalLoggingContext(
+ map: Map<String, String>? = null,
+ stack: Collection<String>? = null,
+ block: suspend CoroutineScope.() -> R,
+): R = withContext(additionalLoggingContext(map, stack), block)
diff --git a/log4j-api-kotlin/src/test/kotlin/org.apache.logging.log4j.kotlin/ThreadContextTest.kt b/log4j-api-kotlin/src/test/kotlin/org.apache.logging.log4j.kotlin/ThreadContextTest.kt
index 64f71ab..f6ed109 100644
--- a/log4j-api-kotlin/src/test/kotlin/org.apache.logging.log4j.kotlin/ThreadContextTest.kt
+++ b/log4j-api-kotlin/src/test/kotlin/org.apache.logging.log4j.kotlin/ThreadContextTest.kt
@@ -53,6 +53,10 @@
assertNull(ContextMap["myKey"])
assertTrue(ContextStack.empty)
}.join()
+ GlobalScope.launch(loggingContext()) {
+ assertNull(ContextMap["myKey"])
+ assertTrue(ContextStack.empty)
+ }.join()
}
@DelicateCoroutinesApi
@@ -65,6 +69,10 @@
assertEquals("myValue", ContextMap["myKey"])
assertEquals("test", ContextStack.peek())
}.join()
+ GlobalScope.launch(additionalLoggingContext()) {
+ assertEquals("myValue", ContextMap["myKey"])
+ assertEquals("test", ContextStack.peek())
+ }.join()
}
@Test
@@ -75,11 +83,37 @@
withContext(CoroutineThreadContext()) {
ContextMap["myKey"] = "myValue2"
ContextStack.push("test2")
- // Scoped launch with inherited MDContext element
+ // Scoped launch with non-inherited MDContext element
launch(Dispatchers.Default) {
assertEquals("myValue", ContextMap["myKey"])
assertEquals("test", ContextStack.peek())
}
+ // Scoped launch with non-inherited MDContext element
+ launch(Dispatchers.Default + loggingContext()) {
+ assertTrue(ContextMap.empty)
+ assertTrue(ContextStack.empty)
+ }
+ // Scoped launch with non-inherited MDContext element
+ launch(Dispatchers.Default + loggingContext(mapOf("myKey2" to "myValue2"), listOf("test3"))) {
+ assertEquals(null, ContextMap["myKey"])
+ assertEquals("myValue2", ContextMap["myKey2"])
+ assertEquals(listOf("test3"), ContextStack.view.asList())
+ }
+ // Scoped launch with inherited MDContext element
+ launch(Dispatchers.Default + CoroutineThreadContext()) {
+ assertEquals("myValue2", ContextMap["myKey"])
+ assertEquals("test2", ContextStack.peek())
+ }
+ // Scoped launch with inherited plus additional empty MDContext element
+ launch(Dispatchers.Default + additionalLoggingContext()) {
+ assertEquals("myValue2", ContextMap["myKey"])
+ assertEquals(listOf("test", "test2"), ContextStack.view.asList())
+ }
+ launch(Dispatchers.Default + additionalLoggingContext(mapOf("myKey2" to "myValue2"), listOf("test3"))) {
+ assertEquals("myValue2", ContextMap["myKey"])
+ assertEquals("myValue2", ContextMap["myKey2"])
+ assertEquals(listOf("test", "test2", "test3"), ContextStack.view.asList())
+ }
}
assertEquals("myValue", ContextMap["myKey"])
assertEquals("test", ContextStack.peek())
@@ -104,6 +138,10 @@
assertEquals("myValue", ContextMap["myKey"])
assertEquals("test", ContextStack.peek())
}
+ runBlocking(additionalLoggingContext()) {
+ assertEquals("myValue", ContextMap["myKey"])
+ assertEquals("test", ContextStack.peek())
+ }
}
@Test
@@ -112,10 +150,18 @@
assertTrue(ContextMap.empty)
assertTrue(ContextStack.empty)
}
+ runBlocking(loggingContext()) {
+ assertTrue(ContextMap.empty)
+ assertTrue(ContextStack.empty)
+ }
+ runBlocking(additionalLoggingContext()) {
+ assertTrue(ContextMap.empty)
+ assertTrue(ContextStack.empty)
+ }
}
@Test
- fun `Context with context`() = runBlocking {
+ fun `Context using withContext`() = runBlocking {
ContextMap["myKey"] = "myValue"
ContextStack.push("test")
val mainDispatcher = coroutineContext[ContinuationInterceptor]!!
@@ -127,6 +173,61 @@
assertEquals("test", ContextStack.peek())
}
}
+ withContext(Dispatchers.Default + additionalLoggingContext()) {
+ assertEquals("myValue", ContextMap["myKey"])
+ assertEquals("test", ContextStack.peek())
+ withContext(mainDispatcher) {
+ assertEquals("myValue", ContextMap["myKey"])
+ assertEquals("test", ContextStack.peek())
+ }
+ }
+ withContext(Dispatchers.Default + additionalLoggingContext(mapOf("myKey2" to "myValue2"), listOf("test2"))) {
+ assertEquals("myValue", ContextMap["myKey"])
+ assertEquals("myValue2", ContextMap["myKey2"])
+ assertEquals(listOf("test", "test2"), ContextStack.view.asList())
+ withContext(mainDispatcher) {
+ assertEquals("myValue", ContextMap["myKey"])
+ assertEquals("myValue2", ContextMap["myKey2"])
+ assertEquals(listOf("test", "test2"), ContextStack.view.asList())
+ }
+ }
+ withContext(Dispatchers.Default + loggingContext(mapOf("myKey2" to "myValue2"), listOf("test2"))) {
+ assertEquals(null, ContextMap["myKey"])
+ assertEquals("myValue2", ContextMap["myKey2"])
+ assertEquals(listOf("test2"), ContextStack.view.asList())
+ withContext(mainDispatcher) {
+ assertEquals(null, ContextMap["myKey"])
+ assertEquals("myValue2", ContextMap["myKey2"])
+ assertEquals(listOf("test2"), ContextStack.view.asList())
+ }
+ }
+ }
+
+ @Test
+ fun `Context using withLoggingContext`() = runBlocking {
+ ContextMap["myKey"] = "myValue"
+ ContextStack.push("test")
+ val mainDispatcher = coroutineContext[ContinuationInterceptor]!!
+ withAdditionalLoggingContext(mapOf("myKey2" to "myValue2"), listOf("test2")) {
+ assertEquals("myValue", ContextMap["myKey"])
+ assertEquals("myValue2", ContextMap["myKey2"])
+ assertEquals(listOf("test", "test2"), ContextStack.view.asList())
+ withContext(mainDispatcher) {
+ assertEquals("myValue", ContextMap["myKey"])
+ assertEquals("myValue2", ContextMap["myKey2"])
+ assertEquals(listOf("test", "test2"), ContextStack.view.asList())
+ }
+ }
+ withLoggingContext(mapOf("myKey2" to "myValue2"), listOf("test2")) {
+ assertEquals(null, ContextMap["myKey"])
+ assertEquals("myValue2", ContextMap["myKey2"])
+ assertEquals(listOf("test2"), ContextStack.view.asList())
+ withContext(mainDispatcher) {
+ assertEquals(null, ContextMap["myKey"])
+ assertEquals("myValue2", ContextMap["myKey2"])
+ assertEquals(listOf("test2"), ContextStack.view.asList())
+ }
+ }
}
@Test
@@ -139,5 +240,19 @@
}
assertTrue(ContextMap.empty)
assertTrue(ContextStack.empty)
+
+ withContext(loggingContext(mapOf("myKey" to "myValue"), listOf("test"))) {
+ assertEquals("myValue", ContextMap["myKey"])
+ assertEquals("test", ContextStack.peek())
+ }
+ assertTrue(ContextMap.empty)
+ assertTrue(ContextStack.empty)
+
+ withLoggingContext(mapOf("myKey" to "myValue"), listOf("test")) {
+ assertEquals("myValue", ContextMap["myKey"])
+ assertEquals("test", ContextStack.peek())
+ }
+ assertTrue(ContextMap.empty)
+ assertTrue(ContextStack.empty)
}
}