blob: a05da50628e25e91448c7e6368d9d1882a4b82e4 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package at
import (
"context"
"database/sql/driver"
"fmt"
"strings"
"github.com/arana-db/parser/ast"
"seata.apache.org/seata-go/pkg/datasource/sql/datasource"
"seata.apache.org/seata-go/pkg/datasource/sql/exec"
"seata.apache.org/seata-go/pkg/datasource/sql/types"
"seata.apache.org/seata-go/pkg/datasource/sql/util"
"seata.apache.org/seata-go/pkg/util/log"
)
const (
sqlPlaceholder = "?"
)
// insertExecutor execute insert SQL
type insertExecutor struct {
baseExecutor
parserCtx *types.ParseContext
execContext *types.ExecContext
incrementStep int
// businesSQLResult after insert sql
businesSQLResult types.ExecResult
}
// NewInsertExecutor get insert executor
func NewInsertExecutor(parserCtx *types.ParseContext, execContent *types.ExecContext, hooks []exec.SQLHook) executor {
return &insertExecutor{parserCtx: parserCtx, execContext: execContent, baseExecutor: baseExecutor{hooks: hooks}}
}
func (i *insertExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) {
i.beforeHooks(ctx, i.execContext)
defer func() {
i.afterHooks(ctx, i.execContext)
}()
beforeImage, err := i.beforeImage(ctx)
if err != nil {
return nil, err
}
res, err := f(ctx, i.execContext.Query, i.execContext.NamedValues)
if err != nil {
return nil, err
}
if i.businesSQLResult == nil {
i.businesSQLResult = res
}
afterImage, err := i.afterImage(ctx)
if err != nil {
return nil, err
}
i.execContext.TxCtx.RoundImages.AppendBeofreImage(beforeImage)
i.execContext.TxCtx.RoundImages.AppendAfterImage(afterImage)
return res, nil
}
// beforeImage build before image
func (i *insertExecutor) beforeImage(ctx context.Context) (*types.RecordImage, error) {
tableName, _ := i.parserCtx.GetTableName()
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, i.execContext.DBName, tableName)
if err != nil {
return nil, err
}
return types.NewEmptyRecordImage(metaData, types.SQLTypeInsert), nil
}
// afterImage build after image
func (i *insertExecutor) afterImage(ctx context.Context) (*types.RecordImage, error) {
if !i.isAstStmtValid() {
return nil, nil
}
tableName, _ := i.parserCtx.GetTableName()
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, i.execContext.DBName, tableName)
if err != nil {
return nil, err
}
selectSQL, selectArgs, err := i.buildAfterImageSQL(ctx)
if err != nil {
return nil, err
}
var rowsi driver.Rows
queryerCtx, ok := i.execContext.Conn.(driver.QueryerContext)
var queryer driver.Queryer
if !ok {
queryer, ok = i.execContext.Conn.(driver.Queryer)
}
if ok {
rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, selectSQL, selectArgs)
defer func() {
if rowsi != nil {
rowsi.Close()
}
}()
if err != nil {
log.Errorf("ctx driver query: %+v", err)
return nil, err
}
} else {
log.Errorf("target conn should been driver.QueryerContext or driver.Queryer")
return nil, fmt.Errorf("invalid conn")
}
image, err := i.buildRecordImages(rowsi, metaData, types.SQLTypeInsert)
if err != nil {
return nil, err
}
lockKey := i.buildLockKey(image, *metaData)
i.execContext.TxCtx.LockKeys[lockKey] = struct{}{}
return image, nil
}
// buildAfterImageSQL build select sql from insert sql
func (i *insertExecutor) buildAfterImageSQL(ctx context.Context) (string, []driver.NamedValue, error) {
// get all pk value
tableName, _ := i.parserCtx.GetTableName()
meta, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, i.execContext.DBName, tableName)
if err != nil {
return "", nil, err
}
pkValuesMap, err := i.getPkValues(ctx, i.execContext, i.parserCtx, *meta)
if err != nil {
return "", nil, err
}
pkColumnNameList := meta.GetPrimaryKeyOnlyName()
if len(pkColumnNameList) == 0 {
return "", nil, fmt.Errorf("Pk columnName size is zero")
}
dataTypeMap, err := meta.GetPrimaryKeyTypeStrMap()
if err != nil {
return "", nil, err
}
if len(dataTypeMap) != len(pkColumnNameList) {
return "", nil, fmt.Errorf("PK columnName size don't equal PK DataType size")
}
var pkRowImages []types.RowImage
rowSize := len(pkValuesMap[pkColumnNameList[0]])
for i := 0; i < rowSize; i++ {
for _, name := range pkColumnNameList {
tmpKey := name
tmpArray := pkValuesMap[tmpKey]
pkRowImages = append(pkRowImages, types.RowImage{
Columns: []types.ColumnImage{{
KeyType: types.IndexTypePrimaryKey,
ColumnName: tmpKey,
ColumnType: types.MySQLStrToJavaType(dataTypeMap[tmpKey]),
Value: tmpArray[i],
}},
})
}
}
// build check sql
sb := strings.Builder{}
sb.WriteString("SELECT * FROM " + tableName)
whereSQL := i.buildWhereConditionByPKs(pkColumnNameList, len(pkValuesMap[pkColumnNameList[0]]), "mysql", maxInSize)
sb.WriteString(" WHERE " + whereSQL + " ")
return sb.String(), i.buildPKParams(pkRowImages, pkColumnNameList), nil
}
func (i *insertExecutor) getPkValues(ctx context.Context, execCtx *types.ExecContext, parseCtx *types.ParseContext, meta types.TableMeta) (map[string][]interface{}, error) {
pkColumnNameList := meta.GetPrimaryKeyOnlyName()
pkValuesMap := make(map[string][]interface{})
var err error
// when there is only one pk in the table
if len(pkColumnNameList) == 1 {
if i.containsPK(meta, parseCtx) {
// the insert sql contain pk value
pkValuesMap, err = i.getPkValuesByColumn(ctx, execCtx)
if err != nil {
return nil, err
}
} else if containsColumns(parseCtx) {
// the insert table pk auto generated
pkValuesMap, err = i.getPkValuesByAuto(ctx, execCtx)
if err != nil {
return nil, err
}
} else {
pkValuesMap, err = i.getPkValuesByColumn(ctx, execCtx)
if err != nil {
return nil, err
}
}
} else {
// when there is multiple pk in the table
// 1,all pk columns are filled value.
// 2,the auto increment pk column value is null, and other pk value are not null.
pkValuesMap, err = i.getPkValuesByColumn(ctx, execCtx)
if err != nil {
return nil, err
}
for _, columnName := range pkColumnNameList {
if _, ok := pkValuesMap[columnName]; !ok {
curPkValuesMap, err := i.getPkValuesByAuto(ctx, execCtx)
if err != nil {
return nil, err
}
pkValuesMapMerge(&pkValuesMap, curPkValuesMap)
}
}
}
return pkValuesMap, nil
}
// containsPK the columns contains table meta pk
func (i *insertExecutor) containsPK(meta types.TableMeta, parseCtx *types.ParseContext) bool {
pkColumnNameList := meta.GetPrimaryKeyOnlyName()
if len(pkColumnNameList) == 0 {
return false
}
if parseCtx == nil || parseCtx.InsertStmt == nil || parseCtx.InsertStmt.Columns == nil {
return false
}
if len(parseCtx.InsertStmt.Columns) == 0 {
return false
}
matchCounter := 0
for _, column := range parseCtx.InsertStmt.Columns {
for _, pkName := range pkColumnNameList {
if strings.EqualFold(pkName, column.Name.O) ||
strings.EqualFold(pkName, column.Name.L) {
matchCounter++
}
}
}
return matchCounter == len(pkColumnNameList)
}
// containPK compare column name and primary key name
func (i *insertExecutor) containPK(columnName string, meta types.TableMeta) bool {
newColumnName := DelEscape(columnName, types.DBTypeMySQL)
pkColumnNameList := meta.GetPrimaryKeyOnlyName()
if len(pkColumnNameList) == 0 {
return false
}
for _, name := range pkColumnNameList {
if strings.EqualFold(name, newColumnName) {
return true
}
}
return false
}
// getPkIndex get pk index
// return the key is pk column name and the value is index of the pk column
func (i *insertExecutor) getPkIndex(InsertStmt *ast.InsertStmt, meta types.TableMeta) map[string]int {
pkIndexMap := make(map[string]int)
if InsertStmt == nil {
return pkIndexMap
}
insertColumnsSize := len(InsertStmt.Columns)
if insertColumnsSize == 0 {
return pkIndexMap
}
if meta.ColumnNames == nil {
return pkIndexMap
}
if len(meta.Columns) > 0 {
for paramIdx := 0; paramIdx < insertColumnsSize; paramIdx++ {
sqlColumnName := InsertStmt.Columns[paramIdx].Name.O
if i.containPK(sqlColumnName, meta) {
pkIndexMap[sqlColumnName] = paramIdx
}
}
return pkIndexMap
}
pkIndex := -1
allColumns := meta.Columns
for _, columnMeta := range allColumns {
tmpColumnMeta := columnMeta
pkIndex++
if i.containPK(tmpColumnMeta.ColumnName, meta) {
pkIndexMap[DelEscape(tmpColumnMeta.ColumnName, types.DBTypeMySQL)] = pkIndex
}
}
return pkIndexMap
}
// parsePkValuesFromStatement parse primary key value from statement.
// return the primary key and values<key:primary key,value:primary key values></key:primary>
func (i *insertExecutor) parsePkValuesFromStatement(insertStmt *ast.InsertStmt, meta types.TableMeta, nameValues []driver.NamedValue) (map[string][]interface{}, error) {
if insertStmt == nil {
return nil, nil
}
pkIndexMap := i.getPkIndex(insertStmt, meta)
if pkIndexMap == nil || len(pkIndexMap) == 0 {
return nil, fmt.Errorf("pkIndex is not found")
}
var pkIndexArray []int
for _, val := range pkIndexMap {
tmpVal := val
pkIndexArray = append(pkIndexArray, tmpVal)
}
if insertStmt == nil || len(insertStmt.Lists) == 0 {
return nil, fmt.Errorf("parCtx is nil, perhaps InsertStmt is empty")
}
pkValuesMap := make(map[string][]interface{})
if nameValues != nil && len(nameValues) > 0 {
// use prepared statements
insertRows, err := getInsertRows(insertStmt, pkIndexArray)
if err != nil {
return nil, err
}
if insertRows == nil || len(insertRows) == 0 {
return nil, err
}
totalPlaceholderNum := -1
for _, row := range insertRows {
if len(row) == 0 {
continue
}
currentRowPlaceholderNum := -1
for _, r := range row {
rStr, ok := r.(string)
if ok && strings.EqualFold(rStr, sqlPlaceholder) {
totalPlaceholderNum += 1
currentRowPlaceholderNum += 1
}
}
var pkKey string
var pkIndex int
var pkValues []interface{}
for key, index := range pkIndexMap {
curKey := key
curIndex := index
pkKey = curKey
pkValues = pkValuesMap[pkKey]
pkIndex = curIndex
if pkIndex > len(row)-1 {
continue
}
pkValue := row[pkIndex]
pkValueStr, ok := pkValue.(string)
if ok && strings.EqualFold(pkValueStr, sqlPlaceholder) {
currentRowNotPlaceholderNumBeforePkIndex := 0
for i := range row {
r := row[i]
rStr, ok := r.(string)
if i < pkIndex && ok && !strings.EqualFold(rStr, sqlPlaceholder) {
currentRowNotPlaceholderNumBeforePkIndex++
}
}
idx := totalPlaceholderNum - currentRowPlaceholderNum + pkIndex - currentRowNotPlaceholderNumBeforePkIndex
pkValues = append(pkValues, nameValues[idx].Value)
} else {
pkValues = append(pkValues, pkValue)
}
if _, ok := pkValuesMap[pkKey]; !ok {
pkValuesMap[pkKey] = pkValues
}
}
}
} else {
for _, list := range insertStmt.Lists {
for pkName, pkIndex := range pkIndexMap {
tmpPkName := pkName
tmpPkIndex := pkIndex
if tmpPkIndex >= len(list) {
return nil, fmt.Errorf("pkIndex out of range")
}
if node, ok := list[tmpPkIndex].(ast.ValueExpr); ok {
pkValuesMap[tmpPkName] = append(pkValuesMap[tmpPkName], node.GetValue())
}
}
}
}
return pkValuesMap, nil
}
// getPkValuesByColumn get pk value by column.
func (i *insertExecutor) getPkValuesByColumn(ctx context.Context, execCtx *types.ExecContext) (map[string][]interface{}, error) {
if !i.isAstStmtValid() {
return nil, nil
}
tableName, _ := i.parserCtx.GetTableName()
meta, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, i.execContext.DBName, tableName)
if err != nil {
return nil, err
}
pkValuesMap, err := i.parsePkValuesFromStatement(i.parserCtx.InsertStmt, *meta, execCtx.NamedValues)
if err != nil {
return nil, err
}
// generate pkValue by auto increment
for _, v := range pkValuesMap {
tmpV := v
if len(tmpV) == 1 {
// pk auto generated while single insert primary key is expression
if _, ok := tmpV[0].(*ast.FuncCallExpr); ok {
curPkValueMap, err := i.getPkValuesByAuto(ctx, execCtx)
if err != nil {
return nil, err
}
pkValuesMapMerge(&pkValuesMap, curPkValueMap)
}
} else if len(tmpV) > 0 && tmpV[0] == nil {
// pk auto generated while column exists and value is null
curPkValueMap, err := i.getPkValuesByAuto(ctx, execCtx)
if err != nil {
return nil, err
}
pkValuesMapMerge(&pkValuesMap, curPkValueMap)
}
}
return pkValuesMap, nil
}
func (i *insertExecutor) getPkValuesByAuto(ctx context.Context, execCtx *types.ExecContext) (map[string][]interface{}, error) {
if !i.isAstStmtValid() {
return nil, nil
}
tableName, _ := i.parserCtx.GetTableName()
metaData, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, i.execContext.DBName, tableName)
if err != nil {
return nil, err
}
pkValuesMap := make(map[string][]interface{})
pkMetaMap := metaData.GetPrimaryKeyMap()
if len(pkMetaMap) == 0 {
return nil, fmt.Errorf("pk map is empty")
}
var autoColumnName string
for _, columnMeta := range pkMetaMap {
tmpColumnMeta := columnMeta
if tmpColumnMeta.Autoincrement {
autoColumnName = tmpColumnMeta.ColumnName
break
}
}
if len(autoColumnName) == 0 {
return nil, fmt.Errorf("auto increment column not exist")
}
updateCount, err := i.businesSQLResult.GetResult().RowsAffected()
if err != nil {
return nil, err
}
lastInsertId, err := i.businesSQLResult.GetResult().LastInsertId()
if err != nil {
return nil, err
}
// If there is batch insert
// do auto increment base LAST_INSERT_ID and variable `auto_increment_increment`
if lastInsertId > 0 && updateCount > 1 && canAutoIncrement(pkMetaMap) {
return i.autoGeneratePks(execCtx, autoColumnName, lastInsertId, updateCount)
}
if lastInsertId > 0 {
var pkValues []interface{}
pkValues = append(pkValues, lastInsertId)
pkValuesMap[autoColumnName] = pkValues
return pkValuesMap, nil
}
return nil, nil
}
func canAutoIncrement(pkMetaMap map[string]types.ColumnMeta) bool {
if len(pkMetaMap) != 1 {
return false
}
for _, meta := range pkMetaMap {
return meta.Autoincrement
}
return false
}
func (i *insertExecutor) isAstStmtValid() bool {
return i.parserCtx != nil && i.parserCtx.InsertStmt != nil
}
func (i *insertExecutor) autoGeneratePks(execCtx *types.ExecContext, autoColumnName string, lastInsetId, updateCount int64) (map[string][]interface{}, error) {
var step int64
if i.incrementStep > 0 {
step = int64(i.incrementStep)
} else {
// get step by query sql
stmt, err := execCtx.Conn.Prepare("SHOW VARIABLES LIKE 'auto_increment_increment'")
if err != nil {
log.Errorf("build prepare stmt: %+v", err)
return nil, err
}
rows, err := stmt.Query(nil)
if err != nil {
log.Errorf("stmt query: %+v", err)
return nil, err
}
if len(rows.Columns()) > 0 {
var curStep []driver.Value
if err := rows.Next(curStep); err != nil {
return nil, err
}
if curStepInt, ok := curStep[0].(int64); ok {
step = curStepInt
}
} else {
return nil, fmt.Errorf("query is empty")
}
}
if step == 0 {
return nil, fmt.Errorf("get increment step error")
}
var pkValues []interface{}
for j := int64(0); j < updateCount; j++ {
pkValues = append(pkValues, lastInsetId+j*step)
}
pkValuesMap := make(map[string][]interface{})
pkValuesMap[autoColumnName] = pkValues
return pkValuesMap, nil
}
func pkValuesMapMerge(dest *map[string][]interface{}, src map[string][]interface{}) {
for k, v := range src {
tmpK := k
tmpV := v
(*dest)[tmpK] = append((*dest)[tmpK], tmpV)
}
}
// containsColumns judge sql specify column
func containsColumns(parseCtx *types.ParseContext) bool {
if parseCtx == nil || parseCtx.InsertStmt == nil || parseCtx.InsertStmt.Columns == nil {
return false
}
return len(parseCtx.InsertStmt.Columns) > 0
}
func getInsertRows(insertStmt *ast.InsertStmt, pkIndexArray []int) ([][]interface{}, error) {
if insertStmt == nil {
return nil, nil
}
if len(insertStmt.Lists) == 0 {
return nil, nil
}
var rows [][]interface{}
for _, nodes := range insertStmt.Lists {
var row []interface{}
for i, node := range nodes {
if _, ok := node.(ast.ParamMarkerExpr); ok {
row = append(row, sqlPlaceholder)
} else if newNode, ok := node.(ast.ValueExpr); ok {
row = append(row, newNode.GetValue())
} else if newNode, ok := node.(*ast.VariableExpr); ok {
row = append(row, newNode.Name)
} else if _, ok := node.(*ast.FuncCallExpr); ok {
row = append(row, ast.FuncCallExpr{})
} else {
for _, index := range pkIndexArray {
if index == i {
return nil, fmt.Errorf("Unknown SQLExpr:%v", node)
}
}
row = append(row, ast.DefaultExpr{})
}
}
rows = append(rows, row)
}
return rows, nil
}