blob: c17a564be3f9a325a71f0cb52fad87ef095be67c [file] [log] [blame]
/*
* Copyright 2019 WeBank
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.webank.wedatasphere.linkis.orchestrator.execution.impl
import com.webank.wedatasphere.linkis.orchestrator.Orchestrator
import com.webank.wedatasphere.linkis.orchestrator.conf.OrchestratorConfiguration
import com.webank.wedatasphere.linkis.orchestrator.execution.ExecTaskRunner
import com.webank.wedatasphere.linkis.orchestrator.plugin.UserParallelOrchestratorPlugin
import scala.collection.mutable
trait UserTaskRunnerQueue {
def add(taskRunner: ExecTaskRunner, userRunningNumber: UserRunningNumber): Unit
def addAll(taskRunners: Array[ExecTaskRunner], userRunningNumber: UserRunningNumber): Unit
def takeTaskRunner(max: Int): Array[UserTaskRunner]
def takeTaskRunnerAll(): Array[UserTaskRunner]
}
class UserTaskRunnerPriorityQueue extends UserTaskRunnerQueue {
private implicit val ord: Ordering[UserTaskRunner] = Ordering.by(_.getScore())
private val priorityQueue = new mutable.PriorityQueue[UserTaskRunner]()
private val DEFAULT_MAX_RUNNING = OrchestratorConfiguration.ORCHESTRATOR_USER_MAX_RUNNING.getValue
private val userParallelOrchestratorPlugin = Orchestrator.getOrchestrator.getOrchestratorContext.getOrchestratorPlugins
.find(_.isInstanceOf[UserParallelOrchestratorPlugin]).map(_.asInstanceOf[UserParallelOrchestratorPlugin])
override def takeTaskRunner(max: Int): Array[UserTaskRunner] = {
val arr = priorityQueue.dequeueAll.toArray
if (arr.length > max) {
arr.splitAt(max)._1
} else {
arr
}
}
override def takeTaskRunnerAll(): Array[UserTaskRunner] = {
priorityQueue.dequeueAll
}
override def add(taskRunner: ExecTaskRunner, userRunningNumber: UserRunningNumber): Unit = {
val astContext = taskRunner.task.getTaskDesc.getOrigin.getASTOrchestration.getASTContext
val user = astContext.getExecuteUser
val labels = astContext.getLabels
val maxRunningNumber = if (userParallelOrchestratorPlugin.isDefined) userParallelOrchestratorPlugin.get.getUserMaxRunningJobs(user, labels) else {
DEFAULT_MAX_RUNNING
}
val runningNumber = userRunningNumber.getRunningNumber(user)
priorityQueue += UserTaskRunner(user, maxRunningNumber, runningNumber, taskRunner)
}
override def addAll(taskRunners: Array[ExecTaskRunner], userRunningNumber: UserRunningNumber): Unit = {
val runners = taskRunners.map { taskRunner =>
val astContext = taskRunner.task.getTaskDesc.getOrigin.getASTOrchestration.getASTContext
val user = astContext.getExecuteUser
val labels = astContext.getLabels
val maxRunningNumber = if (userParallelOrchestratorPlugin.isDefined) userParallelOrchestratorPlugin.get.getUserMaxRunningJobs(user, labels) else {
DEFAULT_MAX_RUNNING
}
UserTaskRunner(user, maxRunningNumber, userRunningNumber.addNumber(user), taskRunner)
}.filter(userTaskRunner => userTaskRunner.maxRunningNumber > userTaskRunner.runningNumber)
priorityQueue ++= runners
}
}