| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import time |
| import mxnet as mx |
| from mxnet.ndarray.sparse import adam_update |
| import numpy as np |
| import argparse |
| |
| mx.random.seed(0) |
| np.random.seed(0) |
| |
| parser = argparse.ArgumentParser(description='Benchmark adam updater') |
| parser.add_argument('--dim-in', type=int, default=240000, help='weight.shape[0]') |
| parser.add_argument('--dim-out', type=int, default=512, help='weight.shape[1]') |
| parser.add_argument('--nnr', type=int, default=5000, help='grad.indices.shape[0]') |
| parser.add_argument('--repeat', type=int, default=1000, help='num repeat') |
| parser.add_argument('--dense-grad', action='store_true', |
| help='if set to true, both gradient and weight are dense.') |
| parser.add_argument('--dense-state', action='store_true', |
| help='if set to true, states are dense, indicating standard update') |
| parser.add_argument('--cpu', action='store_true') |
| |
| |
| args = parser.parse_args() |
| dim_in = args.dim_in |
| dim_out = args.dim_out |
| nnr = args.nnr |
| ctx = mx.cpu() if args.cpu else mx.gpu() |
| |
| ones = mx.nd.ones((dim_in, dim_out), ctx=ctx) |
| |
| if not args.dense_grad: |
| weight = ones.tostype('row_sparse') |
| indices = np.arange(dim_in) |
| np.random.shuffle(indices) |
| indices = np.unique(indices[:nnr]) |
| indices = mx.nd.array(indices, ctx=ctx) |
| grad = mx.nd.sparse.retain(weight, indices) |
| else: |
| weight = ones.copy() |
| grad = ones.copy() |
| |
| if args.dense_state: |
| mean = ones.copy() |
| else: |
| mean = ones.tostype('row_sparse') |
| |
| var = mean.copy() |
| |
| # warmup |
| for i in range(10): |
| adam_update(weight, grad, mean, var, out=weight, lr=1, wd=0, beta1=0.9, |
| beta2=0.99, rescale_grad=0.5, epsilon=1e-8) |
| weight.wait_to_read() |
| |
| # measure speed |
| a = time.time() |
| for i in range(args.repeat): |
| adam_update(weight, grad, mean, var, out=weight, lr=1, wd=0, beta1=0.9, |
| beta2=0.99, rescale_grad=0.5, epsilon=1e-8) |
| weight.wait_to_read() |
| b = time.time() |
| print(b - a) |