| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to you under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package avatica |
| |
| import ( |
| "context" |
| "database/sql/driver" |
| "errors" |
| "math" |
| "sync" |
| "time" |
| |
| "github.com/apache/calcite-avatica-go/v5/message" |
| ) |
| |
| type stmt struct { |
| statementID uint32 |
| conn *conn |
| parameters []*message.AvaticaParameter |
| handle *message.StatementHandle |
| batchUpdates []*message.UpdateBatch |
| sync.Mutex |
| } |
| |
| // Close closes a statement |
| func (s *stmt) Close() error { |
| |
| if s.conn.connectionId == "" { |
| return driver.ErrBadConn |
| } |
| |
| if s.conn.config.batching { |
| _, err := s.conn.httpClient.post(context.Background(), &message.ExecuteBatchRequest{ |
| ConnectionId: s.conn.connectionId, |
| StatementId: s.statementID, |
| Updates: s.batchUpdates, |
| }) |
| if err != nil { |
| return s.conn.avaticaErrorToResponseErrorOrError(err) |
| } |
| } |
| |
| _, err := s.conn.httpClient.post(context.Background(), &message.CloseStatementRequest{ |
| ConnectionId: s.conn.connectionId, |
| StatementId: s.statementID, |
| }) |
| |
| if err != nil { |
| return s.conn.avaticaErrorToResponseErrorOrError(err) |
| } |
| |
| return nil |
| } |
| |
| // NumInput returns the number of placeholder parameters. |
| // |
| // If NumInput returns >= 0, the sql package will sanity check |
| // argument counts from callers and return errors to the caller |
| // before the statement's Exec or Query methods are called. |
| // |
| // NumInput may also return -1, if the driver doesn't know |
| // its number of placeholders. In that case, the sql package |
| // will not sanity check Exec or Query argument counts. |
| func (s *stmt) NumInput() int { |
| return len(s.parameters) |
| } |
| |
| // Exec executes a query that doesn't return rows, such |
| // as an INSERT or UPDATE. |
| func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { |
| list := driverValueToNamedValue(args) |
| return s.exec(context.Background(), list) |
| } |
| |
| func (s *stmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) { |
| |
| if s.conn.connectionId == "" { |
| return nil, driver.ErrBadConn |
| } |
| |
| values := s.parametersToTypedValues(args) |
| |
| if s.conn.config.batching { |
| s.Lock() |
| defer s.Unlock() |
| |
| s.batchUpdates = append(s.batchUpdates, &message.UpdateBatch{ |
| ParameterValues: values, |
| }) |
| return &result{ |
| affectedRows: -1, |
| }, nil |
| } |
| |
| msg := &message.ExecuteRequest{ |
| StatementHandle: s.handle, |
| ParameterValues: values, |
| FirstFrameMaxSize: s.conn.config.frameMaxSize, |
| HasParameterValues: true, |
| } |
| |
| if s.conn.config.frameMaxSize <= -1 { |
| msg.FirstFrameMaxSize = math.MaxInt32 |
| } else { |
| msg.FirstFrameMaxSize = s.conn.config.frameMaxSize |
| } |
| |
| res, err := s.conn.httpClient.post(ctx, msg) |
| |
| if err != nil { |
| return nil, s.conn.avaticaErrorToResponseErrorOrError(err) |
| } |
| |
| results := res.(*message.ExecuteResponse).Results |
| |
| if len(results) <= 0 { |
| return nil, errors.New("empty ResultSet in ExecuteResponse") |
| } |
| |
| // Currently there is only 1 ResultSet per response |
| changed := int64(results[0].UpdateCount) |
| |
| return &result{ |
| affectedRows: changed, |
| }, nil |
| } |
| |
| // Query executes a query that may return rows, such as a |
| // SELECT. |
| func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { |
| list := driverValueToNamedValue(args) |
| return s.query(context.Background(), list) |
| } |
| |
| func (s *stmt) query(ctx context.Context, args []namedValue) (driver.Rows, error) { |
| if s.conn.connectionId == "" { |
| return nil, driver.ErrBadConn |
| } |
| |
| msg := &message.ExecuteRequest{ |
| StatementHandle: s.handle, |
| ParameterValues: s.parametersToTypedValues(args), |
| FirstFrameMaxSize: s.conn.config.frameMaxSize, |
| HasParameterValues: true, |
| } |
| |
| if s.conn.config.frameMaxSize <= -1 { |
| msg.FirstFrameMaxSize = math.MaxInt32 |
| } else { |
| msg.FirstFrameMaxSize = s.conn.config.frameMaxSize |
| } |
| |
| res, err := s.conn.httpClient.post(ctx, msg) |
| |
| if err != nil { |
| return nil, s.conn.avaticaErrorToResponseErrorOrError(err) |
| } |
| |
| resultSet := res.(*message.ExecuteResponse).Results |
| |
| return newRows(s.conn, s.statementID, false, resultSet), nil |
| } |
| |
| func (s *stmt) parametersToTypedValues(vals []namedValue) []*message.TypedValue { |
| |
| var result []*message.TypedValue |
| |
| for i, val := range vals { |
| typed := message.TypedValue{} |
| if val.Value == nil { |
| typed.Null = true |
| typed.Type = message.Rep_NULL |
| } else { |
| |
| switch v := val.Value.(type) { |
| case int64: |
| typed.Type = message.Rep_LONG |
| typed.NumberValue = v |
| case float64: |
| typed.Type = message.Rep_DOUBLE |
| typed.DoubleValue = v |
| case bool: |
| typed.Type = message.Rep_BOOLEAN |
| typed.BoolValue = v |
| case []byte: |
| typed.Type = message.Rep_BYTE_STRING |
| typed.BytesValue = v |
| case string: |
| |
| if s.parameters[i].TypeName == "DECIMAL" { |
| typed.Type = message.Rep_BIG_DECIMAL |
| } else { |
| typed.Type = message.Rep_STRING |
| } |
| typed.StringValue = v |
| |
| case time.Time: |
| avaticaParameter := s.parameters[i] |
| |
| switch avaticaParameter.TypeName { |
| case "TIME", "UNSIGNED_TIME": |
| typed.Type = message.Rep_JAVA_SQL_TIME |
| |
| // Because a location can have multiple time zones due to daylight savings, |
| // we need to be explicit and get the offset |
| zone, offset := v.Zone() |
| |
| // Calculate milliseconds since 00:00:00.000 |
| base := time.Date(v.Year(), v.Month(), v.Day(), 0, 0, 0, 0, time.FixedZone(zone, offset)) |
| typed.NumberValue = v.Sub(base).Nanoseconds() / int64(time.Millisecond) |
| |
| case "DATE", "UNSIGNED_DATE": |
| typed.Type = message.Rep_JAVA_SQL_DATE |
| |
| // Because a location can have multiple time zones due to daylight savings, |
| // we need to be explicit and get the offset |
| zone, offset := v.Zone() |
| |
| // Calculate number of days since 1970/1/1 |
| base := time.Date(1970, 1, 1, 0, 0, 0, 0, time.FixedZone(zone, offset)) |
| typed.NumberValue = int64(v.Sub(base) / (24 * time.Hour)) |
| |
| case "TIMESTAMP", "UNSIGNED_TIMESTAMP": |
| typed.Type = message.Rep_JAVA_SQL_TIMESTAMP |
| |
| // Because a location can have multiple time zones due to daylight savings, |
| // we need to be explicit and get the offset |
| zone, offset := v.Zone() |
| |
| // Calculate number of milliseconds since 1970-01-01 00:00:00.000 |
| base := time.Date(1970, 1, 1, 0, 0, 0, 0, time.FixedZone(zone, offset)) |
| typed.NumberValue = v.Sub(base).Nanoseconds() / int64(time.Millisecond) |
| } |
| } |
| } |
| |
| result = append(result, &typed) |
| } |
| |
| return result |
| } |