blob: 882161e36c7b481e3e0c67625f729df4eccf7f4b [file]
/*
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
*/
package gremlingo
import (
"bytes"
"compress/zlib"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
)
// connectionSettings holds configuration for the connection.
type connectionSettings struct {
tlsConfig *tls.Config
connectionTimeout time.Duration
maxConnsPerHost int
maxIdleConnsPerHost int
idleConnTimeout time.Duration
keepAliveInterval time.Duration
enableCompression bool
enableUserAgentOnConnect bool
}
// connection handles HTTP request/response for Gremlin queries.
// This is the transport layer for communicating with a Gremlin server.
type connection struct {
url string
httpClient *http.Client
connSettings *connectionSettings
logHandler *logHandler
serializer *GraphBinarySerializer
interceptors []RequestInterceptor
wg sync.WaitGroup
}
// Connection pool defaults aligned with Java driver
const (
defaultMaxConnsPerHost = 128 // Java: ConnectionPool.MAX_POOL_SIZE
defaultMaxIdleConnsPerHost = 8 // Keep some connections warm
defaultIdleConnTimeout = 180 * time.Second // Java: CONNECTION_IDLE_TIMEOUT_MILLIS
defaultConnectionTimeout = 15 * time.Second // Java: CONNECTION_SETUP_TIMEOUT_MILLIS
defaultKeepAliveInterval = 30 * time.Second // TCP keep-alive probe interval
)
func newConnection(handler *logHandler, url string, connSettings *connectionSettings) *connection {
// Apply defaults for zero values
connectionTimeout := connSettings.connectionTimeout
if connectionTimeout == 0 {
connectionTimeout = defaultConnectionTimeout
}
maxConnsPerHost := connSettings.maxConnsPerHost
if maxConnsPerHost == 0 {
maxConnsPerHost = defaultMaxConnsPerHost
}
maxIdleConnsPerHost := connSettings.maxIdleConnsPerHost
if maxIdleConnsPerHost == 0 {
maxIdleConnsPerHost = defaultMaxIdleConnsPerHost
}
idleConnTimeout := connSettings.idleConnTimeout
if idleConnTimeout == 0 {
idleConnTimeout = defaultIdleConnTimeout
}
keepAliveInterval := connSettings.keepAliveInterval
if keepAliveInterval == 0 {
keepAliveInterval = defaultKeepAliveInterval
}
transport := &http.Transport{
DialContext: (&net.Dialer{
Timeout: connectionTimeout,
KeepAlive: keepAliveInterval,
}).DialContext,
TLSClientConfig: connSettings.tlsConfig,
MaxConnsPerHost: maxConnsPerHost,
MaxIdleConnsPerHost: maxIdleConnsPerHost,
IdleConnTimeout: idleConnTimeout,
DisableCompression: !connSettings.enableCompression,
}
return &connection{
url: url,
httpClient: &http.Client{Transport: transport}, // No Timeout - allows streaming
connSettings: connSettings,
logHandler: handler,
serializer: newGraphBinarySerializer(handler),
}
}
// AddInterceptor adds a request interceptor to the chain.
func (c *connection) AddInterceptor(interceptor RequestInterceptor) {
c.interceptors = append(c.interceptors, interceptor)
}
// submit sends request and streams results directly to ResultSet
func (c *connection) submit(req *RequestMessage) (ResultSet, error) {
rs := newChannelResultSet()
c.wg.Add(1)
go func() {
defer c.wg.Done()
c.executeAndStream(req, rs)
}()
return rs, nil
}
func (c *connection) executeAndStream(req *RequestMessage, rs ResultSet) {
defer rs.Close()
// Create HttpRequest for interceptors
httpReq, err := NewHttpRequest(http.MethodPost, c.url)
if err != nil {
c.logHandler.logf(Error, failedToSendRequest, err.Error())
rs.setError(err)
return
}
// Set default headers before interceptors
c.setHttpRequestHeaders(httpReq)
// Set Body to the raw *RequestMessage so interceptors can inspect/modify it
httpReq.Body = req
// Apply interceptors — they see *RequestMessage in Body (pre-serialization).
// Interceptors may replace Body with []byte, io.Reader, or *http.Request.
for _, interceptor := range c.interceptors {
if err := interceptor(httpReq); err != nil {
c.logHandler.logf(Error, failedToSendRequest, err.Error())
rs.setError(err)
return
}
}
// After interceptors, serialize if Body is still *RequestMessage
if r, ok := httpReq.Body.(*RequestMessage); ok {
if c.serializer != nil {
data, err := c.serializer.SerializeMessage(r)
if err != nil {
c.logHandler.logf(Error, failedToSendRequest, err.Error())
rs.setError(err)
return
}
httpReq.Body = data
} else {
errMsg := "request body was not serialized; either provide a serializer or add an interceptor that serializes the request"
c.logHandler.logf(Error, failedToSendRequest, errMsg)
rs.setError(fmt.Errorf("%s", errMsg))
return
}
}
// Create actual http.Request from HttpRequest based on Body type
var httpGoReq *http.Request
switch body := httpReq.Body.(type) {
case []byte:
httpGoReq, err = http.NewRequest(httpReq.Method, httpReq.URL.String(), bytes.NewReader(body))
if err != nil {
c.logHandler.logf(Error, failedToSendRequest, err.Error())
rs.setError(err)
return
}
httpGoReq.Header = httpReq.Headers
case io.Reader:
httpGoReq, err = http.NewRequest(httpReq.Method, httpReq.URL.String(), body)
if err != nil {
c.logHandler.logf(Error, failedToSendRequest, err.Error())
rs.setError(err)
return
}
httpGoReq.Header = httpReq.Headers
case *http.Request:
httpGoReq = body
default:
errMsg := fmt.Sprintf("unsupported body type after interceptors: %T", body)
c.logHandler.logf(Error, failedToSendRequest, errMsg)
rs.setError(fmt.Errorf("%s", errMsg))
return
}
resp, err := c.httpClient.Do(httpGoReq)
if err != nil {
c.logHandler.logf(Error, failedToSendRequest, err.Error())
rs.setError(err)
return
}
defer func() {
// Drain any unread bytes so the connection can be reused gracefully.
// Without this, Go's HTTP client sends a TCP RST instead of FIN,
// causing "Connection reset by peer" errors on the server.
io.Copy(io.Discard, resp.Body)
if err := resp.Body.Close(); err != nil {
c.logHandler.logf(Debug, failedToCloseResponseBody, err.Error())
}
}()
// If the HTTP status indicates an error and the response is not GraphBinary,
// read the body as a text/JSON error message instead of attempting binary
// deserialization which would produce cryptic errors.
contentType := resp.Header.Get(HeaderContentType)
if resp.StatusCode >= 400 && !strings.Contains(contentType, graphBinaryMimeType) {
bodyBytes, readErr := io.ReadAll(resp.Body)
if readErr != nil {
c.logHandler.logf(Error, failedToReceiveResponse, readErr.Error())
rs.setError(fmt.Errorf("Gremlin Server returned HTTP %d and failed to read body: %w",
resp.StatusCode, readErr))
return
}
errorBody := string(bodyBytes)
errorMsg := tryExtractJSONError(errorBody)
if errorMsg == "" {
errorMsg = fmt.Sprintf("Gremlin Server returned HTTP %d: %s", resp.StatusCode, errorBody)
}
c.logHandler.logf(Error, failedToReceiveResponse, errorMsg)
rs.setError(fmt.Errorf("%s", errorMsg))
return
}
reader, zlibReader, err := c.getReader(resp)
if err != nil {
c.logHandler.logf(Error, failedToReceiveResponse, err.Error())
rs.setError(err)
return
}
if zlibReader != nil {
defer func() {
if err := zlibReader.Close(); err != nil {
c.logHandler.logf(Debug, failedToCloseDecompReader, err.Error())
}
}()
}
c.streamToResultSet(reader, rs)
}
// setHttpRequestHeaders sets default headers on HttpRequest (for interceptors)
func (c *connection) setHttpRequestHeaders(req *HttpRequest) {
req.Headers.Set(HeaderContentType, graphBinaryMimeType)
req.Headers.Set(HeaderAccept, graphBinaryMimeType)
if c.connSettings.enableUserAgentOnConnect {
req.Headers.Set(HeaderUserAgent, userAgent)
}
if c.connSettings.enableCompression {
req.Headers.Set(HeaderAcceptEncoding, "deflate")
}
}
func (c *connection) getReader(resp *http.Response) (io.Reader, io.Closer, error) {
if resp.Header.Get("Content-Encoding") == "deflate" {
zr, err := zlib.NewReader(resp.Body)
if err != nil {
return nil, nil, err
}
return zr, zr, nil
}
return resp.Body, nil, nil
}
func (c *connection) streamToResultSet(reader io.Reader, rs ResultSet) {
d := NewGraphBinaryDeserializer(reader)
if err := d.ReadHeader(); err != nil {
if err != io.EOF {
c.logHandler.logf(Error, failedToReceiveResponse, err.Error())
rs.setError(err)
}
return
}
for {
obj, err := d.ReadFullyQualified()
if err != nil {
if err != io.EOF {
c.logHandler.logf(Error, failedToReceiveResponse, err.Error())
rs.setError(err)
}
return
}
if marker, ok := obj.(Marker); ok && marker == EndOfStream() {
code, msg, _, err := d.ReadStatus()
if err != nil {
c.logHandler.logf(Error, failedToReceiveResponse, err.Error())
rs.setError(err)
return
}
if code != 200 && code != 0 {
rs.setError(newError(err0502ResponseReadLoopError, msg, code))
}
return
}
if d.IsBulked() {
bulkObj, err := d.ReadFullyQualified()
if err != nil {
c.logHandler.logf(Error, failedToReceiveResponse, err.Error())
rs.setError(err)
return
}
bulk, ok := bulkObj.(int64)
if !ok {
c.logHandler.logf(Error, failedToReceiveResponse, "expected int64 bulk count")
rs.setError(fmt.Errorf("expected int64 bulk count, got %T", bulkObj))
return
}
rs.Channel() <- &Result{&Traverser{Bulk: bulk, Value: obj}}
} else {
rs.Channel() <- &Result{obj}
}
}
}
// tryExtractJSONError attempts to extract an error message from a JSON response body.
// The server sometimes responds with a JSON object containing a "message" field
// even when it cannot produce a GraphBinary response.
func tryExtractJSONError(body string) string {
var obj map[string]interface{}
if err := json.Unmarshal([]byte(body), &obj); err != nil {
return ""
}
if msg, ok := obj["message"]; ok {
if s, ok := msg.(string); ok {
return s
}
}
return ""
}
func (c *connection) close() {
c.wg.Wait()
c.httpClient.CloseIdleConnections()
}