| import mxnet as mx | |
| import json | |
| def check_metric(metric, *args, **kwargs): | |
| metric = mx.metric.create(metric, *args, **kwargs) | |
| str_metric = json.dumps(metric.get_config()) | |
| metric2 = mx.metric.create(str_metric) | |
| assert metric.get_config() == metric2.get_config() | |
| def test_metrics(): | |
| check_metric('acc', axis=0) | |
| check_metric('f1') | |
| check_metric('perplexity', -1) | |
| composite = mx.metric.create(['acc', 'f1']) | |
| check_metric(composite) | |
| if __name__ == '__main__': | |
| import nose | |
| nose.runmodule() |