blob: 5f825de7748366a41472e0777abb2840f8dc6103 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import print_function
import sys
import os.path, re
from io import BytesIO, StringIO
import platform
blacklist = [
'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh', 'cuda_runtime.h', 'cudnn.h',
'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h', 'glog/logging.h', 'io/azure_filesys.h',
'io/hdfs_filesys.h', 'io/s3_filesys.h', 'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h',
'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h', 'NvInfer.h', 'nvml.h',
'opencv2/opencv.hpp', 'sys/stat.h', 'sys/types.h', 'cuda.h', 'cuda_fp16.h', 'omp.h',
'onnx/onnx.pb.h', 'execinfo.h', 'packet/sse-inl.h', 'emmintrin.h', 'thrust/device_vector.h',
'cusolverDn.h', 'internal/concurrentqueue_internal_debug.h', 'relacy/relacy_std.hpp',
'relacy_shims.h', 'ittnotify.h', 'shared_mutex', 'nvToolsExt.h', 'dmlc/build_config.h',
'sys/isa_defs.h'
]
minimum = int(sys.argv[6]) if len(sys.argv) > 5 else 0
android = int(sys.argv[7]) if len(sys.argv) > 6 else 0
# blacklist linear algebra headers when building without blas.
if minimum != 0:
blacklist.append('linalg.h')
if platform.system() != 'Darwin':
blacklist.append('TargetConditionals.h')
if platform.system() != 'Windows':
blacklist.append('windows.h')
blacklist.append('process.h')
blacklist.append('Shlwapi.h')
if platform.system() == 'Windows':
blacklist.append('unistd.h')
if 'freebsd' not in sys.platform:
blacklist.append('sys/endian.h')
def get_sources(def_file):
sources = []
files = []
visited = set()
mxnet_path = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir))
for line in open(def_file):
files = files + line.strip().split(' ')
for f in files:
f = f.strip()
if not f or f.endswith('.o:') or f == '\\': continue
f = os.path.realpath(f)
fn = os.path.relpath(f)
if f.startswith(mxnet_path) and fn not in visited:
sources.append(fn)
visited.add(fn)
return sources
sources = get_sources(sys.argv[1])
def find_source(name, start, stage):
candidates = []
for x in sources:
if x == name:
candidates.append(x)
elif name.endswith(".cc") and x.endswith('/' + name):
if x.startswith("../" + stage):
candidates.append(x)
elif x.endswith('/' + name):
candidates.append(x)
#if x == name or x.endswith('/' + name): candidates.append(x)
if not candidates: return ''
if len(candidates) == 1: return candidates[0]
for x in candidates:
if '3rdparty' in x:
# make sure to compare the directory name after 3rdparty
if x.split('/')[2] == start.split('/')[2]: return x
else:
if x.split('/')[1] == start.split('/')[1]: return x
return ''
re1 = re.compile('<([./a-zA-Z0-9_-]*)>')
re2 = re.compile('"([./a-zA-Z0-9_-]*)"')
re3 = re.compile('DMLC_EXECINFO_H')
sysheaders = []
history = set([])
out = BytesIO()
def expand(x, pending, stage):
"""
Expand the pending files in the current stage.
Parameters
----------
x: str
The file to expand.
pending : str
The list of pending files to expand.
stage: str
The current stage for file expansion, used for matching the prefix of files.
"""
if x in history and x not in ['mshadow/mshadow/expr_scalar-inl.h']: # MULTIPLE includes
return
if x in pending:
#print('loop found: {} in {}'.format(x, pending))
return
whtspace = ' ' * expand.treeDepth
expand.fileCount += 1
comment = u"//=====[{:3d}] STAGE:{:>4} {}EXPANDING: {} =====\n\n".format(expand.fileCount, stage, whtspace, x)
out.write(comment.encode('ascii'))
print(comment)
with open(x, 'rb') as x_h:
for line in x_h.readlines():
uline = line.decode('utf-8')
if '#define DMLC_LOG_STACK_TRACE 1' in uline.strip():
# Do not enable stacktrace logging
continue
if uline.find('#include') < 0:
out.write(line)
continue
if uline.strip().find('#include') > 0:
print(uline)
continue
m = re1.search(uline)
if not m:
m = re2.search(uline)
if m:
path = m.groups()[0]
else:
m = re3.search(uline)
if m:
path = 'execinfo.h'
else:
print(uline + ' not found')
continue
h = path.strip('./') if "../3rdparty/" not in path else path
if h.endswith('complex.h') and x.endswith('openblas_config.h'):
source = ''
elif h.startswith('ps/'):
source = '../3rdparty/ps-lite/include/' + h
else:
source = find_source(h, x, stage)
if not source:
if (h not in blacklist and
h not in sysheaders and
'mkl' not in h and
'nnpack' not in h and
'tensorrt' not in h and
not h.endswith('.cuh')): sysheaders.append(h)
else:
expand.treeDepth += 1
expand(source, pending + [x], stage)
expand.treeDepth -= 1
out.write(u"//===== EXPANDED : {} =====\n\n".format(x).encode('ascii'))
history.add(x)
# Vars to keep track of number of files expanded.
# Used in printing informative comments.
expand.treeDepth = 0
expand.fileCount = 0
# Expand the stages
expand(sys.argv[2], [], "3rdparty/dmlc-core")
expand(sys.argv[3], [], "3rdparty/tvm/nnvm")
expand(sys.argv[4], [], "src")
# Write to amalgamation file
with open(sys.argv[5], 'wb') as f:
if minimum != 0:
sysheaders.remove('cblas.h')
f.write(b"#define MSHADOW_STAND_ALONE 1\n")
f.write(b"#define MSHADOW_USE_SSE 0\n")
f.write(b"#define MSHADOW_USE_CBLAS 0\n")
f.write(
b"""
#if defined(__MACH__)
#include <mach/clock.h>
#include <mach/mach.h>
#endif
#if !defined(__WIN32__)
#include <sys/stat.h>
#include <sys/types.h>
#if !defined(__ANDROID__) && (!defined(MSHADOW_USE_SSE) || MSHADOW_USE_SSE == 1)
#include <emmintrin.h>
#endif
#endif
\n"""
)
if minimum != 0 and android != 0 and 'complex.h' not in sysheaders:
sysheaders.append('complex.h')
for k in sorted(sysheaders):
f.write("#include <{}>\n".format(k).encode('ascii'))
f.write(b'\n')
f.write(out.getvalue())
f.write(b'\n')
for src in sources:
if src not in history and not src.endswith('.o'):
print('Not processed:', src)