blob: 2cfeacfcd71f72cc5818611456b8de1f8a8fc1ce [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file local_session.cc
* \brief Local session that directs requests to local API.
*/
#include "rpc_local_session.h"
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/tensor.h>
#include <memory>
#include <vector>
namespace tvm {
namespace runtime {
RPCSession::PackedFuncHandle LocalSession::GetFunction(const std::string& name) {
if (auto fp = tvm::ffi::Function::GetGlobal(name)) {
// return raw handle because the remote need to explicitly manage it.
Any ret = *fp;
TVMFFIAny ret_any = ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(ret));
return ret_any.v_obj;
} else {
return nullptr;
}
}
void LocalSession::EncodeReturn(ffi::Any rv, const FEncodeReturn& encode_return) {
AnyView packed_args[3];
// NOTE: this is the place that we need to handle special RPC-related
// ABI convention for return value passing that is built on top of Any FFI.
// first argument is always the type index.
packed_args[0] = rv.type_index();
if (rv == nullptr) {
packed_args[1] = rv;
encode_return(ffi::PackedArgs(packed_args, 2));
} else if (rv.as<Tensor>()) {
// We follow a special protocol to return Tensor to client side
// The first pack value is the Tensor handle as DLTensor
// The second pack value is a customized deleter that deletes the Tensor.
TVMFFIAny ret_any = ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(rv));
void* opaque_handle = ret_any.v_obj;
packed_args[1] = TVMFFITensorGetDLTensorPtr(opaque_handle);
packed_args[2] = opaque_handle;
encode_return(ffi::PackedArgs(packed_args, 3));
} else if (const auto opt_bytes = rv.as<ffi::Bytes>()) {
// always pass bytes as byte array
TVMFFIByteArray byte_arr;
byte_arr.data = (*opt_bytes).data();
byte_arr.size = (*opt_bytes).size();
packed_args[1] = &byte_arr;
encode_return(ffi::PackedArgs(packed_args, 2));
} else if (auto opt_str = rv.as<ffi::String>()) {
// encode string as c_str
packed_args[1] = (*opt_str).data();
encode_return(ffi::PackedArgs(packed_args, 2));
} else if (rv.as<ffi::ObjectRef>()) {
TVMFFIAny ret_any = ffi::details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(rv));
void* opaque_handle = ret_any.v_obj;
packed_args[1] = opaque_handle;
encode_return(ffi::PackedArgs(packed_args, 2));
} else {
packed_args[1] = rv;
encode_return(ffi::PackedArgs(packed_args, 2));
}
}
void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, ffi::PackedArgs args,
const FEncodeReturn& encode_return) {
ffi::FunctionObj* pf = static_cast<ffi::FunctionObj*>(func);
Any rv;
std::vector<AnyView> packed_args(args.size());
// unwrap RPCObjectRef in case we are directly using it to call LocalSession
for (int i = 0; i < args.size(); ++i) {
if (auto opt_rpc_obj = args[i].as<RPCObjectRef>()) {
packed_args[i] = static_cast<const Object*>(opt_rpc_obj.value()->object_handle());
} else {
packed_args[i] = args[i];
}
}
pf->CallPacked(packed_args.data(), packed_args.size(), &rv);
this->EncodeReturn(std::move(rv), encode_return);
}
void LocalSession::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes) {
ICHECK_EQ(nbytes, GetDataSize(*to));
DLTensor from;
from.data = from_bytes;
from.device = {kDLCPU, 0};
from.ndim = to->ndim;
from.shape = to->shape;
from.dtype = to->dtype;
from.strides = nullptr;
from.byte_offset = 0;
Device dev_to = to->device;
this->GetDeviceAPI(dev_to)->CopyDataFromTo(&from, to, nullptr);
// Copy can happen asynchrously
// synchronize to make sure that copy is completed
this->GetDeviceAPI(dev_to)->StreamSync(dev_to, nullptr);
}
void LocalSession::CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes) {
ICHECK_EQ(nbytes, ffi::GetDataSize(*from));
DLTensor to;
to.data = to_bytes;
to.device = {kDLCPU, 0};
to.ndim = from->ndim;
to.shape = from->shape;
to.dtype = from->dtype;
to.strides = nullptr;
to.byte_offset = 0;
Device dev_from = from->device;
this->GetDeviceAPI(dev_from)->CopyDataFromTo(from, &to, nullptr);
// Copy can happen asynchrously
// synchronize to make sure that copy is completed
this->GetDeviceAPI(dev_from)->StreamSync(dev_from, nullptr);
}
void LocalSession::FreeHandle(void* handle) {
// NOTE: the type code is no longer need during free handle.
ffi::details::ObjectUnsafe::DecRefObjectHandle(handle);
}
DeviceAPI* LocalSession::GetDeviceAPI(Device dev, bool allow_missing) {
return DeviceAPI::Get(dev, allow_missing);
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("rpc.LocalSession",
[]() { return CreateRPCSessionModule(std::make_shared<LocalSession>()); });
}
} // namespace runtime
} // namespace tvm