import mxnet as mx | |
import json | |
def check_metric(metric, *args, **kwargs): | |
metric = mx.metric.create(metric, *args, **kwargs) | |
str_metric = json.dumps(metric.get_config()) | |
metric2 = mx.metric.create(str_metric) | |
assert metric.get_config() == metric2.get_config() | |
def test_metrics(): | |
check_metric('acc', axis=0) | |
check_metric('f1') | |
check_metric('perplexity', -1) | |
composite = mx.metric.create(['acc', 'f1']) | |
check_metric(composite) | |
if __name__ == '__main__': | |
import nose | |
nose.runmodule() |