blob: d06251cadd543077621ad2fa5a88e571f31ddcea [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.mxnetexamples.imclassification
import ml.dmlc.mxnet.Callback.Speedometer
import ml.dmlc.mxnet._
import ml.dmlc.mxnet.optimizer.SGD
import org.slf4j.LoggerFactory
object ModelTrain {
private val logger = LoggerFactory.getLogger(classOf[ModelTrain])
// scalastyle:off parameterNum
def fit(dataDir: String, batchSize: Int, numExamples: Int, devs: Array[Context],
network: Symbol, dataLoader: (String, Int, KVStore) => (DataIter, DataIter),
kvStore: String, numEpochs: Int, modelPrefix: String = null, loadEpoch: Int = -1,
lr: Float = 0.1f, lrFactor: Float = 1f, lrFactorEpoch: Float = 1f,
clipGradient: Float = 0f, monitorSize: Int = -1): Unit = {
// kvstore
var kv = KVStore.create(kvStore)
// load model
val modelPrefixWithRank =
if (modelPrefix == null) null
else modelPrefix + s"-${kv.rank}"
val (argParams, auxParams, beginEpoch) =
if (loadEpoch >= 0) {
require(modelPrefixWithRank != null)
val tmp = FeedForward.load(modelPrefix, loadEpoch)
(tmp.getArgParams, tmp.getAuxParams, loadEpoch)
} else {
(null, null, 0)
}
// save model
val checkpoint: EpochEndCallback =
if (modelPrefix == null) null
else new EpochEndCallback {
override def invoke(epoch: Int, symbol: Symbol,
argParams: Map[String, NDArray],
auxStates: Map[String, NDArray]): Unit = {
Model.saveCheckpoint(modelPrefix, epoch + 1, symbol, argParams, auxParams)
}
}
// data
val (train, validation) = dataLoader(dataDir, batchSize, kv)
// train
val epochSize =
if (kvStore == "dist_sync") numExamples / batchSize / kv.numWorkers
else numExamples / batchSize
val lrScheduler =
if (lrFactor < 1f) {
new FactorScheduler(step = Math.max((epochSize * lrFactorEpoch).toInt, 1),
factor = lrFactor)
} else {
null
}
val optimizer: Optimizer = new SGD(learningRate = lr,
lrScheduler = lrScheduler, clipGradient = clipGradient,
momentum = 0.9f, wd = 0.00001f)
// disable kvstore for single device
if (kv.`type`.contains("local") && (devs.length == 1 || devs(0).deviceType != "gpu")) {
kv.dispose()
kv = null
}
val model = new FeedForward(ctx = devs,
symbol = network,
numEpoch = numEpochs,
optimizer = optimizer,
initializer = new Xavier(factorType = "in", magnitude = 2.34f),
argParams = argParams,
auxParams = auxParams,
beginEpoch = beginEpoch,
epochSize = epochSize)
if (monitorSize > 0) {
model.setMonitor(new Monitor(monitorSize))
}
model.fit(trainData = train,
evalData = validation,
evalMetric = new Accuracy(),
kvStore = kv,
batchEndCallback = new Speedometer(batchSize, 50),
epochEndCallback = checkpoint)
if (kv != null) {
kv.dispose()
}
}
// scalastyle:on parameterNum
}
class ModelTrain