[YUNIKORN-1709] Add event streaming logic (#533)
Closes: #533
Signed-off-by: Peter Bacsko <pbacsko@cloudera.com>
diff --git a/pkg/events/event_ringbuffer.go b/pkg/events/event_ringbuffer.go
index 479ea70..3ae9533 100644
--- a/pkg/events/event_ringbuffer.go
+++ b/pkg/events/event_ringbuffer.go
@@ -70,7 +70,25 @@
e.id++
}
-// GetEventsFromID returns "count" number of event records from "id" if possible. The id can be determined from
+// GetRecentEvents returns the most recent "count" elements from the ring buffer.
+// It is allowed for "count" to be larger than the number of elements.
+func (e *eventRingBuffer) GetRecentEvents(count uint64) []*si.EventRecord {
+ e.RLock()
+ defer e.RUnlock()
+
+ lastID := e.getLastEventID()
+ var startID uint64
+ if lastID < count {
+ startID = 0
+ } else {
+ startID = lastID - count + 1
+ }
+
+ history, _, _ := e.getEventsFromID(startID, count)
+ return history
+}
+
+// GetEventsFromID returns "count" number of event records from id if possible. The id can be determined from
// the first call of the method - if it returns nothing because the id is not in the buffer, the lowest valid
// identifier is returned which can be used to get the first batch.
// If the caller does not want to pose limit on the number of events returned, "count" must be set to a high
@@ -78,6 +96,14 @@
func (e *eventRingBuffer) GetEventsFromID(id uint64, count uint64) ([]*si.EventRecord, uint64, uint64) {
e.RLock()
defer e.RUnlock()
+
+ return e.getEventsFromID(id, count)
+}
+
+// getEventsFromID unlocked version of GetEventsFromID
+func (e *eventRingBuffer) getEventsFromID(id uint64, count uint64) ([]*si.EventRecord, uint64, uint64) {
+ e.RLock()
+ defer e.RUnlock()
lowest := e.getLowestID()
pos, idFound := e.id2pos(id)
diff --git a/pkg/events/event_ringbuffer_test.go b/pkg/events/event_ringbuffer_test.go
index 6b390e6..d25fc86 100644
--- a/pkg/events/event_ringbuffer_test.go
+++ b/pkg/events/event_ringbuffer_test.go
@@ -277,6 +277,39 @@
assert.Equal(t, uint64(7), ringBuffer.resizeOffset)
}
+func TestGetRecentEvents(t *testing.T) {
+ // empty
+ buffer := newEventRingBuffer(10)
+ records := buffer.GetRecentEvents(5)
+ assert.Equal(t, 0, len(records))
+
+ populate(buffer, 5)
+
+ // count < elements
+ records = buffer.GetRecentEvents(2)
+ assert.Equal(t, 2, len(records))
+ assert.Equal(t, int64(3), records[0].TimestampNano)
+ assert.Equal(t, int64(4), records[1].TimestampNano)
+
+ // count = elements
+ records = buffer.GetRecentEvents(5)
+ assert.Equal(t, 5, len(records))
+ assert.Equal(t, int64(0), records[0].TimestampNano)
+ assert.Equal(t, int64(1), records[1].TimestampNano)
+ assert.Equal(t, int64(2), records[2].TimestampNano)
+ assert.Equal(t, int64(3), records[3].TimestampNano)
+ assert.Equal(t, int64(4), records[4].TimestampNano)
+
+ // count > elements
+ records = buffer.GetRecentEvents(15)
+ assert.Equal(t, 5, len(records))
+ assert.Equal(t, int64(0), records[0].TimestampNano)
+ assert.Equal(t, int64(1), records[1].TimestampNano)
+ assert.Equal(t, int64(2), records[2].TimestampNano)
+ assert.Equal(t, int64(3), records[3].TimestampNano)
+ assert.Equal(t, int64(4), records[4].TimestampNano)
+}
+
func populate(buffer *eventRingBuffer, count int) {
for i := 0; i < count; i++ {
buffer.Add(&si.EventRecord{
diff --git a/pkg/events/event_streaming.go b/pkg/events/event_streaming.go
new file mode 100644
index 0000000..4f7b9d2
--- /dev/null
+++ b/pkg/events/event_streaming.go
@@ -0,0 +1,179 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+package events
+
+import (
+ "sync"
+ "time"
+
+ "go.uber.org/zap"
+
+ "github.com/apache/yunikorn-core/pkg/log"
+ "github.com/apache/yunikorn-scheduler-interface/lib/go/si"
+)
+
+const defaultChannelBufSize = 1000
+
+// EventStreaming implements the event streaming logic.
+// New events are immediately forwarded to all active consumers.
+type EventStreaming struct {
+ buffer *eventRingBuffer
+ stopCh chan struct{}
+ eventStreams map[*EventStream]eventConsumerDetails
+ sync.Mutex
+}
+
+type eventConsumerDetails struct {
+ local chan *si.EventRecord
+ consumer chan<- *si.EventRecord
+ stopCh chan struct{}
+ name string
+ createdAt time.Time
+}
+
+// EventStream handle type returned to the client that wants to capture the stream of events.
+type EventStream struct {
+ Events <-chan *si.EventRecord
+}
+
+// PublishEvent publishes an event to all event stream consumers.
+//
+// The streaming logic uses bridging to ensure proper ordering of existing and new events.
+// Events are sent to the "local" channel from where it is forwarded to the "consumer" channel.
+//
+// If "local" is full, it means that the consumer side has not processed the events at an appropriate pace.
+// Such a consumer is removed and the related channels are closed.
+func (e *EventStreaming) PublishEvent(event *si.EventRecord) {
+ e.Lock()
+ defer e.Unlock()
+
+ for consumer, details := range e.eventStreams {
+ if len(details.local) == defaultChannelBufSize {
+ log.Log(log.Events).Warn("Listener buffer full due to potentially slow consumer, removing it")
+ e.removeEventStream(consumer)
+ continue
+ }
+
+ details.local <- event
+ }
+}
+
+// CreateEventStream sets up event streaming for a consumer. The returned EventStream object
+// contains a channel that can be used for reading.
+//
+// When a consumer is finished, it must call RemoveEventStream to free up resources.
+//
+// Consumers have an arbitrary name for logging purposes. The "count" parameter defines the number
+// of maximum historical events from the ring buffer. "0" is a valid value and means no past events.
+func (e *EventStreaming) CreateEventStream(name string, count uint64) *EventStream {
+ consumer := make(chan *si.EventRecord, defaultChannelBufSize)
+ stream := &EventStream{
+ Events: consumer,
+ }
+ local := make(chan *si.EventRecord, defaultChannelBufSize)
+ stop := make(chan struct{})
+ e.createEventStreamInternal(stream, local, consumer, stop, name, count)
+ history := e.buffer.GetRecentEvents(count)
+
+ go func(consumer chan<- *si.EventRecord, local <-chan *si.EventRecord, stop <-chan struct{}) {
+ // Store the refs of historical events; it's possible that some events are added to the
+ // ring buffer and also to "local" channel.
+ // It is because we use two separate locks, so event updates are not atomic.
+ // Example: an event has been just added to the ring buffer (before createEventStreamInternal()),
+ // and execution is about to enter PublishEvent(); at this point we have an updated "eventStreams"
+ // map, so "local" will also contain the new event.
+ seen := make(map[*si.EventRecord]bool)
+ for _, event := range history {
+ consumer <- event
+ seen[event] = true
+ }
+ for {
+ select {
+ case <-e.stopCh:
+ close(consumer)
+ return
+ case <-stop:
+ close(consumer)
+ return
+ case event := <-local:
+ if seen[event] {
+ continue
+ }
+ // since events are processed in a single goroutine, doubling is no longer
+ // possible at this point
+ seen = make(map[*si.EventRecord]bool)
+ consumer <- event
+ }
+ }
+ }(consumer, local, stop)
+
+ log.Log(log.Events).Info("Created event stream", zap.String("consumer name", name))
+ return stream
+}
+
+func (e *EventStreaming) createEventStreamInternal(stream *EventStream,
+ local chan *si.EventRecord,
+ consumer chan *si.EventRecord,
+ stop chan struct{},
+ name string,
+ count uint64) {
+ // stuff that needs locking
+ e.Lock()
+ defer e.Unlock()
+
+ e.eventStreams[stream] = eventConsumerDetails{
+ local: local,
+ consumer: consumer,
+ stopCh: stop,
+ name: name,
+ createdAt: time.Now(),
+ }
+}
+
+// RemoveEventStream stops the streaming for a given consumer. Must be called to avoid resource leaks.
+func (e *EventStreaming) RemoveEventStream(consumer *EventStream) {
+ e.Lock()
+ defer e.Unlock()
+
+ e.removeEventStream(consumer)
+}
+
+func (e *EventStreaming) removeEventStream(consumer *EventStream) {
+ if details, ok := e.eventStreams[consumer]; ok {
+ log.Log(log.Events).Info("Removing event stream consumer", zap.String("name", details.name),
+ zap.Time("creation time", details.createdAt))
+ close(details.stopCh)
+ close(details.local)
+ delete(e.eventStreams, consumer)
+ }
+}
+
+// Close stops event streaming completely.
+func (e *EventStreaming) Close() {
+ close(e.stopCh)
+}
+
+// NewEventStreaming creates a new event streaming infrastructure.
+func NewEventStreaming(eventBuffer *eventRingBuffer) *EventStreaming {
+ return &EventStreaming{
+ buffer: eventBuffer,
+ stopCh: make(chan struct{}),
+ eventStreams: make(map[*EventStream]eventConsumerDetails),
+ }
+}
diff --git a/pkg/events/event_streaming_test.go b/pkg/events/event_streaming_test.go
new file mode 100644
index 0000000..8afc770
--- /dev/null
+++ b/pkg/events/event_streaming_test.go
@@ -0,0 +1,145 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+package events
+
+import (
+ "testing"
+ "time"
+
+ "gotest.tools/v3/assert"
+
+ "github.com/apache/yunikorn-scheduler-interface/lib/go/si"
+)
+
+var defaultCount = uint64(10000)
+
+func TestEventStreaming_WithoutHistory(t *testing.T) {
+ buffer := newEventRingBuffer(10)
+ streaming := NewEventStreaming(buffer)
+ es := streaming.CreateEventStream("test", defaultCount)
+ defer streaming.Close()
+
+ sent := &si.EventRecord{
+ Message: "testMessage",
+ }
+ streaming.PublishEvent(sent)
+ received := receive(t, es.Events)
+ streaming.RemoveEventStream(es)
+ assert.Equal(t, 0, len(streaming.eventStreams[es].local))
+ assert.Equal(t, 0, len(streaming.eventStreams[es].consumer))
+ assert.Equal(t, received.Message, sent.Message)
+ assert.Equal(t, 0, len(streaming.eventStreams))
+}
+
+func TestEventStreaming_WithHistory(t *testing.T) {
+ buffer := newEventRingBuffer(10)
+ streaming := NewEventStreaming(buffer)
+ defer streaming.Close()
+
+ buffer.Add(&si.EventRecord{TimestampNano: 1})
+ buffer.Add(&si.EventRecord{TimestampNano: 5})
+ buffer.Add(&si.EventRecord{TimestampNano: 6})
+ buffer.Add(&si.EventRecord{TimestampNano: 9})
+ es := streaming.CreateEventStream("test", defaultCount)
+
+ streaming.PublishEvent(&si.EventRecord{TimestampNano: 10})
+
+ received1 := receive(t, es.Events)
+ received2 := receive(t, es.Events)
+ received3 := receive(t, es.Events)
+ received4 := receive(t, es.Events)
+ received5 := receive(t, es.Events)
+ assert.Equal(t, 0, len(streaming.eventStreams[es].local))
+ assert.Equal(t, 0, len(streaming.eventStreams[es].consumer))
+ assert.Equal(t, int64(1), received1.TimestampNano)
+ assert.Equal(t, int64(5), received2.TimestampNano)
+ assert.Equal(t, int64(6), received3.TimestampNano)
+ assert.Equal(t, int64(9), received4.TimestampNano)
+ assert.Equal(t, int64(10), received5.TimestampNano)
+ streaming.RemoveEventStream(es)
+ assert.Equal(t, 0, len(streaming.eventStreams))
+}
+
+func TestEventStreaming_WithHistoryCount(t *testing.T) {
+ buffer := newEventRingBuffer(10)
+ streaming := NewEventStreaming(buffer)
+ defer streaming.Close()
+
+ buffer.Add(&si.EventRecord{TimestampNano: 1})
+ buffer.Add(&si.EventRecord{TimestampNano: 5})
+ buffer.Add(&si.EventRecord{TimestampNano: 6})
+ buffer.Add(&si.EventRecord{TimestampNano: 9})
+ es := streaming.CreateEventStream("test", 2)
+
+ streaming.PublishEvent(&si.EventRecord{TimestampNano: 10})
+
+ received1 := receive(t, es.Events)
+ received2 := receive(t, es.Events)
+ received3 := receive(t, es.Events)
+ assert.Equal(t, 0, len(streaming.eventStreams[es].local))
+ assert.Equal(t, 0, len(streaming.eventStreams[es].consumer))
+ assert.Equal(t, int64(6), received1.TimestampNano)
+ assert.Equal(t, int64(9), received2.TimestampNano)
+ assert.Equal(t, int64(10), received3.TimestampNano)
+}
+
+func TestEventStreaming_TwoConsumers(t *testing.T) {
+ buffer := newEventRingBuffer(10)
+ streaming := NewEventStreaming(buffer)
+ defer streaming.Close()
+
+ es1 := streaming.CreateEventStream("stream1", defaultCount)
+ es2 := streaming.CreateEventStream("stream2", defaultCount)
+ for i := 0; i < 5; i++ {
+ streaming.PublishEvent(&si.EventRecord{TimestampNano: int64(i)})
+ }
+
+ for i := 0; i < 5; i++ {
+ assert.Equal(t, int64(i), receive(t, es1.Events).TimestampNano)
+ assert.Equal(t, int64(i), receive(t, es2.Events).TimestampNano)
+ }
+ assert.Equal(t, 0, len(streaming.eventStreams[es1].local))
+ assert.Equal(t, 0, len(streaming.eventStreams[es1].consumer))
+ assert.Equal(t, 0, len(streaming.eventStreams[es2].local))
+ assert.Equal(t, 0, len(streaming.eventStreams[es2].consumer))
+}
+
+func TestEventStreaming_SlowConsumer(t *testing.T) {
+ // simulating a slow event consumer by ignoring events
+ buffer := newEventRingBuffer(10)
+ streaming := NewEventStreaming(buffer)
+ defer streaming.Close()
+ streaming.CreateEventStream("test", 10000)
+
+ for i := 0; i < 2500; i++ {
+ streaming.PublishEvent(&si.EventRecord{TimestampNano: int64(i)})
+ }
+
+ assert.Equal(t, 0, len(streaming.eventStreams))
+}
+
+func receive(t *testing.T, input <-chan *si.EventRecord) *si.EventRecord {
+ select {
+ case event := <-input:
+ return event
+ case <-time.After(time.Second):
+ t.Fatal("receive failed")
+ return nil
+ }
+}
diff --git a/pkg/events/event_system.go b/pkg/events/event_system.go
index ef8f7c0..e89f628 100644
--- a/pkg/events/event_system.go
+++ b/pkg/events/event_system.go
@@ -38,18 +38,50 @@
var ev EventSystem
type EventSystem interface {
+ // AddEvent adds an event record to the event system for processing:
+ // 1. It is added to a slice from where it is periodically read by the shim publisher.
+ // 2. It is added to an internal ring buffer so that clients can retrieve the event history.
+ // 3. Streaming clients are updated.
AddEvent(event *si.EventRecord)
+
+ // StartService starts the event system.
+ // This method does not block. Events are processed on a separate goroutine.
StartService()
+
+ // Stop stops the event system.
Stop()
+
+ // IsEventTrackingEnabled whether history tracking is currently enabled or not.
IsEventTrackingEnabled() bool
- GetEventsFromID(uint64, uint64) ([]*si.EventRecord, uint64, uint64)
+
+ // GetEventsFromID retrieves "count" number of elements from the history buffer from "id". Every
+ // event has a unique ID inside the ring buffer.
+ // If "id" is not in the buffer, then no record is returned, but the currently available range
+ // [low..high] is set.
+ GetEventsFromID(id, count uint64) ([]*si.EventRecord, uint64, uint64)
+
+ // CreateEventStream creates an event stream (channel) for a consumer.
+ // The "name" argument is an arbitrary string for a consumer, which is used for logging. It does not need to be unique.
+ // The "count" argument defines how many historical elements should be returned on the stream. Zero is a valid value for "count".
+ // The returned type contains a read-only channel which is updated as soon as there is a new event record.
+ // It is also used as a handle to stop the streaming.
+ // Consumers must read the channel and process the event objects as soon as they can to avoid
+ // events piling up inside the channel buffers.
+ CreateEventStream(name string, count uint64) *EventStream
+
+ // RemoveStream stops streaming for a given consumer.
+ // Consumers that no longer wish to be updated (e.g., a remote client
+ // disconnected) *must* call this method to gracefully stop the streaming.
+ RemoveStream(*EventStream)
}
+// EventSystemImpl main implementation of the event system which is used for history tracking.
type EventSystemImpl struct {
eventSystemId string
Store *EventStore // storing eventChannel
publisher *EventPublisher
eventBuffer *eventRingBuffer
+ streaming *EventStreaming
channel chan *si.EventRecord // channelling input eventChannel
stop chan bool // whether the service is stopped
@@ -62,10 +94,22 @@
sync.RWMutex
}
+// CreateEventStream creates an event stream. See the interface for details.
+func (ec *EventSystemImpl) CreateEventStream(name string, count uint64) *EventStream {
+ return ec.streaming.CreateEventStream(name, count)
+}
+
+// RemoveStream graceful termination of an event streaming for a consumer. See the interface for details.
+func (ec *EventSystemImpl) RemoveStream(consumer *EventStream) {
+ ec.streaming.RemoveEventStream(consumer)
+}
+
+// GetEventsFromID retrieves historical elements. See the interface for details.
func (ec *EventSystemImpl) GetEventsFromID(id, count uint64) ([]*si.EventRecord, uint64, uint64) {
return ec.eventBuffer.GetEventsFromID(id, count)
}
+// GetEventSystem returns the event system instance. Initialization happens during the first call.
func GetEventSystem() EventSystem {
once.Do(func() {
Init()
@@ -73,42 +117,51 @@
return ev
}
+// IsEventTrackingEnabled whether history tracking is currently enabled or not.
func (ec *EventSystemImpl) IsEventTrackingEnabled() bool {
ec.RLock()
defer ec.RUnlock()
return ec.trackingEnabled
}
+// GetRequestCapacity returns the capacity of an intermediate storage which is used by the shim publisher.
func (ec *EventSystemImpl) GetRequestCapacity() int {
ec.RLock()
defer ec.RUnlock()
return ec.requestCapacity
}
+// GetRingBufferCapacity returns the capacity of the buffer which stores historical elements.
func (ec *EventSystemImpl) GetRingBufferCapacity() uint64 {
ec.RLock()
defer ec.RUnlock()
return ec.ringBufferCapacity
}
-// VisibleForTesting
+// Init Initializes the event system.
+// Only exported for testing.
func Init() {
store := newEventStore()
+ buffer := newEventRingBuffer(defaultRingBufferSize)
ev = &EventSystemImpl{
Store: store,
channel: make(chan *si.EventRecord, defaultEventChannelSize),
stop: make(chan bool),
stopped: false,
publisher: CreateShimPublisher(store),
- eventBuffer: newEventRingBuffer(defaultRingBufferSize),
+ eventBuffer: buffer,
eventSystemId: fmt.Sprintf("event-system-%d", time.Now().Unix()),
+ streaming: NewEventStreaming(buffer),
}
}
+// StartService starts the event processing in the background. See the interface for details.
func (ec *EventSystemImpl) StartService() {
ec.StartServiceWithPublisher(true)
}
+// StartServiceWithPublisher starts the event processing background routines.
+// Only exported for testing.
func (ec *EventSystemImpl) StartServiceWithPublisher(withPublisher bool) {
ec.Lock()
defer ec.Unlock()
@@ -134,6 +187,7 @@
if event != nil {
ec.Store.Store(event)
ec.eventBuffer.Add(event)
+ ec.streaming.PublishEvent(event)
metrics.GetEventMetrics().IncEventsProcessed()
}
}
@@ -144,6 +198,7 @@
}
}
+// Stop stops the event system.
func (ec *EventSystemImpl) Stop() {
ec.Lock()
defer ec.Unlock()
@@ -163,6 +218,7 @@
ec.stopped = true
}
+// AddEvent adds an event record to the event system. See the interface for details.
func (ec *EventSystemImpl) AddEvent(event *si.EventRecord) {
metrics.GetEventMetrics().IncEventsCreated()
select {
@@ -192,11 +248,21 @@
return ec.readIsTrackingEnabled() != ec.trackingEnabled
}
+// Restart restarts the event system, used during config update.
func (ec *EventSystemImpl) Restart() {
ec.Stop()
ec.StartServiceWithPublisher(true)
}
+// VisibleForTesting
+func (ec *EventSystemImpl) CloseAllStreams() {
+ ec.streaming.Lock()
+ defer ec.streaming.Unlock()
+ for consumer := range ec.streaming.eventStreams {
+ ec.streaming.removeEventStream(consumer)
+ }
+}
+
func (ec *EventSystemImpl) reloadConfig() {
ec.updateRequestCapacity()
diff --git a/pkg/scheduler/objects/common_test.go b/pkg/scheduler/objects/common_test.go
index 72f0e80..290a32f 100644
--- a/pkg/scheduler/objects/common_test.go
+++ b/pkg/scheduler/objects/common_test.go
@@ -20,6 +20,7 @@
"github.com/google/btree"
"github.com/apache/yunikorn-core/pkg/common/resources"
+ "github.com/apache/yunikorn-core/pkg/events"
"github.com/apache/yunikorn-scheduler-interface/lib/go/si"
)
@@ -28,6 +29,13 @@
enabled bool
}
+func (m *EventSystemMock) CreateEventStream(_ string, _ uint64) *events.EventStream {
+ return nil
+}
+
+func (m *EventSystemMock) RemoveStream(_ *events.EventStream) {
+}
+
func (m *EventSystemMock) AddEvent(event *si.EventRecord) {
m.events = append(m.events, event)
}
diff --git a/pkg/webservice/handler_mock_test.go b/pkg/webservice/handler_mock_test.go
index 9759afa..439efe4 100644
--- a/pkg/webservice/handler_mock_test.go
+++ b/pkg/webservice/handler_mock_test.go
@@ -19,6 +19,7 @@
import (
"net/http"
+ "time"
)
// InternalMetricHistory needs resetting between tests
@@ -49,3 +50,7 @@
func (trw *MockResponseWriter) WriteHeader(statusCode int) {
trw.statusCode = statusCode
}
+
+func (trw *MockResponseWriter) SetWriteDeadline(deadline time.Time) error {
+ return nil
+}
diff --git a/pkg/webservice/handlers.go b/pkg/webservice/handlers.go
index 69c781f..d01f12c 100644
--- a/pkg/webservice/handlers.go
+++ b/pkg/webservice/handlers.go
@@ -28,6 +28,7 @@
"sort"
"strconv"
"strings"
+ "time"
"github.com/julienschmidt/httprouter"
"github.com/prometheus/client_golang/prometheus/promhttp"
@@ -1084,3 +1085,82 @@
buildJSONErrorResponse(w, err.Error(), http.StatusInternalServerError)
}
}
+
+func getStream(w http.ResponseWriter, r *http.Request) {
+ writeHeaders(w)
+ eventSystem := events.GetEventSystem()
+ if !eventSystem.IsEventTrackingEnabled() {
+ buildJSONErrorResponse(w, "Event tracking is disabled", http.StatusInternalServerError)
+ return
+ }
+
+ f, ok := w.(http.Flusher)
+ if !ok {
+ buildJSONErrorResponse(w, "Writer does not implement http.Flusher", http.StatusInternalServerError)
+ return
+ }
+
+ var count uint64
+ if countStr := r.URL.Query().Get("count"); countStr != "" {
+ var err error
+ count, err = strconv.ParseUint(countStr, 10, 64)
+ if err != nil {
+ buildJSONErrorResponse(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+ }
+
+ rc := http.NewResponseController(w)
+ // make sure both deadlines can be set
+ if err := rc.SetWriteDeadline(time.Time{}); err != nil {
+ log.Log(log.REST).Error("Cannot set write deadline", zap.Error(err))
+ buildJSONErrorResponse(w, fmt.Sprintf("Cannot set write deadline: %v", err), http.StatusInternalServerError)
+ return
+ }
+ if err := rc.SetReadDeadline(time.Time{}); err != nil {
+ log.Log(log.REST).Error("Cannot set read deadline", zap.Error(err))
+ buildJSONErrorResponse(w, fmt.Sprintf("Cannot set read deadline: %v", err), http.StatusInternalServerError)
+ return
+ }
+ enc := json.NewEncoder(w)
+ stream := eventSystem.CreateEventStream(r.Host, count)
+
+ // Reading events in an infinite loop until either the client disconnects or Yunikorn closes the channel.
+ // This results in a persistent HTTP connection where the message body is never closed.
+ // Write deadline is adjusted before sending data to the client.
+ for {
+ select {
+ case <-r.Context().Done():
+ log.Log(log.REST).Info("Connection closed for event stream client",
+ zap.String("host", r.Host))
+ eventSystem.RemoveStream(stream)
+ return
+ case e, ok := <-stream.Events:
+ err := rc.SetWriteDeadline(time.Now().Add(5 * time.Second))
+ if err != nil {
+ // should not fail at this point
+ log.Log(log.REST).Error("Cannot set write deadline", zap.Error(err))
+ buildJSONErrorResponse(w, fmt.Sprintf("Cannot set write deadline: %v", err), http.StatusInternalServerError)
+ eventSystem.RemoveStream(stream)
+ return
+ }
+
+ if !ok {
+ // the channel was closed by the event system itself
+ msg := "Event stream was closed by the producer"
+ buildJSONErrorResponse(w, msg, http.StatusOK) // status code is 200 at this point, cannot be changed
+ log.Log(log.REST).Error(msg)
+ return
+ }
+
+ if err := enc.Encode(e); err != nil {
+ log.Log(log.REST).Error("Marshalling error",
+ zap.String("host", r.Host))
+ buildJSONErrorResponse(w, err.Error(), http.StatusOK) // status code is 200 at this point, cannot be changed
+ eventSystem.RemoveStream(stream)
+ return
+ }
+ f.Flush()
+ }
+ }
+}
diff --git a/pkg/webservice/handlers_test.go b/pkg/webservice/handlers_test.go
index 54bad55..f07b97e 100644
--- a/pkg/webservice/handlers_test.go
+++ b/pkg/webservice/handlers_test.go
@@ -21,7 +21,9 @@
import (
"context"
"encoding/json"
+ "errors"
"fmt"
+ "io"
"net/http"
"net/http/httptest"
"reflect"
@@ -1944,6 +1946,253 @@
readIllegalRequest(t, req, http.StatusInternalServerError, "Event tracking is disabled")
}
+func TestGetStream(t *testing.T) {
+ ev, req := initEventsAndCreateRequest(t)
+ defer ev.Stop()
+ cancelCtx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ req = req.Clone(cancelCtx)
+
+ resp := NewResponseRecorderWithDeadline() // MockResponseWriter does not implement http.Flusher
+
+ go func() {
+ time.Sleep(200 * time.Millisecond)
+ ev.AddEvent(&si.EventRecord{
+ TimestampNano: 111,
+ ObjectID: "app-1",
+ })
+ ev.AddEvent(&si.EventRecord{
+ TimestampNano: 222,
+ ObjectID: "node-1",
+ })
+ ev.AddEvent(&si.EventRecord{
+ TimestampNano: 333,
+ ObjectID: "app-2",
+ })
+ time.Sleep(200 * time.Millisecond)
+ cancel()
+ }()
+ getStream(resp, req)
+
+ output := make([]byte, 256)
+ n, err := resp.Body.Read(output)
+ assert.NilError(t, err, "cannot read response body")
+
+ lines := strings.Split(string(output[:n]), "\n")
+ assertEvent(t, lines[0], 111, "app-1")
+ assertEvent(t, lines[1], 222, "node-1")
+ assertEvent(t, lines[2], 333, "app-2")
+}
+
+func TestGetStream_StreamClosedByProducer(t *testing.T) {
+ ev, req := initEventsAndCreateRequest(t)
+ defer ev.Stop()
+ resp := NewResponseRecorderWithDeadline() // MockResponseWriter does not implement http.Flusher
+
+ go func() {
+ time.Sleep(200 * time.Millisecond)
+ ev.AddEvent(&si.EventRecord{
+ TimestampNano: 111,
+ ObjectID: "app-1",
+ })
+ time.Sleep(100 * time.Millisecond)
+ ev.CloseAllStreams()
+ }()
+
+ getStream(resp, req)
+
+ output := make([]byte, 256)
+ n, err := resp.Body.Read(output)
+ assert.Equal(t, http.StatusOK, resp.Code)
+ assert.NilError(t, err, "cannot read response body")
+ lines := strings.Split(string(output[:n]), "\n")
+ assertEvent(t, lines[0], 111, "app-1")
+ assertYunikornError(t, lines[1], "Event stream was closed by the producer")
+}
+
+func TestGetStream_NotFlusherImpl(t *testing.T) {
+ var req *http.Request
+ req, err := http.NewRequest("GET", "/ws/v1/events/stream", strings.NewReader(""))
+ assert.NilError(t, err)
+ resp := &MockResponseWriter{}
+
+ getStream(resp, req)
+
+ assert.Assert(t, strings.Contains(string(resp.outputBytes), "Writer does not implement http.Flusher"))
+ assert.Equal(t, http.StatusInternalServerError, resp.statusCode)
+}
+
+func TestGetStream_Count(t *testing.T) {
+ ev, req := initEventsAndCreateRequest(t)
+ defer ev.Stop()
+ cancelCtx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ req = req.Clone(cancelCtx)
+ resp := NewResponseRecorderWithDeadline() // MockResponseWriter does not implement http.Flusher
+
+ // add some existing events
+ ev.AddEvent(&si.EventRecord{TimestampNano: 0})
+ ev.AddEvent(&si.EventRecord{TimestampNano: 1})
+ ev.AddEvent(&si.EventRecord{TimestampNano: 2})
+ time.Sleep(100 * time.Millisecond) // let the events propagate
+
+ // case #1: "count" not set
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ cancel()
+ }()
+ getStream(resp, req)
+ output := make([]byte, 256)
+ n, err := resp.Body.Read(output)
+ assert.Error(t, io.EOF, err.Error())
+ assert.Equal(t, 0, n)
+
+ // case #2: "count" is set to "2"
+ req, err = http.NewRequest("GET", "/ws/v1/events/stream", strings.NewReader(""))
+ assert.NilError(t, err)
+ cancelCtx, cancel = context.WithCancel(context.Background())
+ req = req.Clone(cancelCtx)
+ defer cancel()
+ req.URL.RawQuery = "count=2"
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ cancel()
+ }()
+ getStream(resp, req)
+ output = make([]byte, 256)
+ n, err = resp.Body.Read(output)
+ assert.NilError(t, err)
+ lines := strings.Split(string(output[:n]), "\n")
+ assertEvent(t, lines[0], 1, "")
+ assertEvent(t, lines[1], 2, "")
+
+ // case #3: illegal value
+ req, err = http.NewRequest("GET", "/ws/v1/events/stream", strings.NewReader(""))
+ assert.NilError(t, err)
+ cancelCtx, cancel = context.WithCancel(context.Background())
+ req = req.Clone(cancelCtx)
+ defer cancel()
+ req.URL.RawQuery = "count=xyz"
+ getStream(resp, req)
+ output = make([]byte, 256)
+ n, err = resp.Body.Read(output)
+ assert.NilError(t, err)
+ line := string(output[:n])
+ assertYunikornError(t, line, `strconv.ParseUint: parsing "xyz": invalid syntax`)
+}
+
+func TestGetStream_TrackingDisabled(t *testing.T) {
+ original := configs.GetConfigMap()
+ defer func() {
+ ev := events.GetEventSystem().(*events.EventSystemImpl) //nolint:errcheck
+ ev.Stop()
+ configs.SetConfigMap(original)
+ }()
+ configMap := map[string]string{
+ configs.CMEventTrackingEnabled: "false",
+ }
+ configs.SetConfigMap(configMap)
+ _, req := initEventsAndCreateRequest(t)
+ resp := httptest.NewRecorder()
+
+ assertGetStreamError(t, req, resp, "Event tracking is disabled")
+}
+
+func TestGetStream_NoWriteDeadline(t *testing.T) {
+ ev, req := initEventsAndCreateRequest(t)
+ defer ev.Stop()
+ resp := httptest.NewRecorder() // does not have SetWriteDeadline()
+
+ assertGetStreamError(t, req, resp, "Cannot set write deadline: feature not supported")
+}
+
+func TestGetStream_SetWriteDeadlineFails(t *testing.T) {
+ ev, req := initEventsAndCreateRequest(t)
+ defer ev.Stop()
+ resp := NewResponseRecorderWithDeadline()
+ resp.setWriteFailsAt = 2 // only the second SetWriteDeadline() will fail
+ resp.setWriteFails = true
+
+ go func() {
+ time.Sleep(200 * time.Millisecond)
+ ev.AddEvent(&si.EventRecord{
+ TimestampNano: 111,
+ ObjectID: "app-1",
+ })
+ }()
+
+ getStream(resp, req)
+ checkGetStreamErrorResult(t, resp.Result(), "Cannot set write deadline: SetWriteDeadline failed")
+}
+
+func TestGetStream_SetReadDeadlineFails(t *testing.T) {
+ _, req := initEventsAndCreateRequest(t)
+ resp := NewResponseRecorderWithDeadline()
+ resp.setReadFails = true
+
+ assertGetStreamError(t, req, resp, "Cannot set read deadline: SetReadDeadline failed")
+}
+
+func assertGetStreamError(t *testing.T, req *http.Request, resp interface{},
+ expectedMsg string) {
+ t.Helper()
+ var response *http.Response
+
+ switch rec := resp.(type) {
+ case *ResponseRecorderWithDeadline:
+ getStream(rec, req)
+ response = rec.Result()
+ case *httptest.ResponseRecorder:
+ getStream(rec, req)
+ response = rec.Result()
+ default:
+ t.Fatalf("unknown response recorder type")
+ }
+
+ checkGetStreamErrorResult(t, response, expectedMsg)
+}
+
+func checkGetStreamErrorResult(t *testing.T, response *http.Response, expectedMsg string) {
+ t.Helper()
+ output := make([]byte, 256)
+ n, err := response.Body.Read(output)
+ assert.NilError(t, err)
+ line := string(output[:n])
+ assertYunikornError(t, line, expectedMsg)
+ assert.Equal(t, http.StatusInternalServerError, response.StatusCode)
+}
+
+func initEventsAndCreateRequest(t *testing.T) (*events.EventSystemImpl, *http.Request) {
+ t.Helper()
+ events.Init()
+ ev := events.GetEventSystem().(*events.EventSystemImpl) //nolint:errcheck
+ ev.StartServiceWithPublisher(false)
+
+ var req *http.Request
+ req, err := http.NewRequest("GET", "/ws/v1/events/stream", strings.NewReader(""))
+ assert.NilError(t, err)
+
+ return ev, req
+}
+
+func assertEvent(t *testing.T, output string, tsNano int64, objectID string) {
+ t.Helper()
+ var evt si.EventRecord
+ err := json.Unmarshal([]byte(output), &evt)
+ assert.NilError(t, err)
+ assert.Equal(t, tsNano, evt.TimestampNano)
+ assert.Equal(t, objectID, evt.ObjectID)
+}
+
+func assertYunikornError(t *testing.T, output, errMsg string) {
+ t.Helper()
+ var ykErr dao.YAPIError
+ err := json.Unmarshal([]byte(output), &ykErr)
+ assert.NilError(t, err)
+ assert.Equal(t, errMsg, ykErr.Description)
+ assert.Equal(t, errMsg, ykErr.Message)
+}
+
func addEvents(t *testing.T) (appEvent, nodeEvent, queueEvent *si.EventRecord) {
t.Helper()
events.Init()
@@ -2186,3 +2435,32 @@
assert.Equal(t, expectedHealthCheck.DiagnosisMessage, actualHealthCheck.DiagnosisMessage)
}
}
+
+type ResponseRecorderWithDeadline struct {
+ *httptest.ResponseRecorder
+ setWriteFails bool
+ setWriteFailsAt int
+ setWriteCalls int
+ setReadFails bool
+}
+
+func (rrd *ResponseRecorderWithDeadline) SetWriteDeadline(_ time.Time) error {
+ rrd.setWriteCalls++
+ if rrd.setWriteFails && rrd.setWriteCalls == rrd.setWriteFailsAt {
+ return errors.New("SetWriteDeadline failed")
+ }
+ return nil
+}
+
+func (rrd *ResponseRecorderWithDeadline) SetReadDeadline(_ time.Time) error {
+ if rrd.setReadFails {
+ return errors.New("SetReadDeadline failed")
+ }
+ return nil
+}
+
+func NewResponseRecorderWithDeadline() *ResponseRecorderWithDeadline {
+ return &ResponseRecorderWithDeadline{
+ ResponseRecorder: httptest.NewRecorder(),
+ }
+}
diff --git a/pkg/webservice/routes.go b/pkg/webservice/routes.go
index 633fa78..6871cc7 100644
--- a/pkg/webservice/routes.go
+++ b/pkg/webservice/routes.go
@@ -188,6 +188,12 @@
"/ws/v1/events/batch",
getEvents,
},
+ route{
+ "Scheduler",
+ "GET",
+ "/ws/v1/events/stream",
+ getStream,
+ },
// endpoint to retrieve CPU, Memory profiling data,
// this works with pprof tool. By default, pprof endpoints
// are only registered to http.DefaultServeMux. Here, we