blob: a8a3d26fac3d9a7ccb900dfc15f6d5c0c5156c31 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file mkldnn.cc
* \brief test functions in mkldnn.
* \author Da Zheng
*/
#if MXNET_USE_MKLDNN == 1
#include "gtest/gtest.h"
#include "../../src/operator/nn/mkldnn/mkldnn_base-inl.h"
bool test_mem_align(void *mem, size_t size, size_t alignment, size_t space) {
void *ret1, *ret2;
size_t space1, space2;
space1 = space;
space2 = space;
ret1 = mxnet::AlignMem(mem, size, alignment, &space1);
ret2 = std::align(alignment, size, mem, space2);
EXPECT_EQ(ret1, ret2);
EXPECT_EQ(space1, space2);
return ret1 == ret2;
}
TEST(MKLDNN_UTIL_FUNC, AlignMem) {
size_t alignment = 4096;
void *mem;
size_t size, space;
// When mem has been aligned.
mem = reinterpret_cast<void *>(0x10000);
size = 1000;
space = 10000;
test_mem_align(mem, size, alignment, space);
// When mem isn't aligned and we have enough space for alignment.
mem = reinterpret_cast<void *>(0x10010);
size = 1000;
space = 10000;
test_mem_align(mem, size, alignment, space);
// When mem isn't aligned and we don't have enough memory for alignment
mem = reinterpret_cast<void *>(0x10010);
size = 1000;
space = 1001;
test_mem_align(mem, size, alignment, space);
for (size_t i = 0; i < 10000; i++) {
mem = reinterpret_cast<void *>(random());
size = random() % 2000;
space = random() % 2000;
test_mem_align(mem, size, alignment, space);
}
}
#endif