blob: 3388419f194eebaa3db317f2d23a856fcb26f5a8 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package grpc
import (
"sync"
)
import (
"github.com/pkg/errors"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
)
type ReverseUnaryMessage interface {
proto.Message
GetRequestId() string
}
// ReverseUnaryRPCs helps to implement reverse unary rpcs where server sends requests to a client and receives responses from the client.
type ReverseUnaryRPCs interface {
Send(client string, req ReverseUnaryMessage) error
WatchResponse(client string, reqID string, resp chan ReverseUnaryMessage) error
DeleteWatch(client string, reqID string)
ClientConnected(client string, stream grpc.ServerStream)
ClientDisconnected(client string)
ResponseReceived(client string, resp ReverseUnaryMessage) error
}
type clientStreams struct {
streamForClient map[string]*clientStream
sync.Mutex // protects streamForClient
}
func (x *clientStreams) ResponseReceived(client string, resp ReverseUnaryMessage) error {
stream, err := x.clientStream(client)
if err != nil {
return err
}
stream.Lock()
ch, ok := stream.watchForRequestId[resp.GetRequestId()]
stream.Unlock()
if !ok {
return errors.Errorf("callback for request Id %s not found", resp.GetRequestId())
}
ch <- resp
return nil
}
func NewReverseUnaryRPCs() ReverseUnaryRPCs {
return &clientStreams{
streamForClient: map[string]*clientStream{},
}
}
func (x *clientStreams) ClientConnected(client string, stream grpc.ServerStream) {
x.Lock()
defer x.Unlock()
x.streamForClient[client] = &clientStream{
stream: stream,
watchForRequestId: map[string]chan ReverseUnaryMessage{},
}
}
func (x *clientStreams) clientStream(client string) (*clientStream, error) {
x.Lock()
defer x.Unlock()
stream, ok := x.streamForClient[client]
if !ok {
return nil, errors.Errorf("client %s is not connected", client)
}
return stream, nil
}
func (x *clientStreams) ClientDisconnected(client string) {
x.Lock()
defer x.Unlock()
delete(x.streamForClient, client)
}
type clientStream struct {
stream grpc.ServerStream
watchForRequestId map[string]chan ReverseUnaryMessage
sync.Mutex // protects watchForRequestId
}
func (x *clientStreams) Send(client string, req ReverseUnaryMessage) error {
stream, err := x.clientStream(client)
if err != nil {
return err
}
return stream.stream.SendMsg(req)
}
func (x *clientStreams) WatchResponse(client string, reqID string, resp chan ReverseUnaryMessage) error {
stream, err := x.clientStream(client)
if err != nil {
return err
}
stream.Lock()
defer stream.Unlock()
stream.watchForRequestId[reqID] = resp
return nil
}
func (x *clientStreams) DeleteWatch(client string, reqID string) {
stream, err := x.clientStream(client)
if err != nil {
return // client was already deleted
}
stream.Lock()
defer stream.Unlock()
delete(stream.watchForRequestId, reqID)
}