blob: 9e8d0678a986e25261589f94ae5b888d6bc26f2b [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#ifndef IMPALA_MPFIT_UTIL_H
#define IMPALA_MPFIT_UTIL_H
#include <functional>
#include <memory>
#include <string>
#include "common/logging.h"
#include "common/status.h"
#include "thirdparty/mpfit/mpfit.h"
namespace impala {
/// Objective function to be fit using the MpFit library. An objective function has
/// exactly one 'x' variable and any number of additional non-variable parameters whose
/// values are unknown. The purpose of curve fitting is to find the best values for those
/// parameters given an objective function and a series of x/y data points.
/// This class contains the objective function as well as the data points because that's
/// the most natural setup for the MpFit API.
/// Calling LmsFit() determines the parameters for the objective function.
/// Calling GetY() computes the 'y' value for a given 'x' using the function parameters.
/// The objective function is of type function<double (double, const double*)>
/// By convention, the first argument is the value of the 'x' variable and the second
/// argument is an array of function parameters which are determined during fitting.
class ObjectiveFunction {
public:
ObjectiveFunction(std::string name, int num_params,
std::function<double (double, const double*)> fn);
/// Performs least mean squares (LMS) curve fitting using the MpFit library
/// against the provided x/y data points.
/// Returns true if fitting was successful, false otherwise.
bool LmsFit(const double* xs, const double* ys, int num_points) WARN_UNUSED_RESULT;
/// Evaluates the objective function over the given 'x' value.
double GetY(int64_t x) const {
DCHECK(params_ != nullptr);
return fn_(x, params_.get());
}
/// Returns the difference between the y value of data point 'pidx' and the
/// y value of the objective function with the given parameters over the x value
/// of the same point.
double GetDeltaY(int pidx, const double* params) const {
DCHECK_LT(pidx, num_points_);
return ys_[pidx] - fn_(xs_[pidx], params);
}
/// Returns the Chi-Square of fitting. This is an indication of how well the function
/// fits. Lower is better. Valid to call after LmsFit().
double GetError() const {
DCHECK(params_ != nullptr);
return result_.bestnorm;
}
private:
/// Human-readable name of this function. Used for debugging.
std::string name_;
/// Function parameters to be determined by fitting.
const int num_params_;
std::unique_ptr<double[]> params_;
/// MPFit result structure. Populated by in LmsFit(). All pointers in this structure
/// are optional and must be allocated and owned by the caller of mpfit(). Passing
/// nullptr indicates to MPFit that those fields should not be populated.
mp_result result_;
/// Objective function whose parameters should be fit to the data points.
std::function<double (double, const double*)> fn_;
/// Known x/y data points. Memory not owned.
int num_points_;
const double* xs_;
const double* ys_;
};
}
#endif