blob: fecdb9452b2b054f6fd7c5d33b2f4b4323d4fca5 [file] [log] [blame]
// +build go1.8
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package avatica
import (
"database/sql"
"testing"
"time"
"context"
"math"
"reflect"
)
func (dbt *DBTest) mustExecContext(ctx context.Context, query string, args ...interface{}) (res sql.Result) {
res, err := dbt.db.ExecContext(ctx, query, args...)
if err != nil {
dbt.fail("exec", query, err)
}
return res
}
func (dbt *DBTest) mustQueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows) {
rows, err := dbt.db.QueryContext(ctx, query, args...)
if err != nil {
dbt.fail("query", query, err)
}
return rows
}
func getContext() context.Context {
ctx, _ := context.WithTimeout(context.Background(), 4*time.Minute)
return ctx
}
func TestContext(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
// Create and seed table
dbt.mustExecContext(getContext(), "CREATE TABLE "+dbt.tableName+" (id BIGINT PRIMARY KEY, val VARCHAR) TRANSACTIONAL=false")
dbt.mustExecContext(getContext(), "UPSERT INTO "+dbt.tableName+" VALUES (1,'A')")
dbt.mustExecContext(getContext(), "UPSERT INTO "+dbt.tableName+" VALUES (2,'B')")
rows := dbt.mustQueryContext(getContext(), "SELECT COUNT(*) FROM "+dbt.tableName)
defer rows.Close()
for rows.Next() {
var count int
err := rows.Scan(&count)
if err != nil {
dbt.Fatal(err)
}
if count != 2 {
dbt.Fatalf("There should be 2 rows, got %d", count)
}
}
// Test transactions and prepared statements
_, err := dbt.db.BeginTx(getContext(), &sql.TxOptions{Isolation: sql.LevelReadUncommitted, ReadOnly: true})
if err == nil {
t.Error("Expected an error while creating a read only transaction, but no error was returned")
}
tx, err := dbt.db.BeginTx(getContext(), &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err != nil {
t.Errorf("Unexpected error while creating transaction: %s", err)
}
stmt, err := tx.PrepareContext(getContext(), "UPSERT INTO "+dbt.tableName+" VALUES(?,?)")
if err != nil {
t.Errorf("Unexpected error while preparing statement: %s", err)
}
res, err := stmt.ExecContext(getContext(), 3, "C")
if err != nil {
t.Errorf("Unexpected error while executing statement: %s", err)
}
affected, err := res.RowsAffected()
if err != nil {
t.Errorf("Error getting affected rows: %s", err)
}
if affected != 1 {
t.Errorf("Expected 1 affected row, got %d", affected)
}
err = tx.Commit()
if err != nil {
t.Errorf("Error committing transaction: %s", err)
}
stmt2, err := dbt.db.PrepareContext(getContext(), "SELECT * FROM "+dbt.tableName+" WHERE id = ?")
if err != nil {
t.Errorf("Error preparing statement: %s", err)
}
row := stmt2.QueryRowContext(getContext(), 3)
if err != nil {
t.Errorf("Error querying for row: %s", err)
}
var (
queryID int64
queryVal string
)
err = row.Scan(&queryID, &queryVal)
if err != nil {
t.Errorf("Error scanning results into variable: %s", err)
}
if queryID != 3 {
t.Errorf("Expected scanned id to be %d, got %d", 3, queryID)
}
if queryVal != "C" {
t.Errorf("Expected scanned string to be %s, got %s", "C", queryVal)
}
})
}
func TestPing(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
err := dbt.db.Ping()
if err != nil {
t.Errorf("Expected ping to succeed, got error: %s", err)
}
})
}
func TestInvalidPing(t *testing.T) {
runTests(t, "http://invalid-server:8765", func(dbt *DBTest) {
err := dbt.db.Ping()
if err == nil {
t.Error("Expected ping to fail, but did not get any error")
}
})
}
func TestMultipleResultSets(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
// Create and seed table
dbt.mustExecContext(getContext(), "CREATE TABLE "+dbt.tableName+" (id BIGINT PRIMARY KEY, val VARCHAR) TRANSACTIONAL=false")
dbt.mustExecContext(getContext(), "UPSERT INTO "+dbt.tableName+" VALUES (1,'A')")
dbt.mustExecContext(getContext(), "UPSERT INTO "+dbt.tableName+" VALUES (2,'B')")
rows, err := dbt.db.QueryContext(getContext(), "SELECT * FROM "+dbt.tableName+" WHERE id = 1")
if err != nil {
t.Errorf("Unexpected error while executing query: %s", err)
}
defer rows.Close()
for rows.Next() {
var (
id int64
val string
)
if err := rows.Scan(&id, &val); err != nil {
t.Errorf("Error while scanning row into variables: %s", err)
}
if id != 1 {
t.Errorf("Expected id to be %d, got %d", 1, id)
}
if val != "A" {
t.Errorf("Expected value to be %s, got %s", "A", val)
}
}
if rows.NextResultSet() {
t.Error("There should be no more result sets, but got another result set")
}
})
}
func TestColumnTypes(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
// Create and seed table
dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` (
int INTEGER PRIMARY KEY,
uint UNSIGNED_INT,
bint BIGINT,
ulong UNSIGNED_LONG,
tint TINYINT,
utint UNSIGNED_TINYINT,
sint SMALLINT,
usint UNSIGNED_SMALLINT,
flt FLOAT,
uflt UNSIGNED_FLOAT,
dbl DOUBLE,
udbl UNSIGNED_DOUBLE,
dec DECIMAL(10, 5),
dec2 DECIMAL,
bool BOOLEAN,
tm TIME,
dt DATE,
tmstmp TIMESTAMP,
utm UNSIGNED_TIME,
udt UNSIGNED_DATE,
utmstmp UNSIGNED_TIMESTAMP,
var VARCHAR(10),
ch CHAR(3),
bin BINARY(20),
varbin VARBINARY
) TRANSACTIONAL=false`)
// Select
rows, err := dbt.db.QueryContext(getContext(), "SELECT * FROM "+dbt.tableName)
if err != nil {
t.Errorf("Unexpected error while selecting from table: %s", err)
}
columnNames, err := rows.Columns()
if err != nil {
t.Errorf("Error getting column names: %s", err)
}
expectedColumnNames := []string{"INT", "UINT", "BINT", "ULONG", "TINT", "UTINT", "SINT", "USINT", "FLT", "UFLT", "DBL", "UDBL", "DEC", "DEC2", "BOOL", "TM", "DT", "TMSTMP", "UTM", "UDT", "UTMSTMP", "VAR", "CH", "BIN", "VARBIN"}
if !reflect.DeepEqual(columnNames, expectedColumnNames) {
t.Error("Column names does not match expected column names")
}
type decimalSize struct {
precision int64
scale int64
ok bool
}
type length struct {
length int64
ok bool
}
type nullable struct {
nullable bool
ok bool
}
expectedColumnTypes := []struct {
databaseTypeName string
decimalSize decimalSize
length length
name string
nullable nullable
scanType reflect.Type
}{
{
databaseTypeName: "INTEGER",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 0,
ok: false,
},
name: "INT",
nullable: nullable{
nullable: false,
ok: true,
},
scanType: reflect.TypeOf(int64(0)),
},
{
databaseTypeName: "UNSIGNED_INT",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 0,
ok: false,
},
name: "UINT",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(int64(0)),
},
{
databaseTypeName: "BIGINT",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 0,
ok: false,
},
name: "BINT",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(int64(0)),
},
{
databaseTypeName: "UNSIGNED_LONG",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 0,
ok: false,
},
name: "ULONG",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(int64(0)),
},
{
databaseTypeName: "TINYINT",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 0,
ok: false,
},
name: "TINT",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(int64(0)),
},
{
databaseTypeName: "UNSIGNED_TINYINT",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 0,
ok: false,
},
name: "UTINT",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(int64(0)),
},
{
databaseTypeName: "SMALLINT",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 0,
ok: false,
},
name: "SINT",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(int64(0)),
},
{
databaseTypeName: "UNSIGNED_SMALLINT",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 0,
ok: false,
},
name: "USINT",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(int64(0)),
},
{
databaseTypeName: "FLOAT",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 0,
ok: false,
},
name: "FLT",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(float64(0)),
},
{
databaseTypeName: "UNSIGNED_FLOAT",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 0,
ok: false,
},
name: "UFLT",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(float64(0)),
},
{
databaseTypeName: "DOUBLE",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 0,
ok: false,
},
name: "DBL",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(float64(0)),
},
{
databaseTypeName: "UNSIGNED_DOUBLE",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 0,
ok: false,
},
name: "UDBL",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(float64(0)),
},
{
databaseTypeName: "DECIMAL",
decimalSize: decimalSize{
precision: 10,
scale: 5,
ok: true,
},
length: length{
length: 0,
ok: false,
},
name: "DEC",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(""),
},
{
databaseTypeName: "DECIMAL",
decimalSize: decimalSize{
precision: math.MaxInt64,
scale: math.MaxInt64,
ok: true,
},
length: length{
length: 0,
ok: false,
},
name: "DEC2",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(""),
},
{
databaseTypeName: "BOOLEAN",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 0,
ok: false,
},
name: "BOOL",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(bool(false)),
},
{
databaseTypeName: "TIME",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 0,
ok: false,
},
name: "TM",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(time.Time{}),
},
{
databaseTypeName: "DATE",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 0,
ok: false,
},
name: "DT",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(time.Time{}),
},
{
databaseTypeName: "TIMESTAMP",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 0,
ok: false,
},
name: "TMSTMP",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(time.Time{}),
},
{
databaseTypeName: "UNSIGNED_TIME",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 0,
ok: false,
},
name: "UTM",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(time.Time{}),
},
{
databaseTypeName: "UNSIGNED_DATE",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 0,
ok: false,
},
name: "UDT",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(time.Time{}),
},
{
databaseTypeName: "UNSIGNED_TIMESTAMP",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 0,
ok: false,
},
name: "UTMSTMP",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(time.Time{}),
},
{
databaseTypeName: "VARCHAR",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 10,
ok: true,
},
name: "VAR",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(""),
},
{
databaseTypeName: "CHAR",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 3,
ok: true,
},
name: "CH",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf(""),
},
{
databaseTypeName: "BINARY",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: 20,
ok: true,
},
name: "BIN",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf([]byte{}),
},
{
databaseTypeName: "VARBINARY",
decimalSize: decimalSize{
precision: 0,
scale: 0,
ok: false,
},
length: length{
length: math.MaxInt64,
ok: true,
},
name: "VARBIN",
nullable: nullable{
nullable: true,
ok: true,
},
scanType: reflect.TypeOf([]byte{}),
},
}
columnTypes, err := rows.ColumnTypes()
if err != nil {
t.Errorf("Error getting column types: %s", err)
}
for index, columnType := range columnTypes {
expected := expectedColumnTypes[index]
if columnType.DatabaseTypeName() != expected.databaseTypeName {
t.Errorf("Expected database type name for index %d to be %s, got %s", index, expected.databaseTypeName, columnType.DatabaseTypeName())
}
precision, scale, ok := columnType.DecimalSize()
if precision != expected.decimalSize.precision {
t.Errorf("Expected decimal precision for index %d to be %d, got %d", index, expected.decimalSize.precision, precision)
}
if scale != expected.decimalSize.scale {
t.Errorf("Expected decimal scale for index %d to be %d, got %d", index, expected.decimalSize.scale, scale)
}
if ok != expected.decimalSize.ok {
t.Errorf("Expected decimal ok for index %d to be %t, got %t", index, expected.decimalSize.ok, ok)
}
length, ok := columnType.Length()
if length != expected.length.length {
t.Errorf("Expected length for index %d to be %d, got %d", index, expected.length.length, length)
}
if ok != expected.length.ok {
t.Errorf("Expected length ok for index %d to be %t, got %t", index, expected.length.ok, ok)
}
if columnType.Name() != expected.name {
t.Errorf("Expected column name for index %d to be %s, got %s", index, expected.name, columnType.Name())
}
nullable, ok := columnType.Nullable()
if nullable != expected.nullable.nullable {
t.Errorf("Expected nullable for index %d to be %t, got %t", index, expected.nullable.nullable, nullable)
}
if ok != expected.nullable.ok {
t.Errorf("Expected nullable ok for index %d to be %t, got %t", index, expected.nullable.ok, ok)
}
if columnType.ScanType() != expected.scanType {
t.Errorf("Expected scan type for index %d to be %s, got %s", index, expected.scanType, columnType.ScanType())
}
}
})
}