| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """Should be run with valgrind to get memory consumption |
| for sparse format storage and dot operators. This script can be |
| used for memory benchmarking on CPU only""" |
| import ctypes |
| import sys |
| import argparse |
| import mxnet as mx |
| from mxnet.test_utils import rand_ndarray |
| from mxnet.base import check_call, _LIB |
| |
| |
| def parse_args(): |
| """ Function to parse arguments |
| """ |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--lhs-row-dim", |
| required=True, |
| help="Provide batch_size") |
| parser.add_argument("--lhs-col-dim", |
| required=True, |
| help="Provide feature_dim") |
| parser.add_argument("--rhs-col-dim", |
| required=True, |
| help="Provide output_dim") |
| parser.add_argument("--density", |
| required=True, |
| help="Density for lhs") |
| parser.add_argument("--num-omp-threads", type=int, |
| default=1, help="number of omp threads to set in MXNet") |
| parser.add_argument("--lhs-stype", default="csr", |
| choices=["csr", "default", "row_sparse"], |
| help="stype for lhs", |
| required=True) |
| parser.add_argument("--rhs-stype", default="default", |
| choices=["default", "row_sparse"], |
| help="rhs stype", |
| required=True) |
| parser.add_argument("--only-storage", |
| action="store_true", |
| help="only storage") |
| parser.add_argument("--rhs-density", |
| help="rhs_density") |
| return parser.parse_args() |
| |
| |
| def main(): |
| args = parse_args() |
| lhs_row_dim = int(args.lhs_row_dim) |
| lhs_col_dim = int(args.lhs_col_dim) |
| rhs_col_dim = int(args.rhs_col_dim) |
| density = float(args.density) |
| lhs_stype = args.lhs_stype |
| rhs_stype = args.rhs_stype |
| if args.rhs_density: |
| rhs_density = float(args.rhs_density) |
| else: |
| rhs_density = density |
| dot_func = mx.nd.sparse.dot if lhs_stype == "csr" else mx.nd.dot |
| check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(args.num_omp_threads))) |
| bench_dot(lhs_row_dim, lhs_col_dim, rhs_col_dim, density, |
| rhs_density, dot_func, False, lhs_stype, rhs_stype, args.only_storage) |
| |
| def bench_dot(lhs_row_dim, lhs_col_dim, rhs_col_dim, density, |
| rhs_density, dot_func, trans_lhs, lhs_stype, |
| rhs_stype, only_storage, distribution="uniform"): |
| """ Benchmarking both storage and dot |
| """ |
| lhs_nd = rand_ndarray((lhs_row_dim, lhs_col_dim), lhs_stype, density, distribution=distribution) |
| if not only_storage: |
| rhs_nd = rand_ndarray((lhs_col_dim, rhs_col_dim), rhs_stype, |
| density=rhs_density, distribution=distribution) |
| out = dot_func(lhs_nd, rhs_nd, trans_lhs) |
| mx.nd.waitall() |
| |
| |
| if __name__ == '__main__': |
| sys.exit(main()) |