blob: 98740b05ee3278410282d0e9bfb86236ef5b161a [file] [log] [blame]
import mxnet as mx
import json
def check_metric(metric, *args, **kwargs):
metric = mx.metric.create(metric, *args, **kwargs)
str_metric = json.dumps(metric.get_config())
metric2 = mx.metric.create(str_metric)
assert metric.get_config() == metric2.get_config()
def test_metrics():
check_metric('acc', axis=0)
check_metric('f1')
check_metric('perplexity', -1)
composite = mx.metric.create(['acc', 'f1'])
check_metric(composite)
if __name__ == '__main__':
import nose
nose.runmodule()