| // Copyright 2020 The Casbin Authors. All Rights Reserved. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| import type { Adapter, Model } from 'casbin'; |
| import type { CasbinRule } from './casbin-rule'; |
| import type * as pg from 'pg'; |
| import type * as mysql from 'mysql'; |
| import type * as mysql2 from 'mysql2/promise'; |
| import type * as sqlite3 from 'sqlite3'; |
| import type * as mssql from 'mssql'; |
| // import type * as oracledb from 'oracledb'; |
| |
| import { Helper } from 'casbin'; |
| import * as Knex from 'knex'; |
| |
| export type Config = Knex.Knex.Config & { |
| client: keyof Instance; |
| }; |
| |
| export type Instance = { |
| pg: pg.Client; |
| mysql: mysql.Connection; |
| mysql2: Promise<mysql2.Connection>; |
| sqlite3: sqlite3.Database; |
| mssql: mssql.ConnectionPool; |
| }; |
| |
| export class BasicAdapter<T extends keyof Instance> implements Adapter { |
| private knex: Knex.Knex; |
| private config: Config; |
| private drive: T; |
| private client: Instance[T]; |
| private tableName: string; |
| |
| private constructor( |
| drive: T, |
| client: Instance[T], |
| tableName: string = 'casbin_rule', |
| ) { |
| this.config = { |
| client: drive, |
| useNullAsDefault: drive === 'sqlite3', |
| log: { warn: () => {} }, |
| }; |
| this.knex = Knex.knex(this.config); |
| this.drive = drive; |
| this.client = client; |
| this.tableName = tableName; |
| } |
| |
| static async newAdapter<T extends keyof Instance>( |
| drive: T, |
| client: Instance[T], |
| tableName: string = 'casbin_rule', |
| ): Promise<BasicAdapter<T>> { |
| const a = new BasicAdapter(drive, client, tableName); |
| await a.connect(); |
| await a.createTable(); |
| |
| return a; |
| } |
| |
| async loadPolicy(model: Model): Promise<void> { |
| const result = await this.query( |
| this.knex.select().from(this.tableName).toQuery(), |
| ); |
| |
| for (const line of result) { |
| this.loadPolicyLine(line, model); |
| } |
| } |
| |
| async savePolicy(model: Model): Promise<boolean> { |
| await this.query(this.knex.del().from(this.tableName).toQuery()); |
| |
| let astMap = model.model.get('p')!; |
| const processes: Array<Promise<CasbinRule[]>> = []; |
| |
| for (const [ptype, ast] of astMap) { |
| for (const rule of ast.policy) { |
| const line = this.savePolicyLine(ptype, rule); |
| const p = this.query( |
| this.knex.insert(line).into(this.tableName).toQuery(), |
| ); |
| processes.push(p); |
| } |
| } |
| |
| astMap = model.model.get('g')!; |
| for (const [ptype, ast] of astMap) { |
| for (const rule of ast.policy) { |
| const line = this.savePolicyLine(ptype, rule); |
| const p = this.query( |
| this.knex.insert(line).into(this.tableName).toQuery(), |
| ); |
| processes.push(p); |
| } |
| } |
| |
| await Promise.all(processes); |
| |
| return true; |
| } |
| |
| async addPolicy(sec: string, ptype: string, rule: string[]): Promise<void> { |
| const line = this.savePolicyLine(ptype, rule); |
| await this.query(this.knex.insert(line).into(this.tableName).toQuery()); |
| } |
| |
| async addPolicies( |
| sec: string, |
| ptype: string, |
| rules: string[][], |
| ): Promise<void> { |
| const processes: Array<Promise<unknown>> = []; |
| for (const rule of rules) { |
| const line = this.savePolicyLine(ptype, rule); |
| const p = this.query( |
| this.knex.insert(line).into(this.tableName).toQuery(), |
| ); |
| processes.push(p); |
| } |
| |
| await Promise.all(processes); |
| } |
| |
| async removePolicy( |
| sec: string, |
| ptype: string, |
| rule: string[], |
| ): Promise<void> { |
| const line = this.savePolicyLine(ptype, rule); |
| await this.query( |
| this.knex.del().where(line).from(this.tableName).toQuery(), |
| ); |
| } |
| |
| async removePolicies( |
| sec: string, |
| ptype: string, |
| rules: string[][], |
| ): Promise<void> { |
| const processes: Array<Promise<CasbinRule[]>> = []; |
| for (const rule of rules) { |
| const line = this.savePolicyLine(ptype, rule); |
| const p = this.query( |
| this.knex.del().where(line).from(this.tableName).toQuery(), |
| ); |
| processes.push(p); |
| } |
| |
| await Promise.all(processes); |
| } |
| |
| async removeFilteredPolicy( |
| sec: string, |
| ptype: string, |
| fieldIndex: number, |
| ...fieldValues: string[] |
| ): Promise<void> { |
| const line: Omit<CasbinRule, 'id'> = { ptype }; |
| |
| const idx = fieldIndex + fieldValues.length; |
| if (fieldIndex <= 0 && 0 < idx) { |
| line.v0 = fieldValues[0 - fieldIndex]; |
| } |
| if (fieldIndex <= 1 && 1 < idx) { |
| line.v1 = fieldValues[1 - fieldIndex]; |
| } |
| if (fieldIndex <= 2 && 2 < idx) { |
| line.v2 = fieldValues[2 - fieldIndex]; |
| } |
| if (fieldIndex <= 3 && 3 < idx) { |
| line.v3 = fieldValues[3 - fieldIndex]; |
| } |
| if (fieldIndex <= 4 && 4 < idx) { |
| line.v4 = fieldValues[4 - fieldIndex]; |
| } |
| if (fieldIndex <= 5 && 5 < idx) { |
| line.v5 = fieldValues[5 - fieldIndex]; |
| } |
| |
| await this.query( |
| this.knex.del().where(line).from(this.tableName).toQuery(), |
| ); |
| } |
| |
| async close(): Promise<void> { |
| switch (this.drive) { |
| case 'pg': |
| case 'mysql': { |
| await (<BasicAdapter<'pg' | 'mysql'>>this).client.end(); |
| |
| break; |
| } |
| case 'mysql2': { |
| await (await (<BasicAdapter<'mysql2'>>this).client).end(); |
| |
| break; |
| } |
| case 'sqlite3': { |
| await new Promise<void>((resolve, reject) => { |
| (<BasicAdapter<'sqlite3'>>this).client.close((err) => { |
| if (err) { |
| reject(err); |
| } |
| |
| resolve(); |
| }); |
| }); |
| |
| break; |
| } |
| case 'mssql': { |
| await (<BasicAdapter<'mssql'>>this).client.close(); |
| |
| break; |
| } |
| } |
| } |
| |
| private loadPolicyLine(line: CasbinRule, model: Model): void { |
| const result = |
| line.ptype + |
| ', ' + |
| [line.v0, line.v1, line.v2, line.v3, line.v4, line.v5] |
| .filter((n) => n) |
| .join(', '); |
| Helper.loadPolicyLine(result, model); |
| } |
| |
| private savePolicyLine( |
| ptype: string, |
| rule: string[], |
| ): Omit<CasbinRule, 'id'> { |
| const line: Omit<CasbinRule, 'id'> = { ptype }; |
| |
| if (rule.length > 0) { |
| line.v0 = rule[0]; |
| } |
| if (rule.length > 1) { |
| line.v1 = rule[1]; |
| } |
| if (rule.length > 2) { |
| line.v2 = rule[2]; |
| } |
| if (rule.length > 3) { |
| line.v3 = rule[3]; |
| } |
| if (rule.length > 4) { |
| line.v4 = rule[4]; |
| } |
| if (rule.length > 5) { |
| line.v5 = rule[5]; |
| } |
| |
| return line; |
| } |
| |
| private async createTable(): Promise<void> { |
| const parts = this.tableName.split('.'); |
| const schema = parts.length === 2 ? parts[0] : undefined; |
| const table = parts.length === 2 ? parts[1] : parts[0]; |
| |
| // use the schema if provided |
| const schemaProxy = schema |
| ? this.knex.schema.withSchema(schema) |
| : this.knex.schema; |
| |
| const tableExists = await this.query( |
| schemaProxy.hasTable(table).toString(), |
| ); |
| |
| if (tableExists.length > 0) return; |
| |
| const createTableSQL = this.knex.schema |
| .createTable(this.tableName, (t) => { |
| t.increments(); |
| t.string('ptype').notNullable(); |
| for (const i of ['v0', 'v1', 'v2', 'v3', 'v4', 'v5']) { |
| t.string(i); |
| } |
| }) |
| .toQuery(); |
| |
| await this.query(createTableSQL); |
| } |
| |
| private async connect() { |
| switch (this.drive) { |
| case 'pg': { |
| await (<BasicAdapter<'pg'>>this).client.connect(); |
| |
| break; |
| } |
| case 'mysql': { |
| await new Promise<void>((resolve, reject) => { |
| (<BasicAdapter<'mysql'>>this).client.connect((err) => { |
| if (err) reject(err); |
| resolve(); |
| }); |
| }); |
| |
| break; |
| } |
| case 'mysql2': { |
| await (<BasicAdapter<'mysql2'>>this).client; |
| |
| break; |
| } |
| case 'sqlite3': { |
| // sqlite3 will connect automatically |
| |
| break; |
| } |
| case 'mssql': { |
| await (<BasicAdapter<'mssql'>>this).client.connect(); |
| |
| break; |
| } |
| } |
| } |
| |
| private async query(sql: string): Promise<CasbinRule[]> { |
| let result: CasbinRule[] | undefined; |
| |
| switch (this.drive) { |
| case 'pg': { |
| result = ( |
| await (<BasicAdapter<'pg'>>this).client.query<CasbinRule>(sql) |
| ).rows; |
| |
| break; |
| } |
| case 'mysql': { |
| result = await new Promise((resolve, reject) => { |
| (<BasicAdapter<'mysql'>>this).client.query(sql, (err, rows) => { |
| if (err) return reject(err); |
| resolve(rows); |
| }); |
| }); |
| |
| break; |
| } |
| case 'mysql2': { |
| result = ( |
| await (await (<BasicAdapter<'mysql2'>>this).client).query(sql) |
| )[0] as CasbinRule[]; |
| |
| break; |
| } |
| case 'sqlite3': { |
| result = await new Promise<CasbinRule[] | undefined>( |
| (resolve, reject) => { |
| (<BasicAdapter<'sqlite3'>>this).client.all(sql, (err, rows) => { |
| if (err) reject(err); |
| |
| resolve(rows as CasbinRule[]); |
| }); |
| }, |
| ); |
| |
| break; |
| } |
| case 'mssql': { |
| result = (await (<BasicAdapter<'mssql'>>this).client.query(sql)) |
| .recordset as unknown as CasbinRule[] | undefined; |
| |
| break; |
| } |
| } |
| |
| return result ?? []; |
| } |
| } |