| package plugin |
| |
| import ( |
| "encoding/binary" |
| "fmt" |
| "log" |
| "net" |
| "sync" |
| "sync/atomic" |
| "time" |
| |
| "github.com/hashicorp/yamux" |
| ) |
| |
| // MuxBroker is responsible for brokering multiplexed connections by unique ID. |
| // |
| // It is used by plugins to multiplex multiple RPC connections and data |
| // streams on top of a single connection between the plugin process and the |
| // host process. |
| // |
| // This allows a plugin to request a channel with a specific ID to connect to |
| // or accept a connection from, and the broker handles the details of |
| // holding these channels open while they're being negotiated. |
| // |
| // The Plugin interface has access to these for both Server and Client. |
| // The broker can be used by either (optionally) to reserve and connect to |
| // new multiplexed streams. This is useful for complex args and return values, |
| // or anything else you might need a data stream for. |
| type MuxBroker struct { |
| nextId uint32 |
| session *yamux.Session |
| streams map[uint32]*muxBrokerPending |
| |
| sync.Mutex |
| } |
| |
| type muxBrokerPending struct { |
| ch chan net.Conn |
| doneCh chan struct{} |
| } |
| |
| func newMuxBroker(s *yamux.Session) *MuxBroker { |
| return &MuxBroker{ |
| session: s, |
| streams: make(map[uint32]*muxBrokerPending), |
| } |
| } |
| |
| // Accept accepts a connection by ID. |
| // |
| // This should not be called multiple times with the same ID at one time. |
| func (m *MuxBroker) Accept(id uint32) (net.Conn, error) { |
| var c net.Conn |
| p := m.getStream(id) |
| select { |
| case c = <-p.ch: |
| close(p.doneCh) |
| case <-time.After(5 * time.Second): |
| m.Lock() |
| defer m.Unlock() |
| delete(m.streams, id) |
| |
| return nil, fmt.Errorf("timeout waiting for accept") |
| } |
| |
| // Ack our connection |
| if err := binary.Write(c, binary.LittleEndian, id); err != nil { |
| c.Close() |
| return nil, err |
| } |
| |
| return c, nil |
| } |
| |
| // AcceptAndServe is used to accept a specific stream ID and immediately |
| // serve an RPC server on that stream ID. This is used to easily serve |
| // complex arguments. |
| // |
| // The served interface is always registered to the "Plugin" name. |
| func (m *MuxBroker) AcceptAndServe(id uint32, v interface{}) { |
| conn, err := m.Accept(id) |
| if err != nil { |
| log.Printf("[ERR] plugin: plugin acceptAndServe error: %s", err) |
| return |
| } |
| |
| serve(conn, "Plugin", v) |
| } |
| |
| // Close closes the connection and all sub-connections. |
| func (m *MuxBroker) Close() error { |
| return m.session.Close() |
| } |
| |
| // Dial opens a connection by ID. |
| func (m *MuxBroker) Dial(id uint32) (net.Conn, error) { |
| // Open the stream |
| stream, err := m.session.OpenStream() |
| if err != nil { |
| return nil, err |
| } |
| |
| // Write the stream ID onto the wire. |
| if err := binary.Write(stream, binary.LittleEndian, id); err != nil { |
| stream.Close() |
| return nil, err |
| } |
| |
| // Read the ack that we connected. Then we're off! |
| var ack uint32 |
| if err := binary.Read(stream, binary.LittleEndian, &ack); err != nil { |
| stream.Close() |
| return nil, err |
| } |
| if ack != id { |
| stream.Close() |
| return nil, fmt.Errorf("bad ack: %d (expected %d)", ack, id) |
| } |
| |
| return stream, nil |
| } |
| |
| // NextId returns a unique ID to use next. |
| // |
| // It is possible for very long-running plugin hosts to wrap this value, |
| // though it would require a very large amount of RPC calls. In practice |
| // we've never seen it happen. |
| func (m *MuxBroker) NextId() uint32 { |
| return atomic.AddUint32(&m.nextId, 1) |
| } |
| |
| // Run starts the brokering and should be executed in a goroutine, since it |
| // blocks forever, or until the session closes. |
| // |
| // Uses of MuxBroker never need to call this. It is called internally by |
| // the plugin host/client. |
| func (m *MuxBroker) Run() { |
| for { |
| stream, err := m.session.AcceptStream() |
| if err != nil { |
| // Once we receive an error, just exit |
| break |
| } |
| |
| // Read the stream ID from the stream |
| var id uint32 |
| if err := binary.Read(stream, binary.LittleEndian, &id); err != nil { |
| stream.Close() |
| continue |
| } |
| |
| // Initialize the waiter |
| p := m.getStream(id) |
| select { |
| case p.ch <- stream: |
| default: |
| } |
| |
| // Wait for a timeout |
| go m.timeoutWait(id, p) |
| } |
| } |
| |
| func (m *MuxBroker) getStream(id uint32) *muxBrokerPending { |
| m.Lock() |
| defer m.Unlock() |
| |
| p, ok := m.streams[id] |
| if ok { |
| return p |
| } |
| |
| m.streams[id] = &muxBrokerPending{ |
| ch: make(chan net.Conn, 1), |
| doneCh: make(chan struct{}), |
| } |
| return m.streams[id] |
| } |
| |
| func (m *MuxBroker) timeoutWait(id uint32, p *muxBrokerPending) { |
| // Wait for the stream to either be picked up and connected, or |
| // for a timeout. |
| timeout := false |
| select { |
| case <-p.doneCh: |
| case <-time.After(5 * time.Second): |
| timeout = true |
| } |
| |
| m.Lock() |
| defer m.Unlock() |
| |
| // Delete the stream so no one else can grab it |
| delete(m.streams, id) |
| |
| // If we timed out, then check if we have a channel in the buffer, |
| // and if so, close it. |
| if timeout { |
| select { |
| case s := <-p.ch: |
| s.Close() |
| } |
| } |
| } |