blob: 72f2bfd04a27c64cd0b18fecdde24cf534906e0b [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import time
import mxnet as mx
from mxnet.ndarray.sparse import adam_update
import numpy as np
import argparse
mx.random.seed(0)
np.random.seed(0)
parser = argparse.ArgumentParser(description='Benchmark adam updater')
parser.add_argument('--dim-in', type=int, default=240000, help='weight.shape[0]')
parser.add_argument('--dim-out', type=int, default=512, help='weight.shape[1]')
parser.add_argument('--nnr', type=int, default=5000, help='grad.indices.shape[0]')
parser.add_argument('--repeat', type=int, default=1000, help='num repeat')
parser.add_argument('--dense-grad', action='store_true',
help='if set to true, both gradient and weight are dense.')
parser.add_argument('--dense-state', action='store_true',
help='if set to true, states are dense, indicating standard update')
parser.add_argument('--cpu', action='store_true')
args = parser.parse_args()
dim_in = args.dim_in
dim_out = args.dim_out
nnr = args.nnr
ctx = mx.cpu() if args.cpu else mx.gpu()
ones = mx.nd.ones((dim_in, dim_out), ctx=ctx)
if not args.dense_grad:
weight = ones.tostype('row_sparse')
indices = np.arange(dim_in)
np.random.shuffle(indices)
indices = np.unique(indices[:nnr])
indices = mx.nd.array(indices, ctx=ctx)
grad = mx.nd.sparse.retain(weight, indices)
else:
weight = ones.copy()
grad = ones.copy()
if args.dense_state:
mean = ones.copy()
else:
mean = ones.tostype('row_sparse')
var = mean.copy()
# warmup
for i in range(10):
adam_update(weight, grad, mean, var, out=weight, lr=1, wd=0, beta1=0.9,
beta2=0.99, rescale_grad=0.5, epsilon=1e-8)
weight.wait_to_read()
# measure speed
a = time.time()
for i in range(args.repeat):
adam_update(weight, grad, mean, var, out=weight, lr=1, wd=0, beta1=0.9,
beta2=0.99, rescale_grad=0.5, epsilon=1e-8)
weight.wait_to_read()
b = time.time()
print(b - a)