| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package shardingsphere |
| |
| import ( |
| "database/sql" |
| "fmt" |
| |
| _ "github.com/go-sql-driver/mysql" |
| ) |
| |
| const ( |
| // DistSQLCreateDatabase create database if not exists. |
| DistSQLCreateDatabase = `CREATE DATABASE IF NOT EXISTS %s;` |
| // DistSQLUseDatabase use database. |
| DistSQLUseDatabase = `USE %s;` |
| // DistSQLRegisterStorageUnit register database to shardingsphere by storage unit name and database info. |
| DistSQLRegisterStorageUnit = `REGISTER STORAGE UNIT IF NOT EXISTS %s (HOST="%s",PORT=%d,DB="%s",USER="%s",PASSWORD="%s");` |
| // DistSQLShowRulesUsed show all rules used by storage unit name. |
| DistSQLShowRulesUsed = `SHOW RULES USED STORAGE UNIT %s;` |
| // DistSQLUnRegisterStorageUnit unregister database from shardingsphere by storage unit name. |
| DistSQLUnRegisterStorageUnit = `UNREGISTER STORAGE UNIT %s;` |
| // DistSQLDropRule drop rule by rule type and rule name. |
| DistSQLDropRule = `DROP %s RULE %s;` |
| // DistSQLDropTable drop table by table name. |
| DistSQLDropTable = `DROP TABLE %s;` |
| ) |
| |
| var ruleTypeMap = map[string]string{} |
| |
| type Rule struct { |
| Type string |
| Name string |
| } |
| |
| type server struct { |
| db *sql.DB |
| } |
| |
| type IServer interface { |
| CreateDatabase(dbName string) error |
| RegisterStorageUnit(logicDBName, dsName, dsHost string, dsPort uint, dsDBName, dsUser, dsPassword string) error |
| UnRegisterStorageUnit(logicDBName, dsName string) error |
| Close() error |
| } |
| |
| var _ IServer = (*server)(nil) |
| |
| func NewServer(driver, host string, port uint, user, password string) (IServer, error) { |
| if driver != "mysql" && driver != "postgres" { |
| return nil, fmt.Errorf("unsupported database driver: %s", driver) |
| } |
| |
| if host == "" || port == 0 || user == "" || password == "" { |
| return nil, fmt.Errorf("invalid database config, host=%s, port=%d, user=%s, password=%s", host, port, user, password) |
| } |
| |
| dataSourceName := fmt.Sprintf("%s:%s@tcp(%s:%d)/", user, password, host, port) |
| |
| db, err := sql.Open(driver, dataSourceName) |
| if err != nil { |
| return nil, fmt.Errorf("open database=%s error: %w", dataSourceName, err) |
| } |
| |
| // check database connection |
| if err = db.Ping(); err != nil { |
| return nil, fmt.Errorf("ping database=%s error: %w", dataSourceName, err) |
| } |
| |
| return &server{db: db}, nil |
| } |
| |
| func (s *server) Close() error { |
| return s.db.Close() |
| } |
| |
| func (s *server) CreateDatabase(dbName string) error { |
| distSQL := fmt.Sprintf(DistSQLCreateDatabase, dbName) |
| |
| _, err := s.db.Exec(distSQL) |
| if err != nil { |
| return fmt.Errorf("create database error: %w", err) |
| } |
| |
| return nil |
| } |
| |
| func (s *server) RegisterStorageUnit(logicDBName, dsName, dsHost string, dsPort uint, dsDBName, dsUser, dsPassword string) error { |
| _, err := s.db.Exec(fmt.Sprintf(DistSQLUseDatabase, logicDBName)) |
| if err != nil { |
| return fmt.Errorf("use database error: %w", err) |
| } |
| |
| distSQL := fmt.Sprintf(DistSQLRegisterStorageUnit, dsName, dsHost, dsPort, dsDBName, dsUser, dsPassword) |
| |
| _, err = s.db.Exec(distSQL) |
| if err != nil { |
| return fmt.Errorf("register database error: %w", err) |
| } |
| |
| return nil |
| } |
| |
| // getRulesUsed returns all rules used by storage unit name. |
| func (s *server) getRulesUsed(dsName string) (rules []*Rule, err error) { |
| rules = make([]*Rule, 0) |
| distSQL := fmt.Sprintf(DistSQLShowRulesUsed, dsName) |
| |
| rows, err := s.db.Query(distSQL) |
| if err != nil { |
| return nil, fmt.Errorf("get rules used error: %w", err) |
| } |
| defer rows.Close() |
| |
| for rows.Next() { |
| var ruleT, ruleN string |
| if err := rows.Scan(&ruleT, &ruleN); err != nil { |
| return nil, fmt.Errorf("scan rules used error: %w", err) |
| } |
| rules = append(rules, &Rule{Type: ruleT, Name: ruleN}) |
| } |
| |
| if err := rows.Err(); err != nil { |
| return nil, fmt.Errorf("rows error: %w", err) |
| } |
| return rules, nil |
| } |
| |
| func (s *server) UnRegisterStorageUnit(logicDBName, dsName string) error { |
| _, err := s.db.Exec(fmt.Sprintf(DistSQLUseDatabase, logicDBName)) |
| if err != nil { |
| return fmt.Errorf("use database error: %w", err) |
| } |
| |
| rules, err := s.getRulesUsed(dsName) |
| if err != nil { |
| return fmt.Errorf("get rules used error: %w", err) |
| } |
| |
| // clean all rules used by storage unit |
| for _, rule := range rules { |
| if err := s.dropRule(rule.Type, rule.Name); err != nil { |
| return fmt.Errorf("drop rule error: %w", err) |
| } |
| } |
| |
| distSQL := fmt.Sprintf(DistSQLUnRegisterStorageUnit, dsName) |
| |
| _, err = s.db.Exec(distSQL) |
| if err != nil { |
| return fmt.Errorf("unregister database error: %w", err) |
| } |
| |
| return nil |
| } |
| |
| func (s *server) dropRule(ruleType, ruleName string) error { |
| // convert rule type |
| ruleType = ruleTypeMap[ruleType] |
| distSQL := fmt.Sprintf(DistSQLDropRule, ruleType, ruleName) |
| _, err := s.db.Exec(distSQL) |
| if err != nil { |
| return fmt.Errorf("drop rule fail, err: %s", err) |
| } |
| return nil |
| } |
| |
| func init() { |
| // init rule type map |
| // implement more rule type if needed |
| ruleTypeMap = map[string]string{ |
| "sharding": "SHARDING TABLE", |
| } |
| } |