| /* | |
| * Licensed to the Apache Software Foundation (ASF) under one | |
| * or more contributor license agreements. See the NOTICE file | |
| * distributed with this work for additional information | |
| * regarding copyright ownership. The ASF licenses this file | |
| * to you under the Apache License, Version 2.0 (the | |
| * "License"); you may not use this file except in compliance | |
| * with the License. You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, | |
| * software distributed under the License is distributed on an | |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
| * KIND, either express or implied. See the License for the | |
| * specific language governing permissions and limitations | |
| * under the License. | |
| * | |
| */ | |
| #include <vector> | |
| #include "mxnet/c_api.h" | |
| #include "dmlc/logging.h" | |
| #include "mxnet-cpp/MxNetCpp.h" | |
| using namespace mxnet::cpp; | |
| enum TypeFlag { | |
| kFloat32 = 0, | |
| kFloat64 = 1, | |
| kFloat16 = 2, | |
| kUint8 = 3, | |
| kInt32 = 4, | |
| kInt8 = 5, | |
| kInt64 = 6, | |
| }; | |
| /* | |
| * The file is used for testing if there exist type inconsistency | |
| * when using Copy API to create a new NDArray. | |
| * By running: build/test_ndarray. | |
| */ | |
| int main(int argc, char** argv) { | |
| std::vector<mx_uint> shape1{128, 2, 32}; | |
| Shape shape2(32, 8, 64); | |
| int gpu_count = 0; | |
| if (MXGetGPUCount(&gpu_count) != 0) { | |
| LOG(ERROR) << "MXGetGPUCount failed"; | |
| return -1; | |
| } | |
| Context context = (gpu_count > 0) ? Context::gpu() : Context::cpu(); | |
| NDArray src1(shape1, context, true, kFloat16); | |
| NDArray src2(shape2, context, false, kInt8); | |
| NDArray dst1, dst2; | |
| dst1 = src1.Copy(context); | |
| dst2 = src2.Copy(context); | |
| NDArray::WaitAll(); | |
| CHECK_EQ(src1.GetDType(), dst1.GetDType()); | |
| CHECK_EQ(src2.GetDType(), dst2.GetDType()); | |
| return 0; | |
| } |