package org.apache.nlpcraft.examples.sql.db
import java.sql.Types
import com.typesafe.scalalogging.LazyLogging
import org.jgrapht.alg.shortestpath.DijkstraShortestPath
import org.jgrapht.graph.{DefaultEdge, SimpleGraph}
import scala.collection.JavaConverters._
import scala.collection.{Seq, mutable}
import scala.compat.java8.OptionConverters._
* SQL builder that takes parsed SQL schema and creates SQL query based on configured parameters.
* Note that this is a *very simple* example implementation and it lacks several important capabilities:
* - OR logic conditions
* - negation (SQL <> condition),
* - LIKE and another functions.
* However, these capabilities can be added relatively easy to this extendable implementation.
* @param schema Parsed DB schema to initialize with.
case class SqlBuilder(schema: NCSqlSchema) extends LazyLogging {
private final val DFLT_LIMIT = 1000
private case class Edge(from: String, to: String) extends DefaultEdge
private case class Key(table: String, column: String)
private val schemaTbls = schema.getTables.asScala.toSeq.sortBy(_.getTable)
private val schemaTblsByNames = ⇒ p.getTable → p).toMap
private val schemaCols = schemaTbls.flatMap(p ⇒ ⇒ Key(col.getTable, col.getColumn) → col)).toMap
private val schemaJoins = schema.getJoins.asScala
private val schemaPaths = {
val g = new SimpleGraph[String, Edge](classOf[Edge])
schemaTbls.foreach(t ⇒ g.addVertex(t.getTable))
schemaJoins.foreach(j ⇒ g.addEdge(j.getFromTable, j.getToTable, Edge(j.getFromTable, j.getToTable)))
new DijkstraShortestPath(g)
private var tbls: Seq[NCSqlTable] = Seq.empty
private var cols: Seq[NCSqlColumn] = Seq.empty
private var conds: Seq[SqlCondition] = Seq.empty
private var sorts: Seq[NCSqlSort] = Seq.empty
private var freeDateRangeOpt: Option[NCSqlDateRange] = None
private var limit: Option[NCSqlLimit] = None
* Makes SQL text fragment for given join type.
* @param clause Join type.
private def sql(clause: NCSqlJoinType): String =
clause match {
case OUTER ⇒ throw new AssertionError(s"Unsupported join type: $clause")
case _ ⇒ throw new AssertionError(s"Unexpected join type: $clause")
* Makes SQL join fragment for given tables. It uses tables parameters and information about relations between
* tables based on schema information.
* @param tbls Tables.
private def sql(tbls: Seq[NCSqlTable]): String = {
val names =
names.size match {
case 0throw new AssertionError(s"Unexpected empty tables")
case 1 ⇒ names.head
case _ ⇒
val used = mutable.HashSet.empty[String]
val refs = names.
flatMap(t ⇒ schemaJoins.filter(j ⇒ j.getFromTable == t && names.contains(j.getToTable))).
sortBy(j ⇒ Math.min(names.indexOf(j.getFromTable), names.indexOf(j.getToTable))).
map { case (join, idx)
val fromCols = join.getFromColumns.asScala
val toCols = join.getToColumns.asScala
require(fromCols.size == toCols.size)
val fromTab = join.getFromTable
val toTab = join.getToTable
val onCondition =
map { case (fromCol, toCol) ⇒ s"$fromTab.$fromCol = $toTab.$toCol" }.mkString(" AND ")
if (idx == 0) {
used += fromTab
used += toTab
s"$fromTab ${sql(join.getType)} $toTab ON $onCondition"
else {
val fromAdded = used.add(fromTab)
val toAdded = used.add(toTab)
if (fromAdded) {
val reverseType =
join.getType match {
case _ ⇒ join.getType
s"${sql(reverseType)} $fromTab ON $onCondition"
else if (toAdded)
s"${sql(join.getType)} $toTab ON $onCondition"
if (refs.length != names.length - 1)
throw new RuntimeException(s"Tables cannot be joined: ${names.mkString(", ")}")
refs.mkString(" ")
* Makes SQL text fragment for given column.
* @param clause Column.
private def sql(clause: NCSqlColumn): String = s"${clause.getTable}.${clause.getColumn}"
* Makes SQL text fragment for given sort element.
* @param clause Sort element.
private def sql(clause: NCSqlSort): String = s"${sql(clause.getColumn)} ${if (clause.isAscending) "ASC" else "DESC"}"
* Makes SQL text fragment for given condition.
* @param clause Condition.
private def sql(clause: SqlInCondition): String = s"${sql(clause.column)} IN (${ ⇒ "?").mkString(",")})"
* Makes SQL text fragment for given condition.
* @param clause Condition.
private def sql(clause: SqlSimpleCondition): String = s"${sql(clause.column)} ${clause.operation} ?"
* Makes SQL text fragment for given condition.
* @param clause Condition.
private def sql(clause: SqlCondition): String =
clause match {
case x: SqlSimpleCondition ⇒ sql(x)
case x: SqlInCondition ⇒ sql(x)
case _ ⇒ throw new AssertionError(s"Unexpected condition: $clause")
* Is given column element is DATE or not?
* @param col Column element.
private def isDate(col: NCSqlColumn): Boolean = col.getDataType == Types.DATE
* Is given column element is VARCHAR or not?
* @param col Column element.
private def isString(col: NCSqlColumn): Boolean = col.getDataType == Types.VARCHAR
* Extends given columns list, if necessary, by some additional columns, based on schema model
* configuration and free date column. It can be useful if column list is poor.
* @param initCols Initially detected columns.
* @param extTabs Extra tables list.
* @param freeDateColOpt Free date column. Optional.
private def extendColumns(
initCols: Seq[NCSqlColumn],
extTabs: Seq[NCSqlTable],
freeDateColOpt: Option[NCSqlColumn]
): Seq[NCSqlColumn] = {
var res = (initCols ++ freeDateColOpt.toSeq).distinct
if (res.size < 3)
res = (res ++
flatMap(t ⇒ ⇒ schemaCols(Key(t.getTable, col))))).distinct
* Tries to find sort elements if they are not detected explicitly.
* It attempts to use sort information from limit element if it is defined, or from model configuration
* (default sort columns).
* @param sorts Sort elements.
* @param initTbls Initially detected tables.
* @param extCols Extended columns list.
private def extendSort(sorts: Seq[NCSqlSort], initTbls: Seq[NCSqlTable], extCols: Seq[NCSqlColumn]): Seq[NCSqlSort] =
sorts.size match {
case 0
limit.flatMap(l ⇒ if (extCols.contains(l.getColumn)) Some(Seq(limit2Sort(l))) else None).
// 'Sort' can contain columns from select list only (some databases restriction)
flatMap(sort ⇒ if (extCols.contains(sort.getColumn)) Some(sort) else None)
case _ ⇒ sorts.distinct
* Extends given table list, if necessary, by some additional tables,
* based on information about relations between model tables.
* @param initTbls Initially detected tables.
private def extendTables(initTbls: Seq[NCSqlTable]): Seq[NCSqlTable] = {
val ext =
(initTbls ++ schemaTbls.filter(initTbls.contains).flatMap(
ext.size match {
case 0throw new RuntimeException("Tables cannot be empty.")
case 1 ⇒ ext
case _ ⇒
// The simple algorithm, which takes into account only FKs between tables.
val extra =
ext.combinations(2).flatMap(pair ⇒
schemaPaths.getPath(pair.head.getTable, pair.last.getTable) match {
case nullSeq.empty
case list ⇒ list.getEdgeList.asScala.flatMap(e ⇒ Seq(e.from,
if (ext.exists(t ⇒ !extra.contains(t)))
throw new RuntimeException(
s"Select clause cannot be prepared with given tables set: " +
s"${", ")}"
* Converts limit element to sort.
* @param l Limit element.
private def limit2Sort(l: NCSqlLimit): NCSqlSort = NCSqlSortImpl(l.getColumn, l.isAscending)
* Sorts given extra columns for select list, using the following criteria:
* - initially detected columns are more important.
* - next, columns from initially detected tables.
* - other extra columns.
* @param extraCols Extra columns for sorting.
* @param initCols Initially detected columns.
* @param initTbls Initially detected tables.
private def sort(extraCols: Seq[NCSqlColumn], initCols: Seq[NCSqlColumn], initTbls: Seq[NCSqlTable]): Seq[NCSqlColumn] =
extraCols.sortBy(col ⇒
if (initCols.contains(col))
else if (initTbls.contains(schemaTblsByNames(col.getTable)))
* Tries to find date sql column.
* @param initTbls Initially detected tables.
private def findDateColumn(initTbls: Seq[NCSqlTable]): Option[NCSqlColumn] =
schemaTbls.sortBy(t ⇒ {
// The simple algorithm, which tries to find most suitable date type column for free date condition.
// Higher priority for tables which were detected initially.
val weight1 = if (initTbls.contains(t)) 0 else 1
// Higher priority for tables which don't have references from another.
// Hopefully, these tables would have more specific records.
val weight2 = if (!schemaJoins.exists(_.getToTable == t.getTable)) 0 else 1
(weight1, weight2)
}).flatMap(_.getDefaultDate.asScala).toStream.headOption match {
case Some(col)Some(col)
case None
logger.warn("Free date condition ignored without throwing exceptions.")
* Extends conditions list by given free data (if defined) and converts conditions to
* SQL text and parameters list.
* @param conds Conditions.
* @param initTbls Initially detected tables.
* @param freeDateColOpt Free date column. Optional.
private def extendConditions(
conds: Seq[SqlCondition],
initTbls: Seq[NCSqlTable],
freeDateColOpt: Option[NCSqlColumn]
): (Seq[String], Seq[Any]) = {
val (freeDateConds, freeDateParams): (Seq[SqlCondition], Seq[Object]) =
freeDateColOpt match {
case Some(col)
val range = freeDateRangeOpt.getOrElse(throw new AssertionError("Missed date range"))
import org.apache.nlpcraft.examples.sql.db.{SqlSimpleCondition ⇒ C}
(Seq(C(col, ">=", range.getFrom), C(col, "<=", range.getTo)), Seq(range.getFrom, range.getTo))
case None(Seq.empty, Seq.empty)
( ++,
conds.flatMap(p ⇒
p match {
case x: SqlSimpleConditionSeq(x.value)
case x: SqlInCondition ⇒ x.values
case _ ⇒ throw new AssertionError(s"Unexpected condition: $p")
) ++ freeDateParams
* Sets given argument and returns builder's instance.
* @param tables Tables.
def withTables(tables: NCSqlTable*): SqlBuilder = { this.tbls ++= tables; this }
* Sets given argument and returns builder's instance.
* @param cols Columns.
def withColumns(cols: NCSqlColumn*): SqlBuilder = { this.cols ++= cols; this }
* Sets given argument and returns builder's instance.
* @param conds Conditions.
def withAndConditions(conds: SqlCondition*): SqlBuilder = { this.conds ++= conds; this }
* Sets given argument and returns builder's instance.
* @param freeDateRange Free date.
def withFreeDateRange(freeDateRange: NCSqlDateRange): SqlBuilder = { this.freeDateRangeOpt = Option(freeDateRange); this }
* Sets given argument and returns builder's instance.
* @param sorts Sort elements.
def withSorts(sorts: NCSqlSort*): SqlBuilder = { this.sorts ++= sorts; this }
* Set given argument and returns builder's instance.
* @param limit Limit element.
def withLimit(limit: NCSqlLimit): SqlBuilder = { this.limit = Option(limit); this }
* Main build method. Builds and returns newly constructed SQL query object.
def build(): SqlQuery = {
// Collects data.
if (freeDateRangeOpt.isDefined &&
this.conds.exists(cond ⇒ isDate(schemaCols(Key(cond.column.getTable, cond.column.getColumn))))
throw new RuntimeException("Too many date conditions.")
var tblsNorm = mutable.ArrayBuffer.empty[NCSqlTable] ++ this.tbls
var colsNorm = mutable.ArrayBuffer.empty[NCSqlColumn] ++ this.cols
var condsNorm = mutable.ArrayBuffer.empty[SqlCondition] ++ this.conds
var sortsNorm = mutable.ArrayBuffer.empty[NCSqlSort] ++ this.sorts
colsNorm.foreach(col ⇒ tblsNorm += schemaTblsByNames(col.getTable))
condsNorm.foreach(cond ⇒ { tblsNorm += schemaTblsByNames(cond.column.getTable); colsNorm += cond.column })
sortsNorm.foreach(sort ⇒ { tblsNorm += schemaTblsByNames(sort.getColumn.getTable); colsNorm += sort.getColumn })
var freeDateColOpt =
if (freeDateRangeOpt.isDefined) tblsNorm.flatMap(_.getDefaultDate.asScala).toStream.headOption else None
tblsNorm = tblsNorm.distinct
val extTbls = extendTables(tblsNorm)
if (freeDateColOpt.isEmpty && freeDateRangeOpt.isDefined) {
freeDateColOpt = findDateColumn(tblsNorm)
freeDateColOpt.toSeq.foreach(col ⇒ { tblsNorm += schemaTblsByNames(col.getTable); colsNorm += col })
tblsNorm = tblsNorm.distinct
colsNorm = colsNorm.distinct
condsNorm = condsNorm.distinct
sortsNorm = sortsNorm.distinct
val extCols = extendColumns(colsNorm, extTbls, freeDateColOpt)
val (extConds, extParams) = extendConditions(condsNorm, tblsNorm, freeDateColOpt)
val extSortCols = sort(extCols, colsNorm, tblsNorm).map(sql)
val distinct = extCols.forall(col ⇒ isString(schemaCols(Key(col.getTable, col.getColumn))))
val extSorts = extendSort(sortsNorm, tblsNorm, extCols)
sql =
| ${if (distinct) "DISTINCT" else ""}
| ${extSortCols.mkString(", ")}
| FROM ${sql(extTbls)}
| ${if (extConds.isEmpty) "" else s"WHERE ${extConds.mkString(" AND ")}"}
| ${if (extSorts.isEmpty) "" else s"ORDER BY ${", ")}"}
| LIMIT ${limit.flatMap(p ⇒ Some(p.getLimit)).getOrElse(DFLT_LIMIT)}
|""".stripMargin.split(" ").map(_.trim).filter(_.nonEmpty).mkString(" "),
parameters = extParams