blob: afdffb98cf51135ebfcde7bc8b5a632e4bfc696c [file] [log] [blame]
/*
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
import (
"bytes"
"database/sql"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/apache/incubator-devlake/errors"
"github.com/apache/incubator-devlake/impl/dalgorm"
"github.com/apache/incubator-devlake/plugins/core"
"github.com/apache/incubator-devlake/plugins/core/dal"
"github.com/lib/pq"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
type Table struct {
name string
}
func (t *Table) TableName() string {
return t.name
}
func LoadData(c core.SubTaskContext) errors.Error {
var db dal.Dal
config := c.GetData().(*StarRocksConfig)
if config.SourceDsn != "" && config.SourceType != "" {
var o *gorm.DB
var err error
if config.SourceType == "mysql" {
o, err = gorm.Open(mysql.Open(config.SourceDsn))
if err != nil {
return errors.Convert(err)
}
} else if config.SourceType == "postgres" {
o, err = gorm.Open(postgres.Open(config.SourceDsn))
if err != nil {
return errors.Convert(err)
}
} else {
return errors.NotFound.New(fmt.Sprintf("unsupported source type %s", config.SourceType))
}
db = dalgorm.NewDalgorm(o)
sqlDB, err := o.DB()
if err != nil {
return errors.Convert(err)
}
defer sqlDB.Close()
} else {
db = c.GetDal()
}
var starrocksTables []string
if config.DomainLayer != "" {
starrocksTables = getTablesByDomainLayer(config.DomainLayer)
if starrocksTables == nil {
return errors.NotFound.New(fmt.Sprintf("no table found by domain layer: %s", config.DomainLayer))
}
} else {
tables := config.Tables
allTables, err := db.AllTables()
if err != nil {
return err
}
if len(tables) == 0 {
starrocksTables = allTables
} else {
for _, table := range allTables {
for _, r := range tables {
var ok bool
ok, err = errors.Convert01(regexp.Match(r, []byte(table)))
if err != nil {
return err
}
if ok {
starrocksTables = append(starrocksTables, table)
}
}
}
}
}
starrocks, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", config.User, config.Password, config.Host, config.Port, config.Database))
if err != nil {
return errors.Convert(err)
}
defer starrocks.Close()
for _, table := range starrocksTables {
starrocksTable := strings.TrimLeft(table, "_")
starrocksTmpTable := fmt.Sprintf("%s_tmp", starrocksTable)
var columnMap map[string]string
var orderBy string
var skip bool
columnMap, orderBy, skip, err = createTmpTable(starrocks, db, starrocksTable, starrocksTmpTable, table, c, config)
if skip {
c.GetLogger().Info(fmt.Sprintf("table %s is up to date, so skip it", table))
continue
}
if err != nil {
c.GetLogger().Error(err, "create table %s in starrocks error", table)
return errors.Convert(err)
}
if db.Dialect() == "postgres" {
err = db.Exec("begin transaction isolation level repeatable read")
if err != nil {
return errors.Convert(err)
}
} else if db.Dialect() == "mysql" {
err = db.Exec("set session transaction isolation level repeatable read")
if err != nil {
return errors.Convert(err)
}
err = errors.Convert(db.Exec("start transaction"))
if err != nil {
return errors.Convert(err)
}
} else {
return errors.NotFound.New(fmt.Sprintf("unsupported dialect %s", db.Dialect()))
}
err = errors.Convert(loadData(starrocks, c, starrocksTable, starrocksTmpTable, table, columnMap, db, config, orderBy))
if err != nil {
return errors.Convert(err)
}
err = errors.Convert(db.Exec("commit"))
if err != nil {
return errors.Convert(err)
}
}
return nil
}
func createTmpTable(starrocks *sql.DB, db dal.Dal, starrocksTable string, starrocksTmpTable string, table string, c core.SubTaskContext, config *StarRocksConfig) (map[string]string, string, bool, errors.Error) {
columnMetas, err := db.GetColumns(&Table{name: table}, nil)
updateColumn := config.UpdateColumn
columnMap := make(map[string]string)
if err != nil {
if strings.Contains(err.Error(), "cached plan must not change result type") {
c.GetLogger().Warn(err, "skip err: cached plan must not change result type")
columnMetas, err = db.GetColumns(&Table{name: table}, nil)
if err != nil {
return nil, "", false, errors.Convert(err)
}
} else {
return nil, "", false, errors.Convert(err)
}
}
var pks []string
var orders []string
var columns []string
var separator string
if db.Dialect() == "postgres" {
separator = "\""
} else if db.Dialect() == "mysql" {
separator = "`"
} else {
return nil, "", false, errors.NotFound.New(fmt.Sprintf("unsupported dialect %s", db.Dialect()))
}
firstcm := ""
firstcmName := ""
var rowsInStarRocks *sql.Rows
var rowsInPostgres *sql.Rows
defer func() {
if rowsInStarRocks != nil {
rowsInStarRocks.Close()
}
if rowsInPostgres != nil {
rowsInPostgres.Close()
}
}()
for _, cm := range columnMetas {
name := cm.Name()
if name == updateColumn {
// check update column to detect skip or not
rowsInPostgres, err = db.Cursor(
dal.From(table),
dal.Select(updateColumn),
dal.Limit(1),
dal.Orderby(fmt.Sprintf("%s desc", updateColumn)),
)
if err != nil {
return nil, "", false, err
}
var updatedFrom time.Time
if rowsInPostgres.Next() {
err = errors.Convert(rowsInPostgres.Scan(&updatedFrom))
if err != nil {
return nil, "", false, err
}
}
var starrocksErr error
rowsInStarRocks, starrocksErr = starrocks.Query(fmt.Sprintf("select %s from %s order by %s desc limit 1", updateColumn, starrocksTable, updateColumn))
if starrocksErr != nil {
if !strings.Contains(starrocksErr.Error(), "Unknown table") {
return nil, "", false, errors.Convert(starrocksErr)
}
} else {
var updatedTo time.Time
if rowsInStarRocks.Next() {
err = errors.Convert(rowsInStarRocks.Scan(&updatedTo))
if err != nil {
return nil, "", false, err
}
}
if updatedFrom.Equal(updatedTo) {
return nil, "", true, nil
}
}
}
columnDatatype, ok := cm.ColumnType()
if !ok {
return columnMap, "", false, errors.Default.New(fmt.Sprintf("Get [%s] ColumeType Failed", name))
}
dataType := getStarRocksDataType(columnDatatype)
columnMap[name] = dataType
column := fmt.Sprintf("`%s` %s", name, dataType)
columns = append(columns, column)
isPrimaryKey, ok := cm.PrimaryKey()
if isPrimaryKey && ok {
pks = append(pks, fmt.Sprintf("`%s`", name))
orders = append(orders, fmt.Sprintf("%s%s%s", separator, name, separator))
}
if firstcm == "" {
firstcm = fmt.Sprintf("`%s`", name)
firstcmName = fmt.Sprintf("%s%s%s", separator, name, separator)
}
}
if len(pks) == 0 {
pks = append(pks, firstcm)
}
orderBy := strings.Join(orders, ", ")
if config.OrderBy != nil {
if v, ok := config.OrderBy[table]; ok {
orderBy = v
}
}
if orderBy == "" {
orderBy = firstcmName
}
extra := fmt.Sprintf(`engine=olap distributed by hash(%s) properties("replication_num" = "1")`, strings.Join(pks, ", "))
if config.Extra != nil {
if v, ok := config.Extra[table]; ok {
extra = v
}
}
tableSql := fmt.Sprintf("drop table if exists %s; create table if not exists `%s` ( %s ) %s", starrocksTmpTable, starrocksTmpTable, strings.Join(columns, ","), extra)
c.GetLogger().Debug(tableSql)
_, err = errors.Convert01(starrocks.Exec(tableSql))
return columnMap, orderBy, false, err
}
func loadData(starrocks *sql.DB, c core.SubTaskContext, starrocksTable, starrocksTmpTable, table string, columnMap map[string]string, db dal.Dal, config *StarRocksConfig, orderBy string) error {
offset := 0
var err error
for {
var data []map[string]interface{}
// select data from db
err = func() error {
var rows *sql.Rows
rows, err = db.RawCursor(fmt.Sprintf("select * from %s order by %s limit %d offset %d", table, orderBy, config.BatchSize, offset))
if err != nil {
return err
}
defer rows.Close()
cols, err := rows.Columns()
if err != nil {
return err
}
for rows.Next() {
row := make(map[string]interface{})
columns := make([]interface{}, len(cols))
columnPointers := make([]interface{}, len(cols))
for i := range columns {
dataType := columnMap[cols[i]]
if strings.HasPrefix(dataType, "array") {
var arr []string
columns[i] = &arr
columnPointers[i] = pq.Array(&arr)
} else {
columnPointers[i] = &columns[i]
}
}
err = rows.Scan(columnPointers...)
if err != nil {
return err
}
for i, colName := range cols {
row[colName] = columns[i]
}
data = append(data, row)
}
return nil
}()
if err != nil {
return err
}
if len(data) == 0 {
c.GetLogger().Warn(nil, "no data found in table %s already, limit: %d, offset: %d, so break", table, config.BatchSize, offset)
break
}
// insert data to tmp table
loadURL := fmt.Sprintf("http://%s:%d/api/%s/%s/_stream_load", config.BeHost, config.BePort, config.Database, starrocksTmpTable)
headers := map[string]string{
"format": "json",
"strip_outer_array": "true",
"Expect": "100-continue",
"ignore_json_size": "true",
"Connection": "close",
}
jsonData, err := json.Marshal(data)
if err != nil {
return err
}
client := http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
req, err := http.NewRequest(http.MethodPut, loadURL, bytes.NewBuffer(jsonData))
if err != nil {
return err
}
req.SetBasicAuth(config.User, config.Password)
for k, v := range headers {
req.Header.Set(k, v)
}
resp, err := client.Do(req)
if err != nil {
return err
}
if resp.StatusCode == 307 {
var location *url.URL
location, err = resp.Location()
if err != nil {
return err
}
req, err = http.NewRequest(http.MethodPut, location.String(), bytes.NewBuffer(jsonData))
if err != nil {
return err
}
req.SetBasicAuth(config.User, config.Password)
for k, v := range headers {
req.Header.Set(k, v)
}
resp, err = client.Do(req)
}
if err != nil {
return err
}
b, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
var result map[string]interface{}
err = json.Unmarshal(b, &result)
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
c.GetLogger().Error(nil, "[%s]: %s", resp.StatusCode, string(b))
}
if result["Status"] != "Success" {
c.GetLogger().Error(nil, "load %s failed: %s", table, string(b))
} else {
c.GetLogger().Debug("load %s success: %s, limit: %d, offset: %d", table, b, config.BatchSize, offset)
}
offset += len(data)
}
// drop old table
_, err = starrocks.Exec(fmt.Sprintf("drop table if exists %s", starrocksTable))
if err != nil {
return err
}
// rename tmp table to old table
_, err = starrocks.Exec(fmt.Sprintf("alter table %s rename %s", starrocksTmpTable, starrocksTable))
if err != nil {
return err
}
// check data count
rows, err := db.RawCursor(fmt.Sprintf("select count(*) from %s", table))
if err != nil {
return err
}
defer rows.Close()
var sourceCount int
for rows.Next() {
err = rows.Scan(&sourceCount)
if err != nil {
return err
}
}
rowsStarRocks, err := starrocks.Query(fmt.Sprintf("select count(*) from %s", starrocksTable))
if err != nil {
return err
}
defer rowsStarRocks.Close()
var starrocksCount int
for rowsStarRocks.Next() {
err = rowsStarRocks.Scan(&starrocksCount)
if err != nil {
return err
}
}
if sourceCount != starrocksCount {
c.GetLogger().Warn(nil, "source count %d not equal to starrocks count %d", sourceCount, starrocksCount)
}
c.GetLogger().Info("load %s to starrocks success", table)
return nil
}
var LoadDataTaskMeta = core.SubTaskMeta{
Name: "LoadData",
EntryPoint: LoadData,
EnabledByDefault: true,
Description: "Load data to StarRocks",
}