blob: f1bf13d93d3765e8120e9dc4e739575db19aedc6 [file] [log] [blame]
#!/usr/bin/env python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: skip-file
import sys
sys.path.insert(0, "../../python/")
import mxnet as mx
kv = mx.kv.create('dist_async')
my_rank = kv.rank
nworker = kv.num_workers
def test_gluon_trainer_type():
def check_trainer_kv_update(weight_stype, update_on_kv):
x = mx.gluon.Parameter('x', shape=(10,1), lr_mult=1.0, stype=weight_stype)
x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
try:
trainer = mx.gluon.Trainer([x], 'sgd', {'learning_rate': 0.1},
kvstore=kv, update_on_kvstore=update_on_kv)
trainer._init_kvstore()
assert trainer._kv_initialized
assert trainer._update_on_kvstore is True
except ValueError:
assert update_on_kv is False
check_trainer_kv_update('default', False)
check_trainer_kv_update('default', True)
check_trainer_kv_update('default', None)
check_trainer_kv_update('row_sparse', False)
check_trainer_kv_update('row_sparse', True)
check_trainer_kv_update('row_sparse', None)
print('worker ' + str(my_rank) + ' passed test_gluon_trainer_type')
if __name__ == "__main__":
test_gluon_trainer_type()