| // Package zk is a native Go client library for the ZooKeeper orchestration service. |
| package zk |
| |
| /* |
| TODO: |
| * make sure a ping response comes back in a reasonable time |
| |
| Possible watcher events: |
| * Event{Type: EventNotWatching, State: StateDisconnected, Path: path, Err: err} |
| */ |
| |
| import ( |
| "crypto/rand" |
| "encoding/binary" |
| "errors" |
| "fmt" |
| "io" |
| "net" |
| "strconv" |
| "strings" |
| "sync" |
| "sync/atomic" |
| "time" |
| ) |
| |
| // ErrNoServer indicates that an operation cannot be completed |
| // because attempts to connect to all servers in the list failed. |
| var ErrNoServer = errors.New("zk: could not connect to a server") |
| |
| // ErrInvalidPath indicates that an operation was being attempted on |
| // an invalid path. (e.g. empty path) |
| var ErrInvalidPath = errors.New("zk: invalid path") |
| |
| // DefaultLogger uses the stdlib log package for logging. |
| var DefaultLogger Logger = defaultLogger{} |
| |
| const ( |
| bufferSize = 1536 * 1024 |
| eventChanSize = 6 |
| sendChanSize = 16 |
| protectedPrefix = "_c_" |
| ) |
| |
| type watchType int |
| |
| const ( |
| watchTypeData = iota |
| watchTypeExist |
| watchTypeChild |
| ) |
| |
| type watchPathType struct { |
| path string |
| wType watchType |
| } |
| |
| type Dialer func(network, address string, timeout time.Duration) (net.Conn, error) |
| |
| // Logger is an interface that can be implemented to provide custom log output. |
| type Logger interface { |
| Printf(string, ...interface{}) |
| } |
| |
| type authCreds struct { |
| scheme string |
| auth []byte |
| } |
| |
| type Conn struct { |
| lastZxid int64 |
| sessionID int64 |
| state State // must be 32-bit aligned |
| xid uint32 |
| sessionTimeoutMs int32 // session timeout in milliseconds |
| passwd []byte |
| |
| dialer Dialer |
| hostProvider HostProvider |
| serverMu sync.Mutex // protects server |
| server string // remember the address/port of the current server |
| conn net.Conn |
| eventChan chan Event |
| eventCallback EventCallback // may be nil |
| shouldQuit chan struct{} |
| pingInterval time.Duration |
| recvTimeout time.Duration |
| connectTimeout time.Duration |
| maxBufferSize int |
| |
| creds []authCreds |
| credsMu sync.Mutex // protects server |
| |
| sendChan chan *request |
| requests map[int32]*request // Xid -> pending request |
| requestsLock sync.Mutex |
| watchers map[watchPathType][]chan Event |
| watchersLock sync.Mutex |
| closeChan chan struct{} // channel to tell send loop stop |
| |
| // Debug (used by unit tests) |
| reconnectLatch chan struct{} |
| setWatchLimit int |
| setWatchCallback func([]*setWatchesRequest) |
| // Debug (for recurring re-auth hang) |
| debugCloseRecvLoop bool |
| debugReauthDone chan struct{} |
| |
| logger Logger |
| logInfo bool // true if information messages are logged; false if only errors are logged |
| |
| buf []byte |
| } |
| |
| // connOption represents a connection option. |
| type connOption func(c *Conn) |
| |
| type request struct { |
| xid int32 |
| opcode int32 |
| pkt interface{} |
| recvStruct interface{} |
| recvChan chan response |
| |
| // Because sending and receiving happen in separate go routines, there's |
| // a possible race condition when creating watches from outside the read |
| // loop. We must ensure that a watcher gets added to the list synchronously |
| // with the response from the server on any request that creates a watch. |
| // In order to not hard code the watch logic for each opcode in the recv |
| // loop the caller can use recvFunc to insert some synchronously code |
| // after a response. |
| recvFunc func(*request, *responseHeader, error) |
| } |
| |
| type response struct { |
| zxid int64 |
| err error |
| } |
| |
| type Event struct { |
| Type EventType |
| State State |
| Path string // For non-session events, the path of the watched node. |
| Err error |
| Server string // For connection events |
| } |
| |
| // HostProvider is used to represent a set of hosts a ZooKeeper client should connect to. |
| // It is an analog of the Java equivalent: |
| // http://svn.apache.org/viewvc/zookeeper/trunk/src/java/main/org/apache/zookeeper/client/HostProvider.java?view=markup |
| type HostProvider interface { |
| // Init is called first, with the servers specified in the connection string. |
| Init(servers []string) error |
| // Len returns the number of servers. |
| Len() int |
| // Next returns the next server to connect to. retryStart will be true if we've looped through |
| // all known servers without Connected() being called. |
| Next() (server string, retryStart bool) |
| // Notify the HostProvider of a successful connection. |
| Connected() |
| } |
| |
| // ConnectWithDialer establishes a new connection to a pool of zookeeper servers |
| // using a custom Dialer. See Connect for further information about session timeout. |
| // This method is deprecated and provided for compatibility: use the WithDialer option instead. |
| func ConnectWithDialer(servers []string, sessionTimeout time.Duration, dialer Dialer) (*Conn, <-chan Event, error) { |
| return Connect(servers, sessionTimeout, WithDialer(dialer)) |
| } |
| |
| // Connect establishes a new connection to a pool of zookeeper |
| // servers. The provided session timeout sets the amount of time for which |
| // a session is considered valid after losing connection to a server. Within |
| // the session timeout it's possible to reestablish a connection to a different |
| // server and keep the same session. This is means any ephemeral nodes and |
| // watches are maintained. |
| func Connect(servers []string, sessionTimeout time.Duration, options ...connOption) (*Conn, <-chan Event, error) { |
| if len(servers) == 0 { |
| return nil, nil, errors.New("zk: server list must not be empty") |
| } |
| |
| srvs := make([]string, len(servers)) |
| |
| for i, addr := range servers { |
| if strings.Contains(addr, ":") { |
| srvs[i] = addr |
| } else { |
| srvs[i] = addr + ":" + strconv.Itoa(DefaultPort) |
| } |
| } |
| |
| // Randomize the order of the servers to avoid creating hotspots |
| stringShuffle(srvs) |
| |
| ec := make(chan Event, eventChanSize) |
| conn := &Conn{ |
| dialer: net.DialTimeout, |
| hostProvider: &DNSHostProvider{}, |
| conn: nil, |
| state: StateDisconnected, |
| eventChan: ec, |
| shouldQuit: make(chan struct{}), |
| connectTimeout: 1 * time.Second, |
| sendChan: make(chan *request, sendChanSize), |
| requests: make(map[int32]*request), |
| watchers: make(map[watchPathType][]chan Event), |
| passwd: emptyPassword, |
| logger: DefaultLogger, |
| logInfo: true, // default is true for backwards compatability |
| buf: make([]byte, bufferSize), |
| } |
| |
| // Set provided options. |
| for _, option := range options { |
| option(conn) |
| } |
| |
| if err := conn.hostProvider.Init(srvs); err != nil { |
| return nil, nil, err |
| } |
| |
| conn.setTimeouts(int32(sessionTimeout / time.Millisecond)) |
| |
| go func() { |
| conn.loop() |
| conn.flushRequests(ErrClosing) |
| conn.invalidateWatches(ErrClosing) |
| close(conn.eventChan) |
| }() |
| return conn, ec, nil |
| } |
| |
| // WithDialer returns a connection option specifying a non-default Dialer. |
| func WithDialer(dialer Dialer) connOption { |
| return func(c *Conn) { |
| c.dialer = dialer |
| } |
| } |
| |
| // WithHostProvider returns a connection option specifying a non-default HostProvider. |
| func WithHostProvider(hostProvider HostProvider) connOption { |
| return func(c *Conn) { |
| c.hostProvider = hostProvider |
| } |
| } |
| |
| // WithLogger returns a connection option specifying a non-default Logger |
| func WithLogger(logger Logger) connOption { |
| return func(c *Conn) { |
| c.logger = logger |
| } |
| } |
| |
| // WithLogInfo returns a connection option specifying whether or not information messages |
| // shoud be logged. |
| func WithLogInfo(logInfo bool) connOption { |
| return func(c *Conn) { |
| c.logInfo = logInfo |
| } |
| } |
| |
| // EventCallback is a function that is called when an Event occurs. |
| type EventCallback func(Event) |
| |
| // WithEventCallback returns a connection option that specifies an event |
| // callback. |
| // The callback must not block - doing so would delay the ZK go routines. |
| func WithEventCallback(cb EventCallback) connOption { |
| return func(c *Conn) { |
| c.eventCallback = cb |
| } |
| } |
| |
| // WithMaxBufferSize sets the maximum buffer size used to read and decode |
| // packets received from the Zookeeper server. The standard Zookeeper client for |
| // Java defaults to a limit of 1mb. For backwards compatibility, this Go client |
| // defaults to unbounded unless overridden via this option. A value that is zero |
| // or negative indicates that no limit is enforced. |
| // |
| // This is meant to prevent resource exhaustion in the face of potentially |
| // malicious data in ZK. It should generally match the server setting (which |
| // also defaults ot 1mb) so that clients and servers agree on the limits for |
| // things like the size of data in an individual znode and the total size of a |
| // transaction. |
| // |
| // For production systems, this should be set to a reasonable value (ideally |
| // that matches the server configuration). For ops tooling, it is handy to use a |
| // much larger limit, in order to do things like clean-up problematic state in |
| // the ZK tree. For example, if a single znode has a huge number of children, it |
| // is possible for the response to a "list children" operation to exceed this |
| // buffer size and cause errors in clients. The only way to subsequently clean |
| // up the tree (by removing superfluous children) is to use a client configured |
| // with a larger buffer size that can successfully query for all of the child |
| // names and then remove them. (Note there are other tools that can list all of |
| // the child names without an increased buffer size in the client, but they work |
| // by inspecting the servers' transaction logs to enumerate children instead of |
| // sending an online request to a server. |
| func WithMaxBufferSize(maxBufferSize int) connOption { |
| return func(c *Conn) { |
| c.maxBufferSize = maxBufferSize |
| } |
| } |
| |
| // WithMaxConnBufferSize sets maximum buffer size used to send and encode |
| // packets to Zookeeper server. The standard Zookeepeer client for java defaults |
| // to a limit of 1mb. This option should be used for non-standard server setup |
| // where znode is bigger than default 1mb. |
| func WithMaxConnBufferSize(maxBufferSize int) connOption { |
| return func(c *Conn) { |
| c.buf = make([]byte, maxBufferSize) |
| } |
| } |
| |
| func (c *Conn) Close() { |
| close(c.shouldQuit) |
| |
| select { |
| case <-c.queueRequest(opClose, &closeRequest{}, &closeResponse{}, nil): |
| case <-time.After(time.Second): |
| } |
| } |
| |
| // State returns the current state of the connection. |
| func (c *Conn) State() State { |
| return State(atomic.LoadInt32((*int32)(&c.state))) |
| } |
| |
| // SessionID returns the current session id of the connection. |
| func (c *Conn) SessionID() int64 { |
| return atomic.LoadInt64(&c.sessionID) |
| } |
| |
| // SetLogger sets the logger to be used for printing errors. |
| // Logger is an interface provided by this package. |
| func (c *Conn) SetLogger(l Logger) { |
| c.logger = l |
| } |
| |
| func (c *Conn) setTimeouts(sessionTimeoutMs int32) { |
| c.sessionTimeoutMs = sessionTimeoutMs |
| sessionTimeout := time.Duration(sessionTimeoutMs) * time.Millisecond |
| c.recvTimeout = sessionTimeout * 2 / 3 |
| c.pingInterval = c.recvTimeout / 2 |
| } |
| |
| func (c *Conn) setState(state State) { |
| atomic.StoreInt32((*int32)(&c.state), int32(state)) |
| c.sendEvent(Event{Type: EventSession, State: state, Server: c.Server()}) |
| } |
| |
| func (c *Conn) sendEvent(evt Event) { |
| if c.eventCallback != nil { |
| c.eventCallback(evt) |
| } |
| |
| select { |
| case c.eventChan <- evt: |
| default: |
| // panic("zk: event channel full - it must be monitored and never allowed to be full") |
| } |
| } |
| |
| func (c *Conn) connect() error { |
| var retryStart bool |
| for { |
| c.serverMu.Lock() |
| c.server, retryStart = c.hostProvider.Next() |
| c.serverMu.Unlock() |
| c.setState(StateConnecting) |
| if retryStart { |
| c.flushUnsentRequests(ErrNoServer) |
| select { |
| case <-time.After(time.Second): |
| // pass |
| case <-c.shouldQuit: |
| c.setState(StateDisconnected) |
| c.flushUnsentRequests(ErrClosing) |
| return ErrClosing |
| } |
| } |
| |
| zkConn, err := c.dialer("tcp", c.Server(), c.connectTimeout) |
| if err == nil { |
| c.conn = zkConn |
| c.setState(StateConnected) |
| if c.logInfo { |
| c.logger.Printf("Connected to %s", c.Server()) |
| } |
| return nil |
| } |
| |
| c.logger.Printf("Failed to connect to %s: %+v", c.Server(), err) |
| } |
| } |
| |
| func (c *Conn) resendZkAuth(reauthReadyChan chan struct{}) { |
| shouldCancel := func() bool { |
| select { |
| case <-c.shouldQuit: |
| return true |
| case <-c.closeChan: |
| return true |
| default: |
| return false |
| } |
| } |
| |
| c.credsMu.Lock() |
| defer c.credsMu.Unlock() |
| |
| defer close(reauthReadyChan) |
| |
| if c.logInfo { |
| c.logger.Printf("Re-submitting `%d` credentials after reconnect", |
| len(c.creds)) |
| } |
| |
| for _, cred := range c.creds { |
| if shouldCancel() { |
| c.logger.Printf("Cancel rer-submitting credentials") |
| return |
| } |
| resChan, err := c.sendRequest( |
| opSetAuth, |
| &setAuthRequest{Type: 0, |
| Scheme: cred.scheme, |
| Auth: cred.auth, |
| }, |
| &setAuthResponse{}, |
| nil) |
| |
| if err != nil { |
| c.logger.Printf("Call to sendRequest failed during credential resubmit: %s", err) |
| // FIXME(prozlach): lets ignore errors for now |
| continue |
| } |
| |
| var res response |
| select { |
| case res = <-resChan: |
| case <-c.closeChan: |
| c.logger.Printf("Recv closed, cancel re-submitting credentials") |
| return |
| case <-c.shouldQuit: |
| c.logger.Printf("Should quit, cancel re-submitting credentials") |
| return |
| } |
| if res.err != nil { |
| c.logger.Printf("Credential re-submit failed: %s", res.err) |
| // FIXME(prozlach): lets ignore errors for now |
| continue |
| } |
| } |
| } |
| |
| func (c *Conn) sendRequest( |
| opcode int32, |
| req interface{}, |
| res interface{}, |
| recvFunc func(*request, *responseHeader, error), |
| ) ( |
| <-chan response, |
| error, |
| ) { |
| rq := &request{ |
| xid: c.nextXid(), |
| opcode: opcode, |
| pkt: req, |
| recvStruct: res, |
| recvChan: make(chan response, 1), |
| recvFunc: recvFunc, |
| } |
| |
| if err := c.sendData(rq); err != nil { |
| return nil, err |
| } |
| |
| return rq.recvChan, nil |
| } |
| |
| func (c *Conn) loop() { |
| for { |
| if err := c.connect(); err != nil { |
| // c.Close() was called |
| return |
| } |
| |
| err := c.authenticate() |
| switch { |
| case err == ErrSessionExpired: |
| c.logger.Printf("Authentication failed: %s", err) |
| c.invalidateWatches(err) |
| case err != nil && c.conn != nil: |
| c.logger.Printf("Authentication failed: %s", err) |
| c.conn.Close() |
| case err == nil: |
| if c.logInfo { |
| c.logger.Printf("Authenticated: id=%d, timeout=%d", c.SessionID(), c.sessionTimeoutMs) |
| } |
| c.hostProvider.Connected() // mark success |
| c.closeChan = make(chan struct{}) // channel to tell send loop stop |
| reauthChan := make(chan struct{}) // channel to tell send loop that authdata has been resubmitted |
| |
| var wg sync.WaitGroup |
| wg.Add(1) |
| go func() { |
| <-reauthChan |
| if c.debugCloseRecvLoop { |
| close(c.debugReauthDone) |
| } |
| err := c.sendLoop() |
| if err != nil || c.logInfo { |
| c.logger.Printf("Send loop terminated: err=%v", err) |
| } |
| c.conn.Close() // causes recv loop to EOF/exit |
| wg.Done() |
| }() |
| |
| wg.Add(1) |
| go func() { |
| var err error |
| if c.debugCloseRecvLoop { |
| err = errors.New("DEBUG: close recv loop") |
| } else { |
| err = c.recvLoop(c.conn) |
| } |
| if err != io.EOF || c.logInfo { |
| c.logger.Printf("Recv loop terminated: err=%v", err) |
| } |
| if err == nil { |
| panic("zk: recvLoop should never return nil error") |
| } |
| close(c.closeChan) // tell send loop to exit |
| wg.Done() |
| }() |
| |
| c.resendZkAuth(reauthChan) |
| |
| c.sendSetWatches() |
| wg.Wait() |
| } |
| |
| c.setState(StateDisconnected) |
| |
| select { |
| case <-c.shouldQuit: |
| c.flushRequests(ErrClosing) |
| return |
| default: |
| } |
| |
| if err != ErrSessionExpired { |
| err = ErrConnectionClosed |
| } |
| c.flushRequests(err) |
| |
| if c.reconnectLatch != nil { |
| select { |
| case <-c.shouldQuit: |
| return |
| case <-c.reconnectLatch: |
| } |
| } |
| } |
| } |
| |
| func (c *Conn) flushUnsentRequests(err error) { |
| for { |
| select { |
| default: |
| return |
| case req := <-c.sendChan: |
| req.recvChan <- response{-1, err} |
| } |
| } |
| } |
| |
| // Send error to all pending requests and clear request map |
| func (c *Conn) flushRequests(err error) { |
| c.requestsLock.Lock() |
| for _, req := range c.requests { |
| req.recvChan <- response{-1, err} |
| } |
| c.requests = make(map[int32]*request) |
| c.requestsLock.Unlock() |
| } |
| |
| // Send error to all watchers and clear watchers map |
| func (c *Conn) invalidateWatches(err error) { |
| c.watchersLock.Lock() |
| defer c.watchersLock.Unlock() |
| |
| if len(c.watchers) >= 0 { |
| for pathType, watchers := range c.watchers { |
| ev := Event{Type: EventNotWatching, State: StateDisconnected, Path: pathType.path, Err: err} |
| for _, ch := range watchers { |
| ch <- ev |
| close(ch) |
| } |
| } |
| c.watchers = make(map[watchPathType][]chan Event) |
| } |
| } |
| |
| func (c *Conn) sendSetWatches() { |
| c.watchersLock.Lock() |
| defer c.watchersLock.Unlock() |
| |
| if len(c.watchers) == 0 { |
| return |
| } |
| |
| // NB: A ZK server, by default, rejects packets >1mb. So, if we have too |
| // many watches to reset, we need to break this up into multiple packets |
| // to avoid hitting that limit. Mirroring the Java client behavior: we are |
| // conservative in that we limit requests to 128kb (since server limit is |
| // is actually configurable and could conceivably be configured smaller |
| // than default of 1mb). |
| limit := 128 * 1024 |
| if c.setWatchLimit > 0 { |
| limit = c.setWatchLimit |
| } |
| |
| var reqs []*setWatchesRequest |
| var req *setWatchesRequest |
| var sizeSoFar int |
| |
| n := 0 |
| for pathType, watchers := range c.watchers { |
| if len(watchers) == 0 { |
| continue |
| } |
| addlLen := 4 + len(pathType.path) |
| if req == nil || sizeSoFar+addlLen > limit { |
| if req != nil { |
| // add to set of requests that we'll send |
| reqs = append(reqs, req) |
| } |
| sizeSoFar = 28 // fixed overhead of a set-watches packet |
| req = &setWatchesRequest{ |
| RelativeZxid: c.lastZxid, |
| DataWatches: make([]string, 0), |
| ExistWatches: make([]string, 0), |
| ChildWatches: make([]string, 0), |
| } |
| } |
| sizeSoFar += addlLen |
| switch pathType.wType { |
| case watchTypeData: |
| req.DataWatches = append(req.DataWatches, pathType.path) |
| case watchTypeExist: |
| req.ExistWatches = append(req.ExistWatches, pathType.path) |
| case watchTypeChild: |
| req.ChildWatches = append(req.ChildWatches, pathType.path) |
| } |
| n++ |
| } |
| if n == 0 { |
| return |
| } |
| if req != nil { // don't forget any trailing packet we were building |
| reqs = append(reqs, req) |
| } |
| |
| if c.setWatchCallback != nil { |
| c.setWatchCallback(reqs) |
| } |
| |
| go func() { |
| res := &setWatchesResponse{} |
| // TODO: Pipeline these so queue all of them up before waiting on any |
| // response. That will require some investigation to make sure there |
| // aren't failure modes where a blocking write to the channel of requests |
| // could hang indefinitely and cause this goroutine to leak... |
| for _, req := range reqs { |
| _, err := c.request(opSetWatches, req, res, nil) |
| if err != nil { |
| c.logger.Printf("Failed to set previous watches: %s", err.Error()) |
| break |
| } |
| } |
| }() |
| } |
| |
| func (c *Conn) authenticate() error { |
| buf := make([]byte, 256) |
| |
| // Encode and send a connect request. |
| n, err := encodePacket(buf[4:], &connectRequest{ |
| ProtocolVersion: protocolVersion, |
| LastZxidSeen: c.lastZxid, |
| TimeOut: c.sessionTimeoutMs, |
| SessionID: c.SessionID(), |
| Passwd: c.passwd, |
| }) |
| if err != nil { |
| return err |
| } |
| |
| binary.BigEndian.PutUint32(buf[:4], uint32(n)) |
| |
| c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout * 10)) |
| _, err = c.conn.Write(buf[:n+4]) |
| c.conn.SetWriteDeadline(time.Time{}) |
| if err != nil { |
| return err |
| } |
| |
| // Receive and decode a connect response. |
| c.conn.SetReadDeadline(time.Now().Add(c.recvTimeout * 10)) |
| _, err = io.ReadFull(c.conn, buf[:4]) |
| c.conn.SetReadDeadline(time.Time{}) |
| if err != nil { |
| return err |
| } |
| |
| blen := int(binary.BigEndian.Uint32(buf[:4])) |
| if cap(buf) < blen { |
| buf = make([]byte, blen) |
| } |
| |
| _, err = io.ReadFull(c.conn, buf[:blen]) |
| if err != nil { |
| return err |
| } |
| |
| r := connectResponse{} |
| _, err = decodePacket(buf[:blen], &r) |
| if err != nil { |
| return err |
| } |
| if r.SessionID == 0 { |
| atomic.StoreInt64(&c.sessionID, int64(0)) |
| c.passwd = emptyPassword |
| c.lastZxid = 0 |
| c.setState(StateExpired) |
| return ErrSessionExpired |
| } |
| |
| atomic.StoreInt64(&c.sessionID, r.SessionID) |
| c.setTimeouts(r.TimeOut) |
| c.passwd = r.Passwd |
| c.setState(StateHasSession) |
| |
| return nil |
| } |
| |
| func (c *Conn) sendData(req *request) error { |
| header := &requestHeader{req.xid, req.opcode} |
| n, err := encodePacket(c.buf[4:], header) |
| if err != nil { |
| req.recvChan <- response{-1, err} |
| return nil |
| } |
| |
| n2, err := encodePacket(c.buf[4+n:], req.pkt) |
| if err != nil { |
| req.recvChan <- response{-1, err} |
| return nil |
| } |
| |
| n += n2 |
| |
| binary.BigEndian.PutUint32(c.buf[:4], uint32(n)) |
| |
| c.requestsLock.Lock() |
| select { |
| case <-c.closeChan: |
| req.recvChan <- response{-1, ErrConnectionClosed} |
| c.requestsLock.Unlock() |
| return ErrConnectionClosed |
| default: |
| } |
| c.requests[req.xid] = req |
| c.requestsLock.Unlock() |
| |
| c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout)) |
| _, err = c.conn.Write(c.buf[:n+4]) |
| c.conn.SetWriteDeadline(time.Time{}) |
| if err != nil { |
| req.recvChan <- response{-1, err} |
| c.conn.Close() |
| return err |
| } |
| |
| return nil |
| } |
| |
| func (c *Conn) sendLoop() error { |
| pingTicker := time.NewTicker(c.pingInterval) |
| defer pingTicker.Stop() |
| |
| for { |
| select { |
| case req := <-c.sendChan: |
| if err := c.sendData(req); err != nil { |
| return err |
| } |
| case <-pingTicker.C: |
| n, err := encodePacket(c.buf[4:], &requestHeader{Xid: -2, Opcode: opPing}) |
| if err != nil { |
| panic("zk: opPing should never fail to serialize") |
| } |
| |
| binary.BigEndian.PutUint32(c.buf[:4], uint32(n)) |
| |
| c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout)) |
| _, err = c.conn.Write(c.buf[:n+4]) |
| c.conn.SetWriteDeadline(time.Time{}) |
| if err != nil { |
| c.conn.Close() |
| return err |
| } |
| case <-c.closeChan: |
| return nil |
| } |
| } |
| } |
| |
| func (c *Conn) recvLoop(conn net.Conn) error { |
| sz := bufferSize |
| if c.maxBufferSize > 0 && sz > c.maxBufferSize { |
| sz = c.maxBufferSize |
| } |
| buf := make([]byte, sz) |
| for { |
| // package length |
| conn.SetReadDeadline(time.Now().Add(c.recvTimeout)) |
| _, err := io.ReadFull(conn, buf[:4]) |
| if err != nil { |
| return err |
| } |
| |
| blen := int(binary.BigEndian.Uint32(buf[:4])) |
| if cap(buf) < blen { |
| if c.maxBufferSize > 0 && blen > c.maxBufferSize { |
| return fmt.Errorf("received packet from server with length %d, which exceeds max buffer size %d", blen, c.maxBufferSize) |
| } |
| buf = make([]byte, blen) |
| } |
| |
| _, err = io.ReadFull(conn, buf[:blen]) |
| conn.SetReadDeadline(time.Time{}) |
| if err != nil { |
| return err |
| } |
| |
| res := responseHeader{} |
| _, err = decodePacket(buf[:16], &res) |
| if err != nil { |
| return err |
| } |
| |
| if res.Xid == -1 { |
| res := &watcherEvent{} |
| _, err := decodePacket(buf[16:blen], res) |
| if err != nil { |
| return err |
| } |
| ev := Event{ |
| Type: res.Type, |
| State: res.State, |
| Path: res.Path, |
| Err: nil, |
| } |
| c.sendEvent(ev) |
| wTypes := make([]watchType, 0, 2) |
| switch res.Type { |
| case EventNodeCreated: |
| wTypes = append(wTypes, watchTypeExist) |
| case EventNodeDeleted, EventNodeDataChanged: |
| wTypes = append(wTypes, watchTypeExist, watchTypeData, watchTypeChild) |
| case EventNodeChildrenChanged: |
| wTypes = append(wTypes, watchTypeChild) |
| } |
| c.watchersLock.Lock() |
| for _, t := range wTypes { |
| wpt := watchPathType{res.Path, t} |
| if watchers := c.watchers[wpt]; watchers != nil && len(watchers) > 0 { |
| for _, ch := range watchers { |
| ch <- ev |
| close(ch) |
| } |
| delete(c.watchers, wpt) |
| } |
| } |
| c.watchersLock.Unlock() |
| } else if res.Xid == -2 { |
| // Ping response. Ignore. |
| } else if res.Xid < 0 { |
| c.logger.Printf("Xid < 0 (%d) but not ping or watcher event", res.Xid) |
| } else { |
| if res.Zxid > 0 { |
| c.lastZxid = res.Zxid |
| } |
| |
| c.requestsLock.Lock() |
| req, ok := c.requests[res.Xid] |
| if ok { |
| delete(c.requests, res.Xid) |
| } |
| c.requestsLock.Unlock() |
| |
| if !ok { |
| c.logger.Printf("Response for unknown request with xid %d", res.Xid) |
| } else { |
| if res.Err != 0 { |
| err = res.Err.toError() |
| } else { |
| _, err = decodePacket(buf[16:blen], req.recvStruct) |
| } |
| if req.recvFunc != nil { |
| req.recvFunc(req, &res, err) |
| } |
| req.recvChan <- response{res.Zxid, err} |
| if req.opcode == opClose { |
| return io.EOF |
| } |
| } |
| } |
| } |
| } |
| |
| func (c *Conn) nextXid() int32 { |
| return int32(atomic.AddUint32(&c.xid, 1) & 0x7fffffff) |
| } |
| |
| func (c *Conn) addWatcher(path string, watchType watchType) <-chan Event { |
| c.watchersLock.Lock() |
| defer c.watchersLock.Unlock() |
| |
| ch := make(chan Event, 1) |
| wpt := watchPathType{path, watchType} |
| c.watchers[wpt] = append(c.watchers[wpt], ch) |
| return ch |
| } |
| |
| func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) <-chan response { |
| rq := &request{ |
| xid: c.nextXid(), |
| opcode: opcode, |
| pkt: req, |
| recvStruct: res, |
| recvChan: make(chan response, 1), |
| recvFunc: recvFunc, |
| } |
| c.sendChan <- rq |
| return rq.recvChan |
| } |
| |
| func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (int64, error) { |
| r := <-c.queueRequest(opcode, req, res, recvFunc) |
| return r.zxid, r.err |
| } |
| |
| func (c *Conn) AddAuth(scheme string, auth []byte) error { |
| _, err := c.request(opSetAuth, &setAuthRequest{Type: 0, Scheme: scheme, Auth: auth}, &setAuthResponse{}, nil) |
| |
| if err != nil { |
| return err |
| } |
| |
| // Remember authdata so that it can be re-submitted on reconnect |
| // |
| // FIXME(prozlach): For now we treat "userfoo:passbar" and "userfoo:passbar2" |
| // as two different entries, which will be re-submitted on reconnet. Some |
| // research is needed on how ZK treats these cases and |
| // then maybe switch to something like "map[username] = password" to allow |
| // only single password for given user with users being unique. |
| obj := authCreds{ |
| scheme: scheme, |
| auth: auth, |
| } |
| |
| c.credsMu.Lock() |
| c.creds = append(c.creds, obj) |
| c.credsMu.Unlock() |
| |
| return nil |
| } |
| |
| func (c *Conn) Children(path string) ([]string, *Stat, error) { |
| if err := validatePath(path, false); err != nil { |
| return nil, nil, err |
| } |
| |
| res := &getChildren2Response{} |
| _, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: false}, res, nil) |
| return res.Children, &res.Stat, err |
| } |
| |
| func (c *Conn) ChildrenW(path string) ([]string, *Stat, <-chan Event, error) { |
| if err := validatePath(path, false); err != nil { |
| return nil, nil, nil, err |
| } |
| |
| var ech <-chan Event |
| res := &getChildren2Response{} |
| _, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { |
| if err == nil { |
| ech = c.addWatcher(path, watchTypeChild) |
| } |
| }) |
| if err != nil { |
| return nil, nil, nil, err |
| } |
| return res.Children, &res.Stat, ech, err |
| } |
| |
| func (c *Conn) Get(path string) ([]byte, *Stat, error) { |
| if err := validatePath(path, false); err != nil { |
| return nil, nil, err |
| } |
| |
| res := &getDataResponse{} |
| _, err := c.request(opGetData, &getDataRequest{Path: path, Watch: false}, res, nil) |
| return res.Data, &res.Stat, err |
| } |
| |
| // GetW returns the contents of a znode and sets a watch |
| func (c *Conn) GetW(path string) ([]byte, *Stat, <-chan Event, error) { |
| if err := validatePath(path, false); err != nil { |
| return nil, nil, nil, err |
| } |
| |
| var ech <-chan Event |
| res := &getDataResponse{} |
| _, err := c.request(opGetData, &getDataRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { |
| if err == nil { |
| ech = c.addWatcher(path, watchTypeData) |
| } |
| }) |
| if err != nil { |
| return nil, nil, nil, err |
| } |
| return res.Data, &res.Stat, ech, err |
| } |
| |
| func (c *Conn) Set(path string, data []byte, version int32) (*Stat, error) { |
| if err := validatePath(path, false); err != nil { |
| return nil, err |
| } |
| |
| res := &setDataResponse{} |
| _, err := c.request(opSetData, &SetDataRequest{path, data, version}, res, nil) |
| return &res.Stat, err |
| } |
| |
| func (c *Conn) Create(path string, data []byte, flags int32, acl []ACL) (string, error) { |
| if err := validatePath(path, flags&FlagSequence == FlagSequence); err != nil { |
| return "", err |
| } |
| |
| res := &createResponse{} |
| _, err := c.request(opCreate, &CreateRequest{path, data, acl, flags}, res, nil) |
| return res.Path, err |
| } |
| |
| // CreateProtectedEphemeralSequential fixes a race condition if the server crashes |
| // after it creates the node. On reconnect the session may still be valid so the |
| // ephemeral node still exists. Therefore, on reconnect we need to check if a node |
| // with a GUID generated on create exists. |
| func (c *Conn) CreateProtectedEphemeralSequential(path string, data []byte, acl []ACL) (string, error) { |
| if err := validatePath(path, true); err != nil { |
| return "", err |
| } |
| |
| var guid [16]byte |
| _, err := io.ReadFull(rand.Reader, guid[:16]) |
| if err != nil { |
| return "", err |
| } |
| guidStr := fmt.Sprintf("%x", guid) |
| |
| parts := strings.Split(path, "/") |
| parts[len(parts)-1] = fmt.Sprintf("%s%s-%s", protectedPrefix, guidStr, parts[len(parts)-1]) |
| rootPath := strings.Join(parts[:len(parts)-1], "/") |
| protectedPath := strings.Join(parts, "/") |
| |
| var newPath string |
| for i := 0; i < 3; i++ { |
| newPath, err = c.Create(protectedPath, data, FlagEphemeral|FlagSequence, acl) |
| switch err { |
| case ErrSessionExpired: |
| // No need to search for the node since it can't exist. Just try again. |
| case ErrConnectionClosed: |
| children, _, err := c.Children(rootPath) |
| if err != nil { |
| return "", err |
| } |
| for _, p := range children { |
| parts := strings.Split(p, "/") |
| if pth := parts[len(parts)-1]; strings.HasPrefix(pth, protectedPrefix) { |
| if g := pth[len(protectedPrefix) : len(protectedPrefix)+32]; g == guidStr { |
| return rootPath + "/" + p, nil |
| } |
| } |
| } |
| case nil: |
| return newPath, nil |
| default: |
| return "", err |
| } |
| } |
| return "", err |
| } |
| |
| func (c *Conn) Delete(path string, version int32) error { |
| if err := validatePath(path, false); err != nil { |
| return err |
| } |
| |
| _, err := c.request(opDelete, &DeleteRequest{path, version}, &deleteResponse{}, nil) |
| return err |
| } |
| |
| func (c *Conn) Exists(path string) (bool, *Stat, error) { |
| if err := validatePath(path, false); err != nil { |
| return false, nil, err |
| } |
| |
| res := &existsResponse{} |
| _, err := c.request(opExists, &existsRequest{Path: path, Watch: false}, res, nil) |
| exists := true |
| if err == ErrNoNode { |
| exists = false |
| err = nil |
| } |
| return exists, &res.Stat, err |
| } |
| |
| func (c *Conn) ExistsW(path string) (bool, *Stat, <-chan Event, error) { |
| if err := validatePath(path, false); err != nil { |
| return false, nil, nil, err |
| } |
| |
| var ech <-chan Event |
| res := &existsResponse{} |
| _, err := c.request(opExists, &existsRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { |
| if err == nil { |
| ech = c.addWatcher(path, watchTypeData) |
| } else if err == ErrNoNode { |
| ech = c.addWatcher(path, watchTypeExist) |
| } |
| }) |
| exists := true |
| if err == ErrNoNode { |
| exists = false |
| err = nil |
| } |
| if err != nil { |
| return false, nil, nil, err |
| } |
| return exists, &res.Stat, ech, err |
| } |
| |
| func (c *Conn) GetACL(path string) ([]ACL, *Stat, error) { |
| if err := validatePath(path, false); err != nil { |
| return nil, nil, err |
| } |
| |
| res := &getAclResponse{} |
| _, err := c.request(opGetAcl, &getAclRequest{Path: path}, res, nil) |
| return res.Acl, &res.Stat, err |
| } |
| func (c *Conn) SetACL(path string, acl []ACL, version int32) (*Stat, error) { |
| if err := validatePath(path, false); err != nil { |
| return nil, err |
| } |
| |
| res := &setAclResponse{} |
| _, err := c.request(opSetAcl, &setAclRequest{Path: path, Acl: acl, Version: version}, res, nil) |
| return &res.Stat, err |
| } |
| |
| func (c *Conn) Sync(path string) (string, error) { |
| if err := validatePath(path, false); err != nil { |
| return "", err |
| } |
| |
| res := &syncResponse{} |
| _, err := c.request(opSync, &syncRequest{Path: path}, res, nil) |
| return res.Path, err |
| } |
| |
| type MultiResponse struct { |
| Stat *Stat |
| String string |
| Error error |
| } |
| |
| // Multi executes multiple ZooKeeper operations or none of them. The provided |
| // ops must be one of *CreateRequest, *DeleteRequest, *SetDataRequest, or |
| // *CheckVersionRequest. |
| func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) { |
| req := &multiRequest{ |
| Ops: make([]multiRequestOp, 0, len(ops)), |
| DoneHeader: multiHeader{Type: -1, Done: true, Err: -1}, |
| } |
| for _, op := range ops { |
| var opCode int32 |
| switch op.(type) { |
| case *CreateRequest: |
| opCode = opCreate |
| case *SetDataRequest: |
| opCode = opSetData |
| case *DeleteRequest: |
| opCode = opDelete |
| case *CheckVersionRequest: |
| opCode = opCheck |
| default: |
| return nil, fmt.Errorf("unknown operation type %T", op) |
| } |
| req.Ops = append(req.Ops, multiRequestOp{multiHeader{opCode, false, -1}, op}) |
| } |
| res := &multiResponse{} |
| _, err := c.request(opMulti, req, res, nil) |
| mr := make([]MultiResponse, len(res.Ops)) |
| for i, op := range res.Ops { |
| mr[i] = MultiResponse{Stat: op.Stat, String: op.String, Error: op.Err.toError()} |
| } |
| return mr, err |
| } |
| |
| // Server returns the current or last-connected server name. |
| func (c *Conn) Server() string { |
| c.serverMu.Lock() |
| defer c.serverMu.Unlock() |
| return c.server |
| } |