blob: 329e0424f380f8d5ef02e89cc5cb8c9f967f7eb1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nlpcraft.examples.sql
import java.util
import com.github.difflib.text.DiffRowGenerator
import com.github.vertical_blank.sqlformatter.SqlFormatter
import com.google.gson.Gson
import com.google.gson.reflect.TypeToken
import com.jakewharton.fliptables.FlipTable
import org.apache.nlpcraft.examples.sql.db.SqlServer
import org.apache.nlpcraft.model.tools.test.{NCTestClient, NCTestClientBuilder}
import org.apache.nlpcraft.probe.embedded.NCEmbeddedProbe
import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
import scala.collection.JavaConverters._
import scala.compat.java8.OptionConverters._
/**
* SQL model test.
*
* @see SqlModel
*/
class SqlTest {
private val GSON = new Gson
private val TYPE_RESP = new TypeToken[util.Map[String, Object]]() {}.getType
private val NORM = Seq("\n", "\r", "\t")
private val DIFF = DiffRowGenerator.create.
showInlineDiffs(true).
inlineDiffByWord(true).
oldTag((_: java.lang.Boolean)"~").
newTag((_: java.lang.Boolean)"**").
build
private var client: NCTestClient = _
case class Case(texts: Seq[String], sql: String)
@BeforeEach
def setUp(): Unit = {
SqlServer.start()
NCEmbeddedProbe.start(classOf[SqlModel])
client = new NCTestClientBuilder().newBuilder.setResponseLog(false).build
client.open("sql.model.id")
}
@AfterEach
def tearDown(): Unit = {
if (client != null)
client.close()
NCEmbeddedProbe.stop()
SqlServer.stop()
}
private def normalize(s: String): String =
NORM.
foldLeft(s) { (res, s) ⇒ res.replaceAll(s, " ") }.
split(" ").
map(_.trim).
filter(_.nonEmpty).
mkString(" ")
private def toPretty(s: String): util.List[String] = SqlFormatter.format(s).split("\n").toSeq.asJava
private def check(multiLineOut: Boolean, cases: Case*): Unit = {
val errs = collection.mutable.LinkedHashMap.empty[String, String]
cases.
flatMap(c ⇒ c.texts.map(t ⇒ t → normalize(c.sql))).
foreach {
case (txt, expSqlNorm)
val res = client.ask(txt)
if (res.isOk) {
require(res.getResult.asScala.isDefined)
val m: util.Map[String, Object] = GSON.fromJson(res.getResult.get, TYPE_RESP)
val err = m.get("error")
if (err != null)
errs += txt → err.toString
else {
val resSqlNorm = normalize(m.asScala("sql").asInstanceOf[String])
if (resSqlNorm != expSqlNorm) {
if (multiLineOut) {
val rows = DIFF.generateDiffRows(toPretty(expSqlNorm), toPretty(resSqlNorm)).asScala
val table =
FlipTable.of(
Array("Expected", "Real"),
rows.map(p ⇒ Array(p.getOldLine, p.getNewLine)).toArray
)
errs += txt → s"Unexpected SQL:\n$table"
}
else {
val rows = DIFF.generateDiffRows(Seq(expSqlNorm).asJava, Seq(resSqlNorm).asJava).asScala
require(rows.size == 1)
val row = rows.head
errs += txt →
s"""Unexpected SQL (expected vs real)
|${row.getOldLine}
|${row.getNewLine}
""".stripMargin
}
}
}
}
else {
require(res.getResultError.isPresent)
errs += txt → res.getResultError.get
}
}
if (errs.nonEmpty) {
errs.foreach { case (txt, err)System.err.println(s"Text: $txt\nError: $err\n") }
throw new Exception(s"Test finished with errors [passed=${cases.size - errs.size}, failed=${errs.size}]")
}
else
println("Passed")
}
@Test
def testConversation(): Unit = {
check(
true,
Case(
Seq(
"last year Exotic Liquids orders",
// Second and third sentences are the qualifying questions for first.
// Second by date, third by value.
// See logic implemented in org.apache.nlpcraft.examples.sql.SqlModel.onMatchedIntent
"last month",
"Norske Meierier"
),
"""SELECT
| suppliers.company_name,
| orders.order_date,
| orders.order_id,
| orders.required_date,
| suppliers.supplier_id,
| suppliers.contact_name,
| customers.customer_id,
| customers.company_name,
| customers.contact_name,
| employees.employee_id,
| employees.last_name,
| employees.first_name,
| order_details.unit_price,
| order_details.quantity,
| order_details.discount,
| products.product_id,
| products.product_name,
| products.quantity_per_unit,
| shippers.shipper_id,
| shippers.company_name,
| shippers.phone
|FROM
| order_details
| INNER JOIN orders ON order_details.order_id = orders.order_id
| INNER JOIN products ON order_details.product_id = products.product_id
| LEFT JOIN customers ON orders.customer_id = customers.customer_id
| LEFT JOIN shippers ON orders.ship_via = shippers.shipper_id
| LEFT JOIN employees ON orders.employee_id = employees.employee_id
| LEFT JOIN suppliers ON products.supplier_id = suppliers.supplier_id
|WHERE
| suppliers.company_name IN (?)
| AND orders.order_date >= ?
| AND orders.order_date <= ?
|ORDER BY
| orders.order_id DESC,
| suppliers.supplier_id DESC
|LIMIT
| 1000
""".stripMargin
)
)
}
@Test
def test() {
check(
true,
Case(
Seq(
"order date, please!",
"show me the order dates",
"list dates of orders"
),
"""SELECT
| orders.order_date,
| orders.order_id,
| orders.required_date,
| customers.customer_id,
| customers.company_name,
| customers.contact_name,
| employees.employee_id,
| employees.last_name,
| employees.first_name,
| shippers.shipper_id,
| shippers.company_name,
| shippers.phone
|FROM
| orders
| LEFT JOIN customers ON orders.customer_id = customers.customer_id
| LEFT JOIN shippers ON orders.ship_via = shippers.shipper_id
| LEFT JOIN employees ON orders.employee_id = employees.employee_id
|ORDER BY
| orders.order_id DESC
|LIMIT
| 1000
""".stripMargin
),
Case(
Seq(
"give me orders for the last month"
),
"""SELECT
| orders.order_date,
| orders.order_id,
| orders.required_date,
| customers.customer_id,
| customers.company_name,
| customers.contact_name,
| employees.employee_id,
| employees.last_name,
| employees.first_name,
| shippers.shipper_id,
| shippers.company_name,
| shippers.phone
|FROM
| orders
| LEFT JOIN customers ON orders.customer_id = customers.customer_id
| LEFT JOIN shippers ON orders.ship_via = shippers.shipper_id
| LEFT JOIN employees ON orders.employee_id = employees.employee_id
|WHERE
| orders.order_date >= ?
| AND orders.order_date <= ?
|ORDER BY
| orders.order_id DESC
|LIMIT
| 1000
""".stripMargin
),
Case(
Seq(
"list shippers data"
),
"""SELECT
| shippers.shipper_id,
| shippers.company_name,
| shippers.phone
|FROM
| shippers
|ORDER BY
| shippers.shipper_id DESC
|LIMIT
| 1000
""".stripMargin
),
Case(
Seq(
"show me orders with freight more than 10 for the last year"
),
"""SELECT
| orders.freight,
| orders.order_date,
| orders.order_id,
| orders.required_date,
| customers.customer_id,
| customers.company_name,
| customers.contact_name,
| employees.employee_id,
| employees.last_name,
| employees.first_name,
| shippers.shipper_id,
| shippers.company_name,
| shippers.phone
|FROM
| orders
| LEFT JOIN customers ON orders.customer_id = customers.customer_id
| LEFT JOIN shippers ON orders.ship_via = shippers.shipper_id
| LEFT JOIN employees ON orders.employee_id = employees.employee_id
|WHERE
| orders.freight > ?
| AND orders.order_date >= ?
| AND orders.order_date <= ?
|ORDER BY
| orders.order_id DESC
|LIMIT
| 1000
""".stripMargin
),
Case(
Seq(
"show territories data"
),
"""SELECT
| territories.territory_id,
| territories.territory_description,
| territories.region_id,
| region.region_id,
| region.region_description
|FROM
| territories
| INNER JOIN region ON territories.region_id = region.region_id
|ORDER BY
| territories.territory_id DESC
|LIMIT
| 1000
""".stripMargin
),
Case(
Seq(
"display employees territories"
),
"""SELECT
| employees.employee_id,
| employees.last_name,
| employees.first_name,
| territories.territory_id,
| territories.territory_description,
| territories.region_id,
| employee_territories.employee_id,
| employee_territories.territory_id,
| region.region_id,
| region.region_description
|FROM
| employee_territories
| INNER JOIN employees ON employee_territories.employee_id = employees.employee_id
| INNER JOIN territories ON employee_territories.territory_id = territories.territory_id
| INNER JOIN region ON territories.region_id = region.region_id
|ORDER BY
| employees.employee_id DESC,
| territories.territory_id DESC
|LIMIT
| 1000
""".stripMargin
),
Case(
Seq(
"get me 10 suppliers"
),
"""SELECT
| suppliers.supplier_id,
| suppliers.company_name,
| suppliers.contact_name
|FROM
| suppliers
|ORDER BY
| suppliers.supplier_id DESC
|LIMIT
| 10
""".stripMargin
),
Case(
Seq(
"last year Exotic Liquids orders"
),
"""SELECT
| suppliers.company_name,
| orders.order_date,
| orders.order_id,
| orders.required_date,
| suppliers.supplier_id,
| suppliers.contact_name,
| customers.customer_id,
| customers.company_name,
| customers.contact_name,
| employees.employee_id,
| employees.last_name,
| employees.first_name,
| order_details.unit_price,
| order_details.quantity,
| order_details.discount,
| products.product_id,
| products.product_name,
| products.quantity_per_unit,
| shippers.shipper_id,
| shippers.company_name,
| shippers.phone
|FROM
| order_details
| INNER JOIN orders ON order_details.order_id = orders.order_id
| INNER JOIN products ON order_details.product_id = products.product_id
| LEFT JOIN customers ON orders.customer_id = customers.customer_id
| LEFT JOIN shippers ON orders.ship_via = shippers.shipper_id
| LEFT JOIN employees ON orders.employee_id = employees.employee_id
| LEFT JOIN suppliers ON products.supplier_id = suppliers.supplier_id
|WHERE
| suppliers.company_name IN (?)
| AND orders.order_date >= ?
| AND orders.order_date <= ?
|ORDER BY
| orders.order_id DESC,
| suppliers.supplier_id DESC
|LIMIT
| 1000
""".stripMargin
),
Case(
Seq(
"give me the orders sorted by the ship date"
),
"""SELECT
| orders.shipped_date,
| orders.order_id,
| orders.order_date,
| orders.required_date,
| customers.customer_id,
| customers.company_name,
| customers.contact_name,
| employees.employee_id,
| employees.last_name,
| employees.first_name,
| shippers.shipper_id,
| shippers.company_name,
| shippers.phone
|FROM
| orders
| LEFT JOIN customers ON orders.customer_id = customers.customer_id
| LEFT JOIN shippers ON orders.ship_via = shippers.shipper_id
| LEFT JOIN employees ON orders.employee_id = employees.employee_id
|ORDER BY
| orders.shipped_date DESC
|LIMIT
| 1000
""".stripMargin
),
Case(
Seq(
"give me the orders sorted by the ship date in ascending order"
),
"""SELECT
| orders.shipped_date,
| orders.order_id,
| orders.order_date,
| orders.required_date,
| customers.customer_id,
| customers.company_name,
| customers.contact_name,
| employees.employee_id,
| employees.last_name,
| employees.first_name,
| shippers.shipper_id,
| shippers.company_name,
| shippers.phone
|FROM
| orders
| LEFT JOIN customers ON orders.customer_id = customers.customer_id
| LEFT JOIN shippers ON orders.ship_via = shippers.shipper_id
| LEFT JOIN employees ON orders.employee_id = employees.employee_id
|ORDER BY
| orders.shipped_date ASC
|LIMIT
| 1000
""".stripMargin
),
Case(
// Default sort (nlpcraft:sort shouldn't be found)
Seq(
"give me the orders sorted by date"
),
"""SELECT
| orders.order_date,
| orders.order_id,
| orders.required_date,
| customers.customer_id,
| customers.company_name,
| customers.contact_name,
| employees.employee_id,
| employees.last_name,
| employees.first_name,
| shippers.shipper_id,
| shippers.company_name,
| shippers.phone
|FROM
| orders
| LEFT JOIN customers ON orders.customer_id = customers.customer_id
| LEFT JOIN shippers ON orders.ship_via = shippers.shipper_id
| LEFT JOIN employees ON orders.employee_id = employees.employee_id
|ORDER BY
| orders.order_id DESC
|LIMIT
| 1000
""".stripMargin
),
Case(
Seq(
"What are the top orders for the last 2 weeks sorted by order quantity?"
),
"""SELECT
| order_details.quantity,
| orders.order_date,
| order_details.unit_price,
| order_details.discount,
| orders.order_id,
| orders.required_date,
| customers.customer_id,
| customers.company_name,
| customers.contact_name,
| employees.employee_id,
| employees.last_name,
| employees.first_name,
| products.product_id,
| products.product_name,
| products.quantity_per_unit,
| shippers.shipper_id,
| shippers.company_name,
| shippers.phone
|FROM
| order_details
| INNER JOIN orders ON order_details.order_id = orders.order_id
| INNER JOIN products ON order_details.product_id = products.product_id
| LEFT JOIN customers ON orders.customer_id = customers.customer_id
| LEFT JOIN shippers ON orders.ship_via = shippers.shipper_id
| LEFT JOIN employees ON orders.employee_id = employees.employee_id
|WHERE
| orders.order_date >= ?
| AND orders.order_date <= ?
|ORDER BY
| order_details.quantity DESC
|LIMIT
| 10
""".stripMargin
),
Case(
Seq(
"What are the top 25 orders for the last 2 weeks sorted by order quantity?"
),
"""SELECT
| order_details.quantity,
| orders.order_date,
| order_details.unit_price,
| order_details.discount,
| orders.order_id,
| orders.required_date,
| customers.customer_id,
| customers.company_name,
| customers.contact_name,
| employees.employee_id,
| employees.last_name,
| employees.first_name,
| products.product_id,
| products.product_name,
| products.quantity_per_unit,
| shippers.shipper_id,
| shippers.company_name,
| shippers.phone
|FROM
| order_details
| INNER JOIN orders ON order_details.order_id = orders.order_id
| INNER JOIN products ON order_details.product_id = products.product_id
| LEFT JOIN customers ON orders.customer_id = customers.customer_id
| LEFT JOIN shippers ON orders.ship_via = shippers.shipper_id
| LEFT JOIN employees ON orders.employee_id = employees.employee_id
|WHERE
| orders.order_date >= ?
| AND orders.order_date <= ?
|ORDER BY
| order_details.quantity DESC
|LIMIT
| 25
""".stripMargin
),
Case(
Seq(
"What are the best performing products for the last quarter?"
),
"""SELECT
| orders.freight,
| orders.order_date,
| orders.order_id,
| orders.required_date,
| products.product_id,
| products.product_name,
| products.quantity_per_unit,
| categories.category_id,
| categories.category_name,
| categories.description,
| customers.customer_id,
| customers.company_name,
| customers.contact_name,
| employees.employee_id,
| employees.last_name,
| employees.first_name,
| order_details.unit_price,
| order_details.quantity,
| order_details.discount,
| shippers.shipper_id,
| shippers.company_name,
| shippers.phone,
| suppliers.supplier_id,
| suppliers.company_name,
| suppliers.contact_name
|FROM
| order_details
| INNER JOIN orders ON order_details.order_id = orders.order_id
| INNER JOIN products ON order_details.product_id = products.product_id
| LEFT JOIN customers ON orders.customer_id = customers.customer_id
| LEFT JOIN shippers ON orders.ship_via = shippers.shipper_id
| LEFT JOIN employees ON orders.employee_id = employees.employee_id
| LEFT JOIN suppliers ON products.supplier_id = suppliers.supplier_id
| LEFT JOIN categories ON products.category_id = categories.category_id
|WHERE
| orders.order_date >= ?
| AND orders.order_date <= ?
|ORDER BY
| orders.freight DESC
|LIMIT
| 1000
""".stripMargin
),
Case(
Seq(
"What are the least performing categories for the last quarter?"
),
"""SELECT
| orders.freight,
| orders.order_date,
| categories.category_id,
| categories.category_name,
| categories.description,
| orders.order_id,
| orders.required_date,
| customers.customer_id,
| customers.company_name,
| customers.contact_name,
| employees.employee_id,
| employees.last_name,
| employees.first_name,
| order_details.unit_price,
| order_details.quantity,
| order_details.discount,
| products.product_id,
| products.product_name,
| products.quantity_per_unit,
| shippers.shipper_id,
| shippers.company_name,
| shippers.phone
|FROM
| order_details
| INNER JOIN orders ON order_details.order_id = orders.order_id
| INNER JOIN products ON order_details.product_id = products.product_id
| LEFT JOIN customers ON orders.customer_id = customers.customer_id
| LEFT JOIN shippers ON orders.ship_via = shippers.shipper_id
| LEFT JOIN employees ON orders.employee_id = employees.employee_id
| LEFT JOIN categories ON products.category_id = categories.category_id
|WHERE
| orders.order_date >= ?
| AND orders.order_date <= ?
|ORDER BY
| orders.freight ASC
|LIMIT
| 1000
""".stripMargin
)
)
}
}