| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package sql |
| |
| import ( |
| "context" |
| "database/sql" |
| "database/sql/driver" |
| "errors" |
| "fmt" |
| "io" |
| "reflect" |
| "strings" |
| |
| "github.com/go-sql-driver/mysql" |
| |
| "seata.apache.org/seata-go/pkg/datasource/sql/datasource" |
| mysql2 "seata.apache.org/seata-go/pkg/datasource/sql/datasource/mysql" |
| "seata.apache.org/seata-go/pkg/datasource/sql/types" |
| "seata.apache.org/seata-go/pkg/datasource/sql/util" |
| "seata.apache.org/seata-go/pkg/protocol/branch" |
| "seata.apache.org/seata-go/pkg/util/log" |
| ) |
| |
| const ( |
| // SeataATMySQLDriver MySQL driver for AT mode |
| SeataATMySQLDriver = "seata-at-mysql" |
| // SeataXAMySQLDriver MySQL driver for XA mode |
| SeataXAMySQLDriver = "seata-xa-mysql" |
| ) |
| |
| func initDriver() { |
| sql.Register(SeataATMySQLDriver, &seataATDriver{ |
| seataDriver: &seataDriver{ |
| branchType: branch.BranchTypeAT, |
| transType: types.ATMode, |
| target: mysql.MySQLDriver{}, |
| }, |
| }) |
| |
| sql.Register(SeataXAMySQLDriver, &seataXADriver{ |
| seataDriver: &seataDriver{ |
| branchType: branch.BranchTypeXA, |
| transType: types.XAMode, |
| target: mysql.MySQLDriver{}, |
| }, |
| }) |
| } |
| |
| type seataATDriver struct { |
| *seataDriver |
| } |
| |
| func (d *seataATDriver) OpenConnector(name string) (c driver.Connector, err error) { |
| connector, err := d.seataDriver.OpenConnector(name) |
| if err != nil { |
| return nil, err |
| } |
| |
| _connector, _ := connector.(*seataConnector) |
| _connector.transType = types.ATMode |
| cfg, _ := mysql.ParseDSN(name) |
| _connector.cfg = cfg |
| |
| return &seataATConnector{ |
| seataConnector: _connector, |
| }, nil |
| } |
| |
| type seataXADriver struct { |
| *seataDriver |
| } |
| |
| func (d *seataXADriver) OpenConnector(name string) (c driver.Connector, err error) { |
| connector, err := d.seataDriver.OpenConnector(name) |
| if err != nil { |
| return nil, err |
| } |
| |
| _connector, _ := connector.(*seataConnector) |
| _connector.transType = types.XAMode |
| cfg, _ := mysql.ParseDSN(name) |
| _connector.cfg = cfg |
| |
| return &seataXAConnector{ |
| seataConnector: _connector, |
| }, nil |
| } |
| |
| type seataDriver struct { |
| branchType branch.BranchType |
| transType types.TransactionMode |
| target driver.Driver |
| } |
| |
| // Open never be called, because seataDriver implemented dri.DriverContext interface. |
| // reference package: datasource/sql [https://cs.opensource.google/go/go/+/master:src/database/sql/sql.go;l=813] |
| // and maybe the sql.BD will be call Driver() method, but it obtain the Driver is fron Connector that is proxed by seataConnector. |
| func (d *seataDriver) Open(name string) (driver.Conn, error) { |
| return nil, errors.New(("operation unsupport.")) |
| } |
| |
| func (d *seataDriver) OpenConnector(name string) (c driver.Connector, err error) { |
| c = &dsnConnector{dsn: name, driver: d.target} |
| if driverCtx, ok := d.target.(driver.DriverContext); ok { |
| c, err = driverCtx.OpenConnector(name) |
| if err != nil { |
| log.Errorf("open connector: %w", err) |
| return nil, err |
| } |
| } |
| |
| dbType := types.ParseDBType(d.getTargetDriverName()) |
| if dbType == types.DBTypeUnknown { |
| return nil, fmt.Errorf("unsupport conn type %s", d.getTargetDriverName()) |
| } |
| |
| proxy, err := d.getOpenConnectorProxy(c, dbType, sql.OpenDB(c), name) |
| if err != nil { |
| log.Errorf("register resource: %w", err) |
| return nil, err |
| } |
| |
| return proxy, nil |
| } |
| |
| func (d *seataDriver) getOpenConnectorProxy(connector driver.Connector, dbType types.DBType, |
| db *sql.DB, dataSourceName string) (driver.Connector, error) { |
| cfg, _ := mysql.ParseDSN(dataSourceName) |
| options := []dbOption{ |
| withResourceID(parseResourceID(dataSourceName)), |
| withTarget(db), |
| withBranchType(d.branchType), |
| withDBType(dbType), |
| withDBName(cfg.DBName), |
| withConnector(connector), |
| } |
| res, err := newResource(options...) |
| if err != nil { |
| log.Errorf("create new resource: %w", err) |
| return nil, err |
| } |
| datasource.RegisterTableCache(types.DBTypeMySQL, mysql2.NewTableMetaInstance(db, cfg)) |
| if err = datasource.GetDataSourceManager(d.branchType).RegisterResource(res); err != nil { |
| log.Errorf("regisiter resource: %w", err) |
| return nil, err |
| } |
| return &seataConnector{ |
| res: res, |
| target: connector, |
| cfg: cfg, |
| }, nil |
| } |
| |
| func (d *seataDriver) getTargetDriverName() string { |
| return "mysql" |
| } |
| |
| type dsnConnector struct { |
| dsn string |
| driver driver.Driver |
| } |
| |
| func (t *dsnConnector) Connect(_ context.Context) (driver.Conn, error) { |
| return t.driver.Open(t.dsn) |
| } |
| |
| func (t *dsnConnector) Driver() driver.Driver { |
| return t.driver |
| } |
| |
| func parseResourceID(dsn string) string { |
| i := strings.Index(dsn, "?") |
| res := dsn |
| if i > 0 { |
| res = dsn[:i] |
| } |
| return strings.ReplaceAll(res, ",", "|") |
| } |
| |
| func selectDBVersion(ctx context.Context, conn driver.Conn) (string, error) { |
| var rowsi driver.Rows |
| var err error |
| |
| queryerCtx, ok := conn.(driver.QueryerContext) |
| var queryer driver.Queryer |
| if !ok { |
| queryer, ok = conn.(driver.Queryer) |
| } |
| if ok { |
| rowsi, err = util.CtxDriverQuery(ctx, queryerCtx, queryer, "SELECT VERSION()", nil) |
| defer func() { |
| if rowsi != nil { |
| rowsi.Close() |
| } |
| }() |
| if err != nil { |
| log.Errorf("ctx driver query: %+v", err) |
| return "", err |
| } |
| } else { |
| log.Errorf("target conn should been driver.QueryerContext or driver.Queryer") |
| return "", fmt.Errorf("invalid conn") |
| } |
| |
| dest := make([]driver.Value, 1) |
| var version string |
| if err = rowsi.Next(dest); err != nil { |
| if err == io.EOF { |
| return version, nil |
| } |
| return "", err |
| } |
| if len(dest) != 1 { |
| return "", errors.New("get db version is not column 1") |
| } |
| |
| switch reflect.TypeOf(dest[0]).Kind() { |
| case reflect.Slice, reflect.Array: |
| val := reflect.ValueOf(dest[0]).Bytes() |
| version = string(val) |
| case reflect.String: |
| version = reflect.ValueOf(dest[0]).String() |
| default: |
| return "", errors.New("get db version is not a string") |
| } |
| |
| return version, nil |
| } |