blob: 65f68b0ea0ca3ee2d9ccc4ab6cb0f55f4ed644ba [file] [log] [blame]
// Copyright 2016 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package grpcproxy
import (
"context"
"io"
"sync"
"sync/atomic"
"time"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
type leaseProxy struct {
// leaseClient handles req from LeaseGrant() that requires a lease ID.
leaseClient pb.LeaseClient
lessor clientv3.Lease
ctx context.Context
leader *leader
// mu protects adding outstanding leaseProxyStream through wg.
mu sync.RWMutex
// wg waits until all outstanding leaseProxyStream quit.
wg sync.WaitGroup
}
func NewLeaseProxy(c *clientv3.Client) (pb.LeaseServer, <-chan struct{}) {
cctx, cancel := context.WithCancel(c.Ctx())
lp := &leaseProxy{
leaseClient: pb.NewLeaseClient(c.ActiveConnection()),
lessor: c.Lease,
ctx: cctx,
leader: newLeader(c.Ctx(), c.Watcher),
}
ch := make(chan struct{})
go func() {
defer close(ch)
<-lp.leader.stopNotify()
lp.mu.Lock()
select {
case <-lp.ctx.Done():
case <-lp.leader.disconnectNotify():
cancel()
}
<-lp.ctx.Done()
lp.mu.Unlock()
lp.wg.Wait()
}()
return lp, ch
}
func (lp *leaseProxy) LeaseGrant(ctx context.Context, cr *pb.LeaseGrantRequest) (*pb.LeaseGrantResponse, error) {
rp, err := lp.leaseClient.LeaseGrant(ctx, cr, grpc.FailFast(false))
if err != nil {
return nil, err
}
lp.leader.gotLeader()
return rp, nil
}
func (lp *leaseProxy) LeaseRevoke(ctx context.Context, rr *pb.LeaseRevokeRequest) (*pb.LeaseRevokeResponse, error) {
r, err := lp.lessor.Revoke(ctx, clientv3.LeaseID(rr.ID))
if err != nil {
return nil, err
}
lp.leader.gotLeader()
return (*pb.LeaseRevokeResponse)(r), nil
}
func (lp *leaseProxy) LeaseTimeToLive(ctx context.Context, rr *pb.LeaseTimeToLiveRequest) (*pb.LeaseTimeToLiveResponse, error) {
var (
r *clientv3.LeaseTimeToLiveResponse
err error
)
if rr.Keys {
r, err = lp.lessor.TimeToLive(ctx, clientv3.LeaseID(rr.ID), clientv3.WithAttachedKeys())
} else {
r, err = lp.lessor.TimeToLive(ctx, clientv3.LeaseID(rr.ID))
}
if err != nil {
return nil, err
}
rp := &pb.LeaseTimeToLiveResponse{
Header: r.ResponseHeader,
ID: int64(r.ID),
TTL: r.TTL,
GrantedTTL: r.GrantedTTL,
Keys: r.Keys,
}
return rp, err
}
func (lp *leaseProxy) LeaseLeases(ctx context.Context, rr *pb.LeaseLeasesRequest) (*pb.LeaseLeasesResponse, error) {
r, err := lp.lessor.Leases(ctx)
if err != nil {
return nil, err
}
leases := make([]*pb.LeaseStatus, len(r.Leases))
for i := range r.Leases {
leases[i] = &pb.LeaseStatus{ID: int64(r.Leases[i].ID)}
}
rp := &pb.LeaseLeasesResponse{
Header: r.ResponseHeader,
Leases: leases,
}
return rp, err
}
func (lp *leaseProxy) LeaseKeepAlive(stream pb.Lease_LeaseKeepAliveServer) error {
lp.mu.Lock()
select {
case <-lp.ctx.Done():
lp.mu.Unlock()
return lp.ctx.Err()
default:
lp.wg.Add(1)
}
lp.mu.Unlock()
ctx, cancel := context.WithCancel(stream.Context())
lps := leaseProxyStream{
stream: stream,
lessor: lp.lessor,
keepAliveLeases: make(map[int64]*atomicCounter),
respc: make(chan *pb.LeaseKeepAliveResponse),
ctx: ctx,
cancel: cancel,
}
errc := make(chan error, 2)
var lostLeaderC <-chan struct{}
if md, ok := metadata.FromOutgoingContext(stream.Context()); ok {
v := md[rpctypes.MetadataRequireLeaderKey]
if len(v) > 0 && v[0] == rpctypes.MetadataHasLeader {
lostLeaderC = lp.leader.lostNotify()
// if leader is known to be lost at creation time, avoid
// letting events through at all
select {
case <-lostLeaderC:
lp.wg.Done()
return rpctypes.ErrNoLeader
default:
}
}
}
stopc := make(chan struct{}, 3)
go func() {
defer func() { stopc <- struct{}{} }()
if err := lps.recvLoop(); err != nil {
errc <- err
}
}()
go func() {
defer func() { stopc <- struct{}{} }()
if err := lps.sendLoop(); err != nil {
errc <- err
}
}()
// tears down LeaseKeepAlive stream if leader goes down or entire leaseProxy is terminated.
go func() {
defer func() { stopc <- struct{}{} }()
select {
case <-lostLeaderC:
case <-ctx.Done():
case <-lp.ctx.Done():
}
}()
var err error
select {
case <-stopc:
stopc <- struct{}{}
case err = <-errc:
}
cancel()
// recv/send may only shutdown after function exits;
// this goroutine notifies lease proxy that the stream is through
go func() {
<-stopc
<-stopc
<-stopc
lps.close()
close(errc)
lp.wg.Done()
}()
select {
case <-lostLeaderC:
return rpctypes.ErrNoLeader
case <-lp.leader.disconnectNotify():
return grpc.ErrClientConnClosing
default:
if err != nil {
return err
}
return ctx.Err()
}
}
type leaseProxyStream struct {
stream pb.Lease_LeaseKeepAliveServer
lessor clientv3.Lease
// wg tracks keepAliveLoop goroutines
wg sync.WaitGroup
// mu protects keepAliveLeases
mu sync.RWMutex
// keepAliveLeases tracks how many outstanding keepalive requests which need responses are on a lease.
keepAliveLeases map[int64]*atomicCounter
// respc receives lease keepalive responses from etcd backend
respc chan *pb.LeaseKeepAliveResponse
ctx context.Context
cancel context.CancelFunc
}
func (lps *leaseProxyStream) recvLoop() error {
for {
rr, err := lps.stream.Recv()
if err == io.EOF {
return nil
}
if err != nil {
return err
}
lps.mu.Lock()
neededResps, ok := lps.keepAliveLeases[rr.ID]
if !ok {
neededResps = &atomicCounter{}
lps.keepAliveLeases[rr.ID] = neededResps
lps.wg.Add(1)
go func() {
defer lps.wg.Done()
if err := lps.keepAliveLoop(rr.ID, neededResps); err != nil {
lps.cancel()
}
}()
}
neededResps.add(1)
lps.mu.Unlock()
}
}
func (lps *leaseProxyStream) keepAliveLoop(leaseID int64, neededResps *atomicCounter) error {
cctx, ccancel := context.WithCancel(lps.ctx)
defer ccancel()
respc, err := lps.lessor.KeepAlive(cctx, clientv3.LeaseID(leaseID))
if err != nil {
return err
}
// ticker expires when loop hasn't received keepalive within TTL
var ticker <-chan time.Time
for {
select {
case <-ticker:
lps.mu.Lock()
// if there are outstanding keepAlive reqs at the moment of ticker firing,
// don't close keepAliveLoop(), let it continuing to process the KeepAlive reqs.
if neededResps.get() > 0 {
lps.mu.Unlock()
ticker = nil
continue
}
delete(lps.keepAliveLeases, leaseID)
lps.mu.Unlock()
return nil
case rp, ok := <-respc:
if !ok {
lps.mu.Lock()
delete(lps.keepAliveLeases, leaseID)
lps.mu.Unlock()
if neededResps.get() == 0 {
return nil
}
ttlResp, err := lps.lessor.TimeToLive(cctx, clientv3.LeaseID(leaseID))
if err != nil {
return err
}
r := &pb.LeaseKeepAliveResponse{
Header: ttlResp.ResponseHeader,
ID: int64(ttlResp.ID),
TTL: ttlResp.TTL,
}
for neededResps.get() > 0 {
select {
case lps.respc <- r:
neededResps.add(-1)
case <-lps.ctx.Done():
return nil
}
}
return nil
}
if neededResps.get() == 0 {
continue
}
ticker = time.After(time.Duration(rp.TTL) * time.Second)
r := &pb.LeaseKeepAliveResponse{
Header: rp.ResponseHeader,
ID: int64(rp.ID),
TTL: rp.TTL,
}
lps.replyToClient(r, neededResps)
}
}
}
func (lps *leaseProxyStream) replyToClient(r *pb.LeaseKeepAliveResponse, neededResps *atomicCounter) {
timer := time.After(500 * time.Millisecond)
for neededResps.get() > 0 {
select {
case lps.respc <- r:
neededResps.add(-1)
case <-timer:
return
case <-lps.ctx.Done():
return
}
}
}
func (lps *leaseProxyStream) sendLoop() error {
for {
select {
case lrp, ok := <-lps.respc:
if !ok {
return nil
}
if err := lps.stream.Send(lrp); err != nil {
return err
}
case <-lps.ctx.Done():
return lps.ctx.Err()
}
}
}
func (lps *leaseProxyStream) close() {
lps.cancel()
lps.wg.Wait()
// only close respc channel if all the keepAliveLoop() goroutines have finished
// this ensures those goroutines don't send resp to a closed resp channel
close(lps.respc)
}
type atomicCounter struct {
counter int64
}
func (ac *atomicCounter) add(delta int64) {
atomic.AddInt64(&ac.counter, delta)
}
func (ac *atomicCounter) get() int64 {
return atomic.LoadInt64(&ac.counter)
}