blob: b1b7ccfa2ad9c73fe90e979309ff7a289f443952 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package jdbc
import (
"context"
"database/sql"
"flag"
"fmt"
"log"
"testing"
"time"
"github.com/apache/beam/sdks/v2/go/pkg/beam"
_ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/dataflow"
_ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/flink"
_ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/samza"
_ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/spark"
"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest"
"github.com/apache/beam/sdks/v2/go/test/integration"
"github.com/docker/go-connections/nat"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
"gopkg.in/retry.v1"
)
var expansionAddr string // Populate with expansion address labelled "schemaio".
const maxRetryCount = 5
func checkFlags(t *testing.T) {
if expansionAddr == "" {
t.Skip("No Schema IO expansion address provided.")
}
}
func setupTestContainer(t *testing.T, ctx context.Context, dbname, username, password string) (testcontainers.Container, int) {
t.Helper()
var env = map[string]string{
"POSTGRES_PASSWORD": password,
"POSTGRES_USER": username,
"POSTGRES_DB": dbname,
}
var port = "5432/tcp"
dbURL := func(port nat.Port) string {
return fmt.Sprintf("postgres://%s:%s@localhost:%s/%s?sslmode=disable", username, password, port.Port(), dbname)
}
req := testcontainers.GenericContainerRequest{
ContainerRequest: testcontainers.ContainerRequest{
Image: "postgres",
ExposedPorts: []string{port},
Env: env,
WaitingFor: wait.ForSQL(nat.Port(port), "postgres", dbURL).Timeout(time.Second * 5),
},
Started: true,
}
strategy := retry.LimitCount(maxRetryCount,
retry.Exponential{
Initial: time.Second,
Factor: 2,
},
)
var container testcontainers.Container
var err error
for r := retry.Start(strategy, nil); r.Next(); {
container, err = testcontainers.GenericContainer(ctx, req)
if err == nil {
break
}
if r.Count() == maxRetryCount {
t.Fatalf("failed to start container with %v retries: %v", maxRetryCount, err)
}
}
mappedPort, err := container.MappedPort(ctx, nat.Port(port))
if err != nil {
t.Fatalf("failed to get container external port: %s", err)
}
url := fmt.Sprintf("postgres://%s:%s@localhost:%s/%s?sslmode=disable", username, password, mappedPort.Port(), dbname)
db, err := sql.Open("postgres", url)
if err != nil {
t.Fatalf("failed to establish database connection: %s", err)
}
defer db.Close()
_, err = db.ExecContext(ctx, "CREATE TABLE roles(role_id bigint PRIMARY KEY);")
if err != nil {
t.Fatalf("can't create table, check command and access level")
}
return container, mappedPort.Int()
}
// TestJDBCIO_BasicReadWrite tests basic read and write transform from JDBC.
func TestJDBCIO_BasicReadWrite(t *testing.T) {
integration.CheckFilters(t)
checkFlags(t)
ctx := context.Background()
dbname := "postjdbc"
username := "newuser"
password := "password"
cont, port := setupTestContainer(t, ctx, dbname, username, password)
defer cont.Terminate(ctx)
tableName := "roles"
host := "localhost"
jdbcUrl := fmt.Sprintf("jdbc:postgresql://%s:%d/%s", host, port, dbname)
write := WritePipeline(expansionAddr, tableName, "org.postgresql.Driver", jdbcUrl, username, password)
ptest.RunAndValidate(t, write)
read := ReadPipeline(expansionAddr, tableName, "org.postgresql.Driver", jdbcUrl, username, password)
ptest.RunAndValidate(t, read)
}
// TestJDBCIO_PostgresReadWrite tests basic read and write transform from JDBC with postgres.
func TestJDBCIO_PostgresReadWrite(t *testing.T) {
integration.CheckFilters(t)
checkFlags(t)
dbname := "postjdbc"
username := "newuser"
password := "password"
ctx := context.Background()
cont, port := setupTestContainer(t, ctx, dbname, username, password)
defer cont.Terminate(ctx)
tableName := "roles"
host := "localhost"
jdbcUrl := fmt.Sprintf("jdbc:postgresql://%s:%d/%s", host, port, dbname)
write := WriteToPostgres(expansionAddr, tableName, jdbcUrl, username, password)
ptest.RunAndValidate(t, write)
read := ReadFromPostgres(expansionAddr, tableName, jdbcUrl, username, password)
ptest.RunAndValidate(t, read)
}
func TestMain(m *testing.M) {
flag.Parse()
beam.Init()
services := integration.NewExpansionServices()
defer func() { services.Shutdown() }()
addr, err := services.GetAddr("schemaio")
if err != nil {
log.Printf("skipping missing expansion service: %v", err)
} else {
expansionAddr = addr
}
ptest.MainRet(m)
}