blob: d07f84be16477b3593c5fa019e3da6013a91c6dc [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/runtime/vm/lm_support.cc
* \brief Runtime to support language model related task
*
* Including inplace attention kv cache for runtime and simple sampler.
*
* This file provides a simple implementation of inplace attention
* kv cache for relax runtime. The main goal here is to help us enable
* auto-regressive decoding quickly in relax.
*
* This is not the only way to support attention kv-cache.
* Our support of attention kv-cache can subject to future
* changes as we build more LM verticals.
*
* We will keep the impact minimum by puting it as a private
* runtime builtin provide as in this file.
*
* We can evolve this implementation as we build more LM verticals.
*/
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/shape.h>
#include <tvm/ffi/memory.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/device_api.h>
#include <tvm/ffi/error.h>
#include <tvm/runtime/memory/memory_manager.h>
#include <tvm/runtime/tensor.h>
#include <tvm/runtime/vm/vm.h>
#include <cmath>
namespace tvm {
namespace runtime {
namespace vm {
//-------------------------------------------
// We keep the implementation private as
// they may subject to future changes.
//
// Users can interact with it through the
// runtime API function calls
//-------------------------------------------
/*!
* \brief An object representing an attention kv cache.
*/
class AttentionKVCacheLegacyObj : public ffi::Object {
public:
/*!
* \brief Underlying support data.
*/
Tensor data;
/*!
* \brief number of slots already filled.
*/
int64_t fill_count{0};
/*!
* \brief current cache position (windowed kv cache only).
*/
int64_t window_attention_current_pos{0};
/*!
* \brief View all current cached values as one array.
* \param shape The cached values.
*/
Tensor View(const ffi::Shape& shape) {
TVM_FFI_ICHECK_EQ(shape[0], fill_count) << "Requested shape do not match the filled count";
for (int i = 1; i < this->data->ndim; ++i) {
TVM_FFI_ICHECK_EQ(shape[i], data->shape[i]) << "Dimension " << i << " mismatch";
}
return data.CreateView(shape, data->dtype);
}
/** Clear the cache */
void Clear() {
this->fill_count = 0;
this->window_attention_current_pos = 0;
}
/** pop n entries */
void PopN(size_t n) {
TVM_FFI_ICHECK_LE(n, fill_count);
this->fill_count -= n;
}
void Update(Tensor value) {
TVM_FFI_ICHECK(data.DataType() == value.DataType()) << "dtype mismatch";
TVM_FFI_ICHECK_EQ(value->shape[0], fill_count)
<< "Requested shape do not match the filled count";
TVM_FFI_ICHECK(data.IsContiguous());
TVM_FFI_ICHECK(value.IsContiguous());
DLTensor copy_dst = *(data.operator->());
copy_dst.byte_offset = 0;
copy_dst.shape = value->shape;
Tensor::CopyFromTo(value.operator->(), &copy_dst);
this->fill_count = value->shape[0];
}
/*!
* \brief Append value to the cache, overrides if full.
* \param value The value to override previous elements.
* \param max_cache_size max size of the cache.
* \param num_attention_sinks number of sinks to store (https://arxiv.org/abs/2309.17453).
*/
void WindowOverride(Tensor value, int64_t max_cache_size, int64_t num_attention_sinks = 0) {
TVM_FFI_ICHECK(data.DataType() == value.DataType()) << "dtype mismatch";
TVM_FFI_ICHECK_LE(value->shape[0], max_cache_size - num_attention_sinks)
<< "dim 0 of value too large";
// reallocate cache
if (fill_count + value->shape[0] <= max_cache_size) {
int64_t reserved_slots = data->shape[0];
while (fill_count + value->shape[0] > reserved_slots) {
reserved_slots *= 2;
}
if (reserved_slots != data->shape[0]) {
std::vector<int64_t> new_shape(data->shape, data->shape + data->ndim);
new_shape[0] = reserved_slots;
Tensor new_data = Tensor::Empty(new_shape, data->dtype, data->device);
new_data.CreateView(data.Shape(), data->dtype).CopyFrom(data);
this->data = new_data;
}
}
// copy into the current position.
TVM_FFI_ICHECK(data.IsContiguous());
int64_t num_elements_to_copy =
std::min(value->shape[0], max_cache_size - window_attention_current_pos);
int64_t num_elements_p_entry = 1;
std::vector<int64_t> shape;
shape.push_back(num_elements_to_copy);
for (int i = 1; i < data->ndim; ++i) {
TVM_FFI_ICHECK_EQ(value->shape[i], data->shape[i]) << "Dimension " << i << " mismatch";
num_elements_p_entry *= data->shape[i];
shape.push_back(data->shape[i]);
}
int64_t num_filled_elements = window_attention_current_pos * num_elements_p_entry;
this->fill_count = std::min(this->fill_count + value->shape[0], max_cache_size);
this->window_attention_current_pos =
std::min(this->window_attention_current_pos + value->shape[0], max_cache_size);
if (num_elements_to_copy > 0) {
DLTensor copy_dst = *(data.operator->());
copy_dst.byte_offset = num_filled_elements * ((data->dtype.bits * data->dtype.lanes + 7) / 8);
copy_dst.shape = &shape[0];
DLTensor copy_src = *(value.operator->());
copy_src.byte_offset = 0;
copy_src.shape = &shape[0];
Tensor::CopyFromTo(&copy_src, &copy_dst);
}
// copy the remainder to the beginning of the cache
if (num_elements_to_copy < value->shape[0]) {
TVM_FFI_ICHECK_EQ(this->fill_count, max_cache_size);
TVM_FFI_ICHECK_EQ(this->fill_count, this->window_attention_current_pos);
shape[0] = value->shape[0] - num_elements_to_copy;
num_filled_elements = num_elements_to_copy * num_elements_p_entry;
DLTensor copy_dst = *(data.operator->());
copy_dst.byte_offset = (num_attention_sinks * num_elements_p_entry) *
((data->dtype.bits * data->dtype.lanes + 7) / 8);
copy_dst.shape = &shape[0];
DLTensor copy_src = *(value.operator->());
copy_src.byte_offset =
num_filled_elements * ((value->dtype.bits * value->dtype.lanes + 7) / 8);
copy_src.shape = &shape[0];
Tensor::CopyFromTo(&copy_src, &copy_dst);
this->window_attention_current_pos =
value->shape[0] - num_elements_to_copy + num_attention_sinks;
}
}
/*!
* \brief Append value to the cache.
* \param value The value to be appended.
*/
void Append(Tensor value) {
TVM_FFI_ICHECK(data.DataType() == value.DataType()) << "dtype mismatch";
// reallocate cache
int64_t reserved_slots = data->shape[0];
while (fill_count + value->shape[0] > reserved_slots) {
reserved_slots *= 2;
}
if (reserved_slots != data->shape[0]) {
std::vector<int64_t> new_shape(data->shape, data->shape + data->ndim);
new_shape[0] = reserved_slots;
Tensor new_data = Tensor::Empty(new_shape, data->dtype, data->device);
new_data.CreateView(data.Shape(), data->dtype).CopyFrom(data);
this->data = new_data;
}
// copy into the fill count position.
TVM_FFI_ICHECK_LE(fill_count + value->shape[0], data->shape[0]);
TVM_FFI_ICHECK(data.IsContiguous());
int64_t num_filled_elements = fill_count;
for (int i = 1; i < data->ndim; ++i) {
TVM_FFI_ICHECK_EQ(value->shape[i], data->shape[i]) << "Dimension " << i << " mismatch";
num_filled_elements *= data->shape[i];
}
// create a view of copy dest to copy the value into it.
DLTensor copy_dst = *(data.operator->());
copy_dst.byte_offset = num_filled_elements * ((data->dtype.bits * data->dtype.lanes + 7) / 8);
copy_dst.shape = value->shape;
Tensor::CopyFromTo(value.operator->(), &copy_dst);
this->fill_count += value->shape[0];
}
static constexpr const bool _type_mutable = true;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.vm.AttentionKVCacheLegacy", AttentionKVCacheLegacyObj,
ffi::Object);
};
/*! \brief reference to closure. */
class AttentionKVCacheLegacy : public ffi::ObjectRef {
public:
/*!
* \brief Create the attention kv cache.
* \param init_data The initial reserved.
*/
static AttentionKVCacheLegacy Create(Tensor init_data, ffi::Shape reserve_shape,
int init_fill_count) {
auto n = ffi::make_object<AttentionKVCacheLegacyObj>();
n->data = Tensor::Empty(reserve_shape, init_data->dtype, init_data->device);
n->fill_count = 0;
n->Append(init_data);
if (init_fill_count >= 0) {
n->fill_count = init_fill_count;
n->window_attention_current_pos = init_fill_count; // window attention only
}
return AttentionKVCacheLegacy(n);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AttentionKVCacheLegacy, ffi::ObjectRef,
AttentionKVCacheLegacyObj);
};
//-------------------------------------------------
// Register runtime functions
//-------------------------------------------------
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("vm.builtin.attention_kv_cache_create", AttentionKVCacheLegacy::Create);
}
AttentionKVCacheLegacy AttentionKVCacheUpdate(AttentionKVCacheLegacy cache, Tensor value) {
cache->Update(value);
return cache;
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("vm.builtin.attention_kv_cache_update", AttentionKVCacheUpdate);
}
AttentionKVCacheLegacy AttentionKVCacheAppend(AttentionKVCacheLegacy cache, Tensor value) {
cache->Append(value);
return cache;
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("vm.builtin.attention_kv_cache_append", AttentionKVCacheAppend);
}
AttentionKVCacheLegacy AttentionKVCacheWindowOverride(AttentionKVCacheLegacy cache, Tensor value,
int64_t max_cache_size) {
cache->WindowOverride(value, max_cache_size);
return cache;
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("vm.builtin.attention_kv_cache_window_override",
AttentionKVCacheWindowOverride);
}
AttentionKVCacheLegacy AttentionKVCacheWindowOverrideWithSinks(AttentionKVCacheLegacy cache,
Tensor value, int64_t max_cache_size,
int64_t num_attention_sinks) {
cache->WindowOverride(value, max_cache_size, num_attention_sinks);
return cache;
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("vm.builtin.attention_kv_cache_window_override_with_sinks",
AttentionKVCacheWindowOverrideWithSinks);
}
Tensor AttentionKVCacheView(AttentionKVCacheLegacy cache, ffi::Shape shape) {
return cache->View(shape);
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed(
"vm.builtin.attention_kv_cache_view", [](ffi::PackedArgs args, ffi::Any* rv) {
TVM_FFI_CHECK(args.size() == 1 || args.size() == 2, ValueError)
<< "`vm.builtin.attention_kv_cache_view` expects 1 or 2 arguments, but got "
<< args.size() << ".";
AttentionKVCacheLegacy cache = args[0].cast<AttentionKVCacheLegacy>();
if (args.size() == 2) {
ffi::Shape shape = args[1].cast<ffi::Shape>();
*rv = cache->View(shape);
} else {
std::vector<ffi::Shape::index_type> shape;
shape.push_back(cache->fill_count);
for (int i = 1; i < cache->data->ndim; ++i) {
shape.push_back(cache->data->shape[i]);
}
*rv = cache->View(ffi::Shape(shape));
}
});
}
void AttentionKVCacheArrayPopN(ffi::Array<AttentionKVCacheLegacy> caches, int64_t n) {
for (AttentionKVCacheLegacy cache : caches) {
cache->PopN(static_cast<size_t>(n));
}
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("vm.builtin.attention_kv_cache_array_popn", AttentionKVCacheArrayPopN);
}
void AttentionKVCacheArrayClear(ffi::Array<AttentionKVCacheLegacy> caches) {
for (AttentionKVCacheLegacy cache : caches) {
cache->Clear();
}
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("vm.builtin.attention_kv_cache_array_clear", AttentionKVCacheArrayClear);
}
// NOTE this is a built-in highly related to LM so we put it here.
int SampleTopPFromLogits(Tensor logits, double temperature, double top_p, double uniform_sample) {
TVM_FFI_ICHECK(logits.IsContiguous());
TVM_FFI_ICHECK(logits.DataType() == DataType::Float(32));
if (logits->device.device_type != kDLCPU) {
logits = logits.CopyTo(DLDevice{kDLCPU, 0});
}
TVM_FFI_ICHECK(logits->device.device_type == kDLCPU);
for (int i = 0; i < logits->ndim - 1; ++i) {
TVM_FFI_ICHECK_EQ(logits->shape[i], 1) << "The leading dimensions of logits must be 1";
}
std::vector<std::pair<float, int>> data;
data.resize(logits->shape[logits->ndim - 1]);
const float* plogits = static_cast<float*>(logits->data);
for (size_t i = 0; i < data.size(); ++i) {
data[i] = std::make_pair(plogits[i], static_cast<int>(i));
}
auto fcmp = [](const std::pair<float, int>& lhs, const std::pair<float, int>& rhs) {
return lhs.first > rhs.first;
};
// sort by logits from largest to smallest
std::sort(data.begin(), data.end(), fcmp);
// argmax
if (temperature < 1e-6f) {
return data[0].second;
}
// compute expf scaled by temp
float sum = 0.0f, logit_scale = 1.0f / temperature;
float max_value = data[0].first;
for (auto it = data.begin(); it != data.end(); ++it) {
it->first = expf((it->first - max_value) * logit_scale);
sum += it->first;
}
// do a cumsum in order of data
float cum_sum_prob = 0.0f;
float top_p_sum = 0.0f;
for (auto it = data.begin(); it != data.end(); ++it) {
float prob = it->first / sum;
if (cum_sum_prob < top_p) {
top_p_sum += prob;
}
cum_sum_prob += prob;
it->first = cum_sum_prob;
}
// pick a number based on random in (0, 1)
for (auto it = data.begin(); it != data.end(); ++it) {
if (uniform_sample < it->first / top_p_sum) {
return it->second;
}
}
TVM_FFI_ICHECK_LE(uniform_sample, data[0].first);
return data[0].second;
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("vm.builtin.sample_top_p_from_logits", SampleTopPFromLogits);
}
int SampleTopPFromProb(Tensor prob, double top_p, double uniform_sample) {
TVM_FFI_ICHECK(prob.IsContiguous());
TVM_FFI_ICHECK(prob.DataType() == DataType::Float(32));
if (prob->device.device_type != kDLCPU) {
prob = prob.CopyTo(DLDevice{kDLCPU, 0});
}
TVM_FFI_ICHECK(prob->device.device_type == kDLCPU);
for (int i = 0; i < prob->ndim - 1; ++i) {
TVM_FFI_ICHECK_EQ(prob->shape[i], 1) << "The leading dimensions of logits must be 1";
}
// Key observation: when we are doing top_p sampling
// usually we only need to preserve some of the elements with
// high probablities before we do sort
std::vector<std::pair<float, int>> data;
int64_t ndata = prob->shape[prob->ndim - 1];
const float* p_prob = static_cast<float*>(prob->data);
auto sample_top_p_with_filter = [&](float cuttoff) -> int64_t {
data.clear();
// filter the data with cuttoff
for (int64_t i = 0; i < ndata; ++i) {
if (p_prob[i] >= cuttoff) {
data.emplace_back(std::make_pair(p_prob[i], static_cast<int>(i)));
}
}
if (data.size() == 0) return -1;
auto fcmp = [](const std::pair<float, int>& lhs, const std::pair<float, int>& rhs) {
return lhs.first > rhs.first;
};
std::sort(data.begin(), data.end(), fcmp);
// short cut, if we know that
// uniform sample < p[0] / top_p
// we know that unform_sample < p[0] / top_p_sum
// because top_p_sum gaurantees to be smaller than top_p
// so we can simply return the argmax sample
// without computing anything
if (uniform_sample < data[0].first / top_p) return data[0].second;
// compute top_p_sum
float cum_sum_prob = 0.0f;
float top_p_sum = 0.0f;
for (auto it = data.begin(); it != data.end(); ++it) {
float prob = it->first;
if (cum_sum_prob < top_p) {
top_p_sum += prob;
} else {
// we get to the right cutoff pt
break;
}
cum_sum_prob += prob;
it->first = cum_sum_prob;
}
// we find that the current total sum by the given cutoff
// is not sufficient to cover everything
// this means we might need to retry a smaller cutoff pt.
if (cum_sum_prob < top_p && cuttoff != 0.0f) return -1;
for (auto it = data.begin(); it != data.end(); ++it) {
if (uniform_sample < it->first / top_p_sum) {
return it->second;
}
}
return data[data.size() - 1].second;
};
auto is_all_nan = [&]() -> bool {
return std::all_of(p_prob, p_prob + ndata, [](float x) { return std::isnan(x); });
};
if (top_p < 1) {
// sample through cutoff by a number
// by pigeonhole principle we will get at most 1024 elements
// usually it is much less by applying this filtering(order of 10 - 20)
data.reserve(128);
int64_t sampled_index = sample_top_p_with_filter(top_p / 1024);
if (sampled_index >= 0) return sampled_index;
}
// fallback via full prob, rare case
data.reserve(ndata);
int64_t sampled_index = sample_top_p_with_filter(0.0f);
if (sampled_index < 0 && is_all_nan()) {
TVM_FFI_THROW(InternalError) << "The output probabilities are all NaNs, can not sample from it";
} else if (sampled_index < 0) {
TVM_FFI_THROW(InternalError)
<< "Cannot sample from the given probability distribution due to unknown reason";
}
return sampled_index;
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("vm.builtin.sample_top_p_from_prob", SampleTopPFromProb);
}
Tensor MultinomialFromUniform(Tensor prob, Tensor uniform_sample) {
TVM_FFI_ICHECK(prob.IsContiguous());
TVM_FFI_ICHECK(uniform_sample.IsContiguous());
if (prob->device.device_type != kDLCPU) {
prob = prob.CopyTo(DLDevice{kDLCPU, 0});
}
if (uniform_sample->device.device_type != kDLCPU) {
uniform_sample = uniform_sample.CopyTo(DLDevice{kDLCPU, 0});
}
TVM_FFI_ICHECK(prob->device.device_type == kDLCPU);
TVM_FFI_ICHECK(uniform_sample->device.device_type == kDLCPU);
int64_t batch_size = prob->shape[0];
int64_t vocab_size = prob->shape[prob->ndim - 1];
const float* pprob = static_cast<float*>(prob->data);
const float* psample = static_cast<float*>(uniform_sample->data);
Tensor new_array = Tensor::Empty({batch_size, 1}, DataType::Int(64), uniform_sample->device);
int64_t* parray = static_cast<int64_t*>(new_array->data);
for (int64_t i = 0; i < batch_size; ++i) {
float cum_sum_prob = 0.0f;
int64_t prob_idx = 0;
for (int64_t j = 0; j < vocab_size; ++j) {
prob_idx = j;
cum_sum_prob += pprob[i * vocab_size + j];
if (cum_sum_prob > psample[i]) {
break;
}
}
parray[i] = prob_idx;
}
return new_array;
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("vm.builtin.multinomial_from_uniform", MultinomialFromUniform);
}
// This is an inplace operation.
void ApplyRepetitionPenalty(Tensor logits, Tensor token_ids, double penalty) {
TVM_FFI_ICHECK(logits.IsContiguous());
TVM_FFI_ICHECK(token_ids.IsContiguous());
TVM_FFI_ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!";
TVM_FFI_ICHECK(token_ids.DataType() == DataType::Int(32)) << "token ids must be int32!";
TVM_FFI_ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!";
TVM_FFI_ICHECK(token_ids->device.device_type == kDLCPU) << "token_ids device must be CPU!";
float* logits_raw_data = static_cast<float*>(logits->data);
int* token_ids_data = static_cast<int*>(token_ids->data);
size_t num_token_ids = token_ids->shape[token_ids->ndim - 1];
for (size_t i = 0; i < num_token_ids; ++i) {
int token_id = token_ids_data[i];
if (logits_raw_data[token_id] <= 0) {
logits_raw_data[token_id] *= penalty;
} else { // logits > 0
logits_raw_data[token_id] /= penalty;
}
}
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("vm.builtin.apply_repetition_penalty", ApplyRepetitionPenalty);
}
/*!
* \brief Apply presence and frequency penalty. This is an inplace operation.
* \param logits The input logits before penalty.
* \param token_ids The appeared token ids.
* \param token_freqs The number of times each token has appeared since last PrefillStep.
* token_freqs[i] is the frequency of token_ids[i], for all i. And all token_freqs should be >= 1.
* \param presence_penalty The penalty factor, applied if a token appeared in an one-off manner.
* \param frequency_penalty The penalty factor, contributes more the more frequent a token appears.
*/
void ApplyPresenceAndFrequencyPenalty(Tensor logits, Tensor token_ids, Tensor token_freqs,
double presence_penalty, double frequency_penalty) {
// See https://platform.openai.com/docs/guides/text-generation/frequency-and-presence-penalties
TVM_FFI_ICHECK(logits.IsContiguous());
TVM_FFI_ICHECK(token_ids.IsContiguous());
TVM_FFI_ICHECK(token_freqs.IsContiguous());
TVM_FFI_ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!";
TVM_FFI_ICHECK(token_ids.DataType() == DataType::Int(32)) << "token ids must be int32!";
TVM_FFI_ICHECK(token_freqs.DataType() == DataType::Int(32)) << "token freqs must be int32!";
TVM_FFI_ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!";
TVM_FFI_ICHECK(token_ids->device.device_type == kDLCPU) << "token_ids device must be CPU!";
TVM_FFI_ICHECK(token_freqs->device.device_type == kDLCPU) << "token_ids device must be CPU!";
float* logits_raw_data = static_cast<float*>(logits->data);
int* token_ids_data = static_cast<int*>(token_ids->data);
int* token_freqs_data = static_cast<int*>(token_freqs->data);
size_t num_token_ids = token_ids->shape[token_ids->ndim - 1];
for (size_t i = 0; i < num_token_ids; ++i) {
int token_id = token_ids_data[i];
int token_freq = token_freqs_data[i];
logits_raw_data[token_id] -= (token_freq * frequency_penalty + presence_penalty);
}
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("vm.builtin.apply_presence_and_frequency_penalty",
ApplyPresenceAndFrequencyPenalty);
}
// This is an inplace operation.
void ApplySoftmaxWithTemperature(Tensor logits, double temperature) {
TVM_FFI_ICHECK(logits.IsContiguous());
TVM_FFI_ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!";
TVM_FFI_ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!";
int vocab_size = logits->shape[logits->ndim - 1];
float* logits_raw_data = static_cast<float*>(logits->data);
float inv_temp = 1.0f / temperature;
float m = std::numeric_limits<float>::min();
double d = 0.0f;
for (int i = 0; i < vocab_size; ++i) {
float x = logits_raw_data[i] * inv_temp;
float m_prev = m;
m = std::max(m, x);
d = d * std::exp(m_prev - m) + std::exp(x - m);
}
for (int i = 0; i < vocab_size; ++i) {
float x = logits_raw_data[i] * inv_temp;
logits_raw_data[i] = std::exp(x - m) / d;
}
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("vm.builtin.apply_softmax_with_temperature", ApplySoftmaxWithTemperature);
}
} // namespace vm
} // namespace runtime
} // namespace tvm