[Issue 140] Consumer should not block on received if closed. (#142)
* [Issue 140] Consumer should not block on received if closed.
* [Issue #140] Consumer should not block on received if closed.
diff --git a/pulsar/consumer_impl.go b/pulsar/consumer_impl.go
index 61c27a7..4a1f85e 100644
--- a/pulsar/consumer_impl.go
+++ b/pulsar/consumer_impl.go
@@ -48,8 +48,9 @@
// channel used to deliver message to clients
messageCh chan ConsumerMessage
- closeCh chan struct{}
- errorCh chan error
+ closeOnce sync.Once
+ closeCh chan struct{}
+ errorCh chan error
log *log.Entry
}
@@ -116,6 +117,7 @@
consumer := &consumer{
options: options,
messageCh: messageCh,
+ closeCh: make(chan struct{}),
errorCh: make(chan error),
log: log.WithField("topic", topic),
}
@@ -226,6 +228,8 @@
func (c *consumer) Receive(ctx context.Context) (message Message, err error) {
for {
select {
+ case <-c.closeCh:
+ return nil, ErrConsumerClosed
case cm, ok := <-c.messageCh:
if !ok {
return nil, ErrConsumerClosed
@@ -298,15 +302,18 @@
}
func (c *consumer) Close() {
- var wg sync.WaitGroup
- for i := range c.consumers {
- wg.Add(1)
- go func(pc *partitionConsumer) {
- defer wg.Done()
- pc.Close()
- }(c.consumers[i])
- }
- wg.Wait()
+ c.closeOnce.Do(func() {
+ var wg sync.WaitGroup
+ for i := range c.consumers {
+ wg.Add(1)
+ go func(pc *partitionConsumer) {
+ defer wg.Done()
+ pc.Close()
+ }(c.consumers[i])
+ }
+ wg.Wait()
+ close(c.closeCh)
+ })
}
var r = &random{
diff --git a/pulsar/consumer_test.go b/pulsar/consumer_test.go
index 11401ba..9bc2bd6 100644
--- a/pulsar/consumer_test.go
+++ b/pulsar/consumer_test.go
@@ -794,3 +794,36 @@
assert.Equal(t, v, mv)
}
}
+
+// Test for issue #140
+// Don't block on receive if the consumer has been closed
+func TestConsumerReceiveErrAfterClose(t *testing.T) {
+ client, err := NewClient(ClientOptions{
+ URL: lookupURL,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer client.Close()
+
+ topicName := newTopicName()
+ consumer, err := client.Subscribe(ConsumerOptions{
+ Topic: topicName,
+ SubscriptionName: "my-sub",
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ consumer.Close()
+
+ errorCh := make(chan error)
+ go func() {
+ _, err = consumer.Receive(context.Background())
+ errorCh <- err
+ }()
+ select {
+ case <-time.After(200 * time.Millisecond):
+ case err = <-errorCh:
+ }
+ assert.Equal(t, ErrConsumerClosed, err)
+}