blob: a48c934a7a61d55abab4b6b1e386c6dcb68f822b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nlpcraft
import org.apache.nlpcraft.model.tools.embedded.NCEmbeddedProbe
import org.apache.nlpcraft.model.tools.test.{NCTestClient, NCTestClientBuilder}
import org.apache.nlpcraft.probe.mgrs.model.NCModelManager
import org.junit.jupiter.api.TestInstance.Lifecycle
import org.junit.jupiter.api._
/**
*
*/
@TestInstance(Lifecycle.PER_CLASS)
abstract class NCTestContext {
private final val MDL_CLASS = classOf[NCTestEnvironment]
private var cli: NCTestClient = _
private var probeStarted = false
@BeforeEach
@throws[Exception]
private def beforeEach(info: TestInfo): Unit = start0(() ⇒ getMethodAnnotation(info))
@BeforeAll
@throws[Exception]
private def beforeAll(info: TestInfo): Unit = start0(() ⇒ getClassAnnotation(info))
@AfterEach
@throws[Exception]
private def afterEach(info: TestInfo): Unit =
if (getMethodAnnotation(info).isDefined)
stop0()
@AfterAll
@throws[Exception]
private def afterAll(info: TestInfo): Unit =
if (getClassAnnotation(info).isDefined)
stop0()
private def getClassAnnotation(info: TestInfo) =
if (info.getTestClass.isPresent) Option(info.getTestClass.get().getAnnotation(MDL_CLASS)) else None
private def getMethodAnnotation(info: TestInfo): Option[NCTestEnvironment] =
if (info.getTestMethod.isPresent) Option(info.getTestMethod.get().getAnnotation(MDL_CLASS)) else None
@throws[Exception]
private def start0(extract: ()Option[NCTestEnvironment]): Unit =
extract() match {
case Some(ann)
if (probeStarted || cli != null)
throw new IllegalStateException(
"Model already initialized. " +
s"Note that '@${classOf[NCTestEnvironment].getSimpleName}' can be set for class or method, " +
s"but not both of them."
)
preProbeStart()
probeStarted = false
if (NCEmbeddedProbe.start(null, ann.model().getName)) {
probeStarted = true
if (ann.startClient()) {
cli = new NCTestClientBuilder().newBuilder.build
cli.open(NCModelManager.getAllModels().head.model.getId)
}
}
case None// No-op.
}
@throws[Exception]
private def stop0(): Unit = {
if (cli != null) {
cli.close()
cli = null
}
if (probeStarted) {
NCEmbeddedProbe.stop()
probeStarted = false
afterProbeStop()
}
}
protected def preProbeStart(): Unit = { }
protected def afterProbeStop(): Unit = { }
final protected def getClient: NCTestClient = {
if (cli == null)
throw new IllegalStateException("Client is not started.")
cli
}
}