blob: b6a0e41660f4af71c6758b1276243e0a84e8a7a4 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file torch_base.h
* \brief Torch interface.
* \author Junyuan Xie
*/
#ifndef PLUGIN_TORCH_TORCH_BASE_H_
#define PLUGIN_TORCH_TORCH_BASE_H_
#include <mxnet/base.h>
extern "C" {
#include <lua.h>
#include <luaT.h>
#include <lualib.h>
#include <TH/THStorage.h>
#include <TH/THTensor.h>
}
#if MXNET_USE_CUDA
extern "C" {
#include <THC/THCStorage.h>
#include <THC/THCTensor.h>
#include <THC/THCTensorCopy.h>
}
#endif // MXNET_USE_CUDA
#include <vector>
namespace mxnet {
class TorchState {
public:
lua_State* L;
TorchState();
static TorchState* ThreadSharedLuaState();
#if MXNET_USE_CUDA
THCState* CudaState() {
lua_getglobal(L, "cutorch");
CHECK(!lua_isnil(L, -1));
lua_getfield(L, -1, "_state");
CHECK(!lua_isnil(L, -1));
THCState* state = reinterpret_cast<THCState*>(lua_touserdata(L, -1));
lua_pop(L, 2);
return state;
}
#endif // MXNET_USE_CUDA
template <typename xpu>
void SetStream(mshadow::Stream<xpu>* s);
void PrintState() {
int i;
int top = lua_gettop(L);
LOG(INFO) << "Stack height: " << top;
for (i = 1; i <= top; i++) { /* repeat for each level */
int t = lua_type(L, i);
switch (t) {
case LUA_TSTRING: /* strings */
LOG(INFO) << i << ": '" << lua_tostring(L, i) << "'";
break;
case LUA_TBOOLEAN: /* booleans */
LOG(INFO) << i << ": " << (lua_toboolean(L, i) ? "true" : "false");
break;
case LUA_TNUMBER: /* numbers */
LOG(INFO) << i << ": " << lua_tonumber(L, i);
break;
default: /* other values */
LOG(INFO) << i << ": " << lua_typename(L, t);
break;
}
}
}
int Deserialize(THCharStorage* chunk) { // read only to the chunk
CHECK_NE(chunk, NULL);
lua_getglobal(L, "Deserialize");
luaT_pushudata(L, chunk, "torch.CharStorage");
THCharStorage_retain(chunk); // keep it because read only
int err = lua_pcall(L, 1, 1, 0);
CHECK_EQ(err, 0);
return 1;
}
int Serialize(THCharStorage** chunk) {
lua_getglobal(L, "Serialize");
lua_pushvalue(L, -2);
int err = lua_pcall(L, 1, 1, 0);
CHECK_EQ(err, 0) << "Serialize failed " << lua_tostring(L, -1);
THCharStorage_free(*chunk); // free the original
*chunk = reinterpret_cast<THCharStorage*>(luaT_toudata(L, -1, "torch.CharStorage"));
THCharStorage_retain(*chunk); // keep the chunk even when lua side deletes
lua_pop(L, 2);
return 0;
}
};
typedef void* THGeneralTensor;
typedef void* THGeneralStorage;
class TorchTensor {
public:
static const char* TensorType(int dev_mask) {
switch (dev_mask) {
case cpu::kDevMask:
return "torch.FloatTensor";
case gpu::kDevMask:
return "torch.CudaTensor";
default:
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
return NULL;
}
}
static const char* ModuleType(int dev_mask) {
switch (dev_mask) {
case cpu::kDevMask:
return ":float()";
case gpu::kDevMask:
return ":cuda()";
default:
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
return NULL;
}
}
static const char* TensorType(TBlob data) {
return TensorType(data.dev_mask());
}
static const char* ModuleType(TBlob data) {
return TensorType(data.dev_mask());
}
static THGeneralTensor TBlobToTHTensor(TorchState* torchState, TBlob data) {
size_t size = data.Size();
THGeneralTensor tensor = NULL;
THLongStorage* thshape = THLongStorage_newWithSize(data.ndim());
for (int i = 0; i < data.ndim(); ++i) {
THLongStorage_set(thshape, i, data.shape_[i]);
}
CHECK_EQ(data.type_flag_, mshadow::kFloat32) << "Torch Interface only support float32";
switch (data.dev_mask()) {
case cpu::kDevMask: {
THFloatStorage* storage =
THFloatStorage_newWithData(static_cast<real_t*>(data.dptr_), size);
THFloatStorage_clearFlag(storage, TH_STORAGE_FREEMEM);
tensor = (THGeneralTensor)THFloatTensor_newWithStorage(storage, 0, thshape, NULL);
THFloatStorage_free(storage);
break;
}
#if MXNET_USE_CUDA
case gpu::kDevMask: {
THCState* state = torchState->CudaState();
THCudaStorage* storage =
THCudaStorage_newWithData(state, static_cast<real_t*>(data.dptr_), size);
// a bug in cutorch
THFloatStorage_clearFlag(reinterpret_cast<THFloatStorage*>(storage), TH_STORAGE_FREEMEM);
tensor = (THGeneralTensor)THCudaTensor_newWithStorage(state, storage, 0, thshape, NULL);
THCudaStorage_free(state, storage);
break;
}
#endif
default:
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
}
THLongStorage_free(thshape);
return tensor;
}
static void FreeInternal(TorchState* torchState, THGeneralTensor tensor, int dev_mask) {
switch (dev_mask) {
case cpu::kDevMask: {
THFloatStorage* original = static_cast<THFloatTensor*>(tensor)->storage;
THFloatStorage_free(original);
break;
}
#if MXNET_USE_CUDA
case gpu::kDevMask: {
THCState* state = torchState->CudaState();
THCudaStorage* original = static_cast<THCudaTensor*>(tensor)->storage;
THCudaStorage_free(state, original);
break;
}
#endif
default:
LOG(FATAL) << "Unknown device type " << dev_mask;
}
}
static void SetInternal(TorchState* torchState, THGeneralTensor tensor, const TBlob& blob) {
size_t size = blob.Size();
switch (blob.dev_mask()) {
case cpu::kDevMask: {
THFloatStorage* storage =
THFloatStorage_newWithData(static_cast<real_t*>(blob.dptr_), size);
THFloatStorage_clearFlag(storage, TH_STORAGE_FREEMEM);
THFloatStorage* original = static_cast<THFloatTensor*>(tensor)->storage;
static_cast<THFloatTensor*>(tensor)->storage = storage;
THFloatStorage_free(original);
break;
}
#if MXNET_USE_CUDA
case gpu::kDevMask: {
THCState* state = torchState->CudaState();
THCudaStorage* storage =
THCudaStorage_newWithData(state, static_cast<real_t*>(blob.dptr_), size);
// TODO(min): torch bug Cuda version not implemented
THFloatStorage_clearFlag(reinterpret_cast<THFloatStorage*>(storage), TH_STORAGE_FREEMEM);
THCudaStorage* original = static_cast<THCudaTensor*>(tensor)->storage;
static_cast<THCudaTensor*>(tensor)->storage = storage;
THCudaStorage_free(state, original);
break;
}
#endif
default:
LOG(FATAL) << "Unknown device type " << blob.dev_mask();
}
}
static std::vector<THGeneralTensor> TBlobVectorAsTable(
TorchState* torchState,
const std::vector<TBlob>::const_iterator begin,
const std::vector<TBlob>::const_iterator end) {
lua_State* L = torchState->L;
std::vector<THGeneralTensor> res;
int num = end - begin;
if (num > 1) {
lua_createtable(L, num, 0);
int index = 1;
for (std::vector<TBlob>::const_iterator it = begin; it != end; ++it) {
THGeneralTensor th = TorchTensor::TBlobToTHTensor(torchState, *it);
res.push_back(th);
luaT_pushudata(L, th, TorchTensor::TensorType(*it));
lua_rawseti(L, -2, index++);
}
} else if (num == 0) {
lua_pushnil(L);
} else {
THGeneralTensor th = TorchTensor::TBlobToTHTensor(torchState, *begin);
res.push_back(th);
luaT_pushudata(L, th, TorchTensor::TensorType(*begin));
}
return res;
}
static void CopyIfDifferent(TorchState* torchState, TBlob dst, THGeneralTensor th_dst) {
lua_State* L = torchState->L;
if (luaT_isudata(L, -1, TorchTensor::TensorType(cpu::kDevMask))) {
CHECK_EQ(dst.dev_mask(), cpu::kDevMask) << "Device type mismatch.";
THFloatTensor* src =
static_cast<THFloatTensor*>(luaT_toudata(L, -1, TorchTensor::TensorType(cpu::kDevMask)));
if (src->storage != static_cast<THFloatTensor*>(th_dst)->storage) {
THFloatTensor_copy(static_cast<THFloatTensor*>(th_dst), src);
}
#if MXNET_USE_CUDA
} else if (luaT_isudata(L, -1, TorchTensor::TensorType(gpu::kDevMask))) {
CHECK_EQ(dst.dev_mask(), gpu::kDevMask) << "Device type mismatch.";
THCudaTensor* src =
static_cast<THCudaTensor*>(luaT_toudata(L, -1, TorchTensor::TensorType(gpu::kDevMask)));
if (src->storage != static_cast<THCudaTensor*>(th_dst)->storage) {
THCudaTensor_copy(torchState->CudaState(), static_cast<THCudaTensor*>(th_dst), src);
}
#endif // MXNET_USE_CUDA
} else {
LOG(FATAL) << "Unsupported Torch tensor type " << luaT_typename(L, -1);
}
}
static void CheckOutput(TorchState* torchState,
std::vector<TBlob>::const_iterator begin,
std::vector<TBlob>::const_iterator end,
std::vector<THGeneralTensor>::const_iterator th_begin,
std::vector<THGeneralTensor>::const_iterator th_end) {
lua_State* L = torchState->L;
int num = end - begin;
CHECK_EQ(th_end - th_begin, num);
if (num == 0) {
} else if (num == 1) {
CopyIfDifferent(torchState, *begin, *th_begin);
} else {
CHECK(lua_istable(L, -1));
lua_pushnil(L);
for (; begin != end; ++begin, ++th_begin) {
CHECK(lua_next(L, -2));
CopyIfDifferent(torchState, *begin, *th_begin);
lua_pop(L, 1);
}
lua_pop(L, 1);
}
}
};
} // namespace mxnet
#endif // PLUGIN_TORCH_TORCH_BASE_H_