| package fasthttp |
| |
| import ( |
| "errors" |
| "net" |
| "strconv" |
| "sync" |
| "sync/atomic" |
| "time" |
| ) |
| |
| // Dial dials the given TCP addr using tcp4. |
| // |
| // This function has the following additional features comparing to net.Dial: |
| // |
| // * It reduces load on DNS resolver by caching resolved TCP addressed |
| // for DefaultDNSCacheDuration. |
| // * It dials all the resolved TCP addresses in round-robin manner until |
| // connection is established. This may be useful if certain addresses |
| // are temporarily unreachable. |
| // * It returns ErrDialTimeout if connection cannot be established during |
| // DefaultDialTimeout seconds. Use DialTimeout for customizing dial timeout. |
| // |
| // This dialer is intended for custom code wrapping before passing |
| // to Client.Dial or HostClient.Dial. |
| // |
| // For instance, per-host counters and/or limits may be implemented |
| // by such wrappers. |
| // |
| // The addr passed to the function must contain port. Example addr values: |
| // |
| // * foobar.baz:443 |
| // * foo.bar:80 |
| // * aaa.com:8080 |
| func Dial(addr string) (net.Conn, error) { |
| return getDialer(DefaultDialTimeout, false)(addr) |
| } |
| |
| // DialTimeout dials the given TCP addr using tcp4 using the given timeout. |
| // |
| // This function has the following additional features comparing to net.Dial: |
| // |
| // * It reduces load on DNS resolver by caching resolved TCP addressed |
| // for DefaultDNSCacheDuration. |
| // * It dials all the resolved TCP addresses in round-robin manner until |
| // connection is established. This may be useful if certain addresses |
| // are temporarily unreachable. |
| // |
| // This dialer is intended for custom code wrapping before passing |
| // to Client.Dial or HostClient.Dial. |
| // |
| // For instance, per-host counters and/or limits may be implemented |
| // by such wrappers. |
| // |
| // The addr passed to the function must contain port. Example addr values: |
| // |
| // * foobar.baz:443 |
| // * foo.bar:80 |
| // * aaa.com:8080 |
| func DialTimeout(addr string, timeout time.Duration) (net.Conn, error) { |
| return getDialer(timeout, false)(addr) |
| } |
| |
| // DialDualStack dials the given TCP addr using both tcp4 and tcp6. |
| // |
| // This function has the following additional features comparing to net.Dial: |
| // |
| // * It reduces load on DNS resolver by caching resolved TCP addressed |
| // for DefaultDNSCacheDuration. |
| // * It dials all the resolved TCP addresses in round-robin manner until |
| // connection is established. This may be useful if certain addresses |
| // are temporarily unreachable. |
| // * It returns ErrDialTimeout if connection cannot be established during |
| // DefaultDialTimeout seconds. Use DialDualStackTimeout for custom dial |
| // timeout. |
| // |
| // This dialer is intended for custom code wrapping before passing |
| // to Client.Dial or HostClient.Dial. |
| // |
| // For instance, per-host counters and/or limits may be implemented |
| // by such wrappers. |
| // |
| // The addr passed to the function must contain port. Example addr values: |
| // |
| // * foobar.baz:443 |
| // * foo.bar:80 |
| // * aaa.com:8080 |
| func DialDualStack(addr string) (net.Conn, error) { |
| return getDialer(DefaultDialTimeout, true)(addr) |
| } |
| |
| // DialDualStackTimeout dials the given TCP addr using both tcp4 and tcp6 |
| // using the given timeout. |
| // |
| // This function has the following additional features comparing to net.Dial: |
| // |
| // * It reduces load on DNS resolver by caching resolved TCP addressed |
| // for DefaultDNSCacheDuration. |
| // * It dials all the resolved TCP addresses in round-robin manner until |
| // connection is established. This may be useful if certain addresses |
| // are temporarily unreachable. |
| // |
| // This dialer is intended for custom code wrapping before passing |
| // to Client.Dial or HostClient.Dial. |
| // |
| // For instance, per-host counters and/or limits may be implemented |
| // by such wrappers. |
| // |
| // The addr passed to the function must contain port. Example addr values: |
| // |
| // * foobar.baz:443 |
| // * foo.bar:80 |
| // * aaa.com:8080 |
| func DialDualStackTimeout(addr string, timeout time.Duration) (net.Conn, error) { |
| return getDialer(timeout, true)(addr) |
| } |
| |
| func getDialer(timeout time.Duration, dualStack bool) DialFunc { |
| if timeout <= 0 { |
| timeout = DefaultDialTimeout |
| } |
| timeoutRounded := int(timeout.Seconds()*10 + 9) |
| |
| m := dialMap |
| if dualStack { |
| m = dialDualStackMap |
| } |
| |
| dialMapLock.Lock() |
| d := m[timeoutRounded] |
| if d == nil { |
| dialer := dialerStd |
| if dualStack { |
| dialer = dialerDualStack |
| } |
| d = dialer.NewDial(timeout) |
| m[timeoutRounded] = d |
| } |
| dialMapLock.Unlock() |
| return d |
| } |
| |
| var ( |
| dialerStd = &tcpDialer{} |
| dialerDualStack = &tcpDialer{DualStack: true} |
| |
| dialMap = make(map[int]DialFunc) |
| dialDualStackMap = make(map[int]DialFunc) |
| dialMapLock sync.Mutex |
| ) |
| |
| type tcpDialer struct { |
| DualStack bool |
| |
| tcpAddrsLock sync.Mutex |
| tcpAddrsMap map[string]*tcpAddrEntry |
| |
| concurrencyCh chan struct{} |
| |
| once sync.Once |
| } |
| |
| const maxDialConcurrency = 1000 |
| |
| func (d *tcpDialer) NewDial(timeout time.Duration) DialFunc { |
| d.once.Do(func() { |
| d.concurrencyCh = make(chan struct{}, maxDialConcurrency) |
| d.tcpAddrsMap = make(map[string]*tcpAddrEntry) |
| go d.tcpAddrsClean() |
| }) |
| |
| return func(addr string) (net.Conn, error) { |
| addrs, idx, err := d.getTCPAddrs(addr) |
| if err != nil { |
| return nil, err |
| } |
| network := "tcp4" |
| if d.DualStack { |
| network = "tcp" |
| } |
| |
| var conn net.Conn |
| n := uint32(len(addrs)) |
| deadline := time.Now().Add(timeout) |
| for n > 0 { |
| conn, err = tryDial(network, &addrs[idx%n], deadline, d.concurrencyCh) |
| if err == nil { |
| return conn, nil |
| } |
| if err == ErrDialTimeout { |
| return nil, err |
| } |
| idx++ |
| n-- |
| } |
| return nil, err |
| } |
| } |
| |
| func tryDial(network string, addr *net.TCPAddr, deadline time.Time, concurrencyCh chan struct{}) (net.Conn, error) { |
| timeout := -time.Since(deadline) |
| if timeout <= 0 { |
| return nil, ErrDialTimeout |
| } |
| |
| select { |
| case concurrencyCh <- struct{}{}: |
| default: |
| tc := acquireTimer(timeout) |
| isTimeout := false |
| select { |
| case concurrencyCh <- struct{}{}: |
| case <-tc.C: |
| isTimeout = true |
| } |
| releaseTimer(tc) |
| if isTimeout { |
| return nil, ErrDialTimeout |
| } |
| } |
| |
| timeout = -time.Since(deadline) |
| if timeout <= 0 { |
| <-concurrencyCh |
| return nil, ErrDialTimeout |
| } |
| |
| chv := dialResultChanPool.Get() |
| if chv == nil { |
| chv = make(chan dialResult, 1) |
| } |
| ch := chv.(chan dialResult) |
| go func() { |
| var dr dialResult |
| dr.conn, dr.err = net.DialTCP(network, nil, addr) |
| ch <- dr |
| <-concurrencyCh |
| }() |
| |
| var ( |
| conn net.Conn |
| err error |
| ) |
| |
| tc := acquireTimer(timeout) |
| select { |
| case dr := <-ch: |
| conn = dr.conn |
| err = dr.err |
| dialResultChanPool.Put(ch) |
| case <-tc.C: |
| err = ErrDialTimeout |
| } |
| releaseTimer(tc) |
| |
| return conn, err |
| } |
| |
| var dialResultChanPool sync.Pool |
| |
| type dialResult struct { |
| conn net.Conn |
| err error |
| } |
| |
| // ErrDialTimeout is returned when TCP dialing is timed out. |
| var ErrDialTimeout = errors.New("dialing to the given TCP address timed out") |
| |
| // DefaultDialTimeout is timeout used by Dial and DialDualStack |
| // for establishing TCP connections. |
| const DefaultDialTimeout = 3 * time.Second |
| |
| type tcpAddrEntry struct { |
| addrs []net.TCPAddr |
| addrsIdx uint32 |
| |
| resolveTime time.Time |
| pending bool |
| } |
| |
| // DefaultDNSCacheDuration is the duration for caching resolved TCP addresses |
| // by Dial* functions. |
| const DefaultDNSCacheDuration = time.Minute |
| |
| func (d *tcpDialer) tcpAddrsClean() { |
| expireDuration := 2 * DefaultDNSCacheDuration |
| for { |
| time.Sleep(time.Second) |
| t := time.Now() |
| |
| d.tcpAddrsLock.Lock() |
| for k, e := range d.tcpAddrsMap { |
| if t.Sub(e.resolveTime) > expireDuration { |
| delete(d.tcpAddrsMap, k) |
| } |
| } |
| d.tcpAddrsLock.Unlock() |
| } |
| } |
| |
| func (d *tcpDialer) getTCPAddrs(addr string) ([]net.TCPAddr, uint32, error) { |
| d.tcpAddrsLock.Lock() |
| e := d.tcpAddrsMap[addr] |
| if e != nil && !e.pending && time.Since(e.resolveTime) > DefaultDNSCacheDuration { |
| e.pending = true |
| e = nil |
| } |
| d.tcpAddrsLock.Unlock() |
| |
| if e == nil { |
| addrs, err := resolveTCPAddrs(addr, d.DualStack) |
| if err != nil { |
| d.tcpAddrsLock.Lock() |
| e = d.tcpAddrsMap[addr] |
| if e != nil && e.pending { |
| e.pending = false |
| } |
| d.tcpAddrsLock.Unlock() |
| return nil, 0, err |
| } |
| |
| e = &tcpAddrEntry{ |
| addrs: addrs, |
| resolveTime: time.Now(), |
| } |
| |
| d.tcpAddrsLock.Lock() |
| d.tcpAddrsMap[addr] = e |
| d.tcpAddrsLock.Unlock() |
| } |
| |
| idx := atomic.AddUint32(&e.addrsIdx, 1) |
| return e.addrs, idx, nil |
| } |
| |
| func resolveTCPAddrs(addr string, dualStack bool) ([]net.TCPAddr, error) { |
| host, portS, err := net.SplitHostPort(addr) |
| if err != nil { |
| return nil, err |
| } |
| port, err := strconv.Atoi(portS) |
| if err != nil { |
| return nil, err |
| } |
| |
| ips, err := net.LookupIP(host) |
| if err != nil { |
| return nil, err |
| } |
| |
| n := len(ips) |
| addrs := make([]net.TCPAddr, 0, n) |
| for i := 0; i < n; i++ { |
| ip := ips[i] |
| if !dualStack && ip.To4() == nil { |
| continue |
| } |
| addrs = append(addrs, net.TCPAddr{ |
| IP: ip, |
| Port: port, |
| }) |
| } |
| if len(addrs) == 0 { |
| return nil, errNoDNSEntries |
| } |
| return addrs, nil |
| } |
| |
| var errNoDNSEntries = errors.New("couldn't find DNS entries for the given domain. Try using DialDualStack") |