blob: 0ca3d1680412f0a08380d6e7bbaf50366e8ff63c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* License); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef ENCODING_FIRE_H
#define ENCODING_FIRE_H
#include <cstdint>
template <typename T>
class Fire {
public:
explicit Fire(int learning_rate)
: learn_shift_(learning_rate),
bit_width_(0),
accumulator_(0),
delta_(0) {}
virtual ~Fire() = default;
virtual T predict(T value) = 0;
virtual void train(T pre, T val, T err) = 0;
virtual void reset() {
accumulator_ = 0;
delta_ = 0;
}
protected:
int learn_shift_;
int bit_width_;
int accumulator_;
T delta_;
};
class IntFire : public Fire<int> {
public:
explicit IntFire(int learning_rate) : Fire(learning_rate) {
bit_width_ = 8;
accumulator_ = 0;
delta_ = 0;
}
void reset() override {
accumulator_ = 0;
delta_ = 0;
}
int predict(int value) override {
int alpha = accumulator_ >> learn_shift_;
int diff = static_cast<int>((static_cast<int64_t>(alpha) * delta_)) >>
bit_width_;
return value + diff;
}
void train(int pre, int val, int err) override {
int gradient = err > 0 ? -delta_ : delta_;
accumulator_ -= gradient;
delta_ = val - pre;
}
};
class LongFire : public Fire<int64_t> {
public:
explicit LongFire(int learning_rate) : Fire(learning_rate) {
bit_width_ = 16;
accumulator_ = 0;
delta_ = 0;
}
void reset() override {
accumulator_ = 0;
delta_ = 0;
}
int64_t predict(int64_t value) override {
int64_t alpha = accumulator_ >> learn_shift_;
int64_t diff = safe_mul_shift(alpha, delta_, bit_width_);
return value + diff;
}
void train(int64_t pre, int64_t val, int64_t err) override {
int64_t gradient = err > 0 ? -delta_ : delta_;
accumulator_ -= gradient;
delta_ = val - pre;
}
private:
/** (alpha * delta_) >> shift without signed overflow; both args are
* int64_t. */
static int64_t safe_mul_shift(int64_t alpha, int64_t delta, int shift) {
#if defined(__SIZEOF_INT128__) && __SIZEOF_INT128__ >= 16
__int128 product = static_cast<__int128>(alpha) * delta;
return static_cast<int64_t>(product >> shift);
#else
/* Portable fallback: use double for product. Exact for |alpha|,|delta|
* < 2^53. */
double prod = static_cast<double>(alpha) * static_cast<double>(delta);
double div = static_cast<double>(1LL << shift);
return static_cast<int64_t>(prod / div);
#endif
}
};
#endif // ENCODING_FIRE_H