blob: 6c8cef35b3a71efc427b585dad312a619a89d1d0 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::f64::consts::LN_2;
use crate::interval_arithmetic::{Interval, apply_operator};
use crate::operator::Operator;
use crate::type_coercion::binary::binary_numeric_coercion;
use arrow::array::ArrowNativeTypeOp;
use arrow::datatypes::DataType;
use datafusion_common::rounding::alter_fp_rounding_mode;
use datafusion_common::{
Result, ScalarValue, assert_eq_or_internal_err, assert_ne_or_internal_err,
assert_or_internal_err, internal_err, not_impl_err,
};
/// This object defines probabilistic distributions that encode uncertain
/// information about a single, scalar value. Currently, we support five core
/// statistical distributions. New variants will be added over time.
///
/// This object is the lowest-level object in the statistics hierarchy, and it
/// is the main unit of calculus when evaluating expressions in a statistical
/// context. Notions like column and table statistics are built on top of this
/// object and the operations it supports.
#[derive(Clone, Debug, PartialEq)]
pub enum Distribution {
Uniform(UniformDistribution),
Exponential(ExponentialDistribution),
Gaussian(GaussianDistribution),
Bernoulli(BernoulliDistribution),
Generic(GenericDistribution),
}
use Distribution::{Bernoulli, Exponential, Gaussian, Generic, Uniform};
impl Distribution {
/// Constructs a new [`Uniform`] distribution from the given [`Interval`].
pub fn new_uniform(interval: Interval) -> Result<Self> {
UniformDistribution::try_new(interval).map(Uniform)
}
/// Constructs a new [`Exponential`] distribution from the given rate/offset
/// pair, and validates the given parameters.
pub fn new_exponential(
rate: ScalarValue,
offset: ScalarValue,
positive_tail: bool,
) -> Result<Self> {
ExponentialDistribution::try_new(rate, offset, positive_tail).map(Exponential)
}
/// Constructs a new [`Gaussian`] distribution from the given mean/variance
/// pair, and validates the given parameters.
pub fn new_gaussian(mean: ScalarValue, variance: ScalarValue) -> Result<Self> {
GaussianDistribution::try_new(mean, variance).map(Gaussian)
}
/// Constructs a new [`Bernoulli`] distribution from the given success
/// probability, and validates the given parameters.
pub fn new_bernoulli(p: ScalarValue) -> Result<Self> {
BernoulliDistribution::try_new(p).map(Bernoulli)
}
/// Constructs a new [`Generic`] distribution from the given mean, median,
/// variance, and range values after validating the given parameters.
pub fn new_generic(
mean: ScalarValue,
median: ScalarValue,
variance: ScalarValue,
range: Interval,
) -> Result<Self> {
GenericDistribution::try_new(mean, median, variance, range).map(Generic)
}
/// Constructs a new [`Generic`] distribution from the given range. Other
/// parameters (mean, median and variance) are initialized with null values.
pub fn new_from_interval(range: Interval) -> Result<Self> {
let null = ScalarValue::try_from(range.data_type())?;
Distribution::new_generic(null.clone(), null.clone(), null, range)
}
/// Extracts the mean value of this uncertain quantity, depending on its
/// distribution:
/// - A [`Uniform`] distribution's interval determines its mean value, which
/// is the arithmetic average of the interval endpoints.
/// - An [`Exponential`] distribution's mean is calculable by the formula
/// `offset + 1 / λ`, where `λ` is the (non-negative) rate.
/// - A [`Gaussian`] distribution contains the mean explicitly.
/// - A [`Bernoulli`] distribution's mean is equal to its success probability `p`.
/// - A [`Generic`] distribution _may_ have it explicitly, or this information
/// may be absent.
pub fn mean(&self) -> Result<ScalarValue> {
match &self {
Uniform(u) => u.mean(),
Exponential(e) => e.mean(),
Gaussian(g) => Ok(g.mean().clone()),
Bernoulli(b) => Ok(b.mean().clone()),
Generic(u) => Ok(u.mean().clone()),
}
}
/// Extracts the median value of this uncertain quantity, depending on its
/// distribution:
/// - A [`Uniform`] distribution's interval determines its median value, which
/// is the arithmetic average of the interval endpoints.
/// - An [`Exponential`] distribution's median is calculable by the formula
/// `offset + ln(2) / λ`, where `λ` is the (non-negative) rate.
/// - A [`Gaussian`] distribution's median is equal to its mean, which is
/// specified explicitly.
/// - A [`Bernoulli`] distribution's median is `1` if `p > 0.5` and `0`
/// otherwise, where `p` is the success probability.
/// - A [`Generic`] distribution _may_ have it explicitly, or this information
/// may be absent.
pub fn median(&self) -> Result<ScalarValue> {
match &self {
Uniform(u) => u.median(),
Exponential(e) => e.median(),
Gaussian(g) => Ok(g.median().clone()),
Bernoulli(b) => b.median(),
Generic(u) => Ok(u.median().clone()),
}
}
/// Extracts the variance value of this uncertain quantity, depending on
/// its distribution:
/// - A [`Uniform`] distribution's interval determines its variance value, which
/// is calculable by the formula `(upper - lower) ^ 2 / 12`.
/// - An [`Exponential`] distribution's variance is calculable by the formula
/// `1 / (λ ^ 2)`, where `λ` is the (non-negative) rate.
/// - A [`Gaussian`] distribution's variance is specified explicitly.
/// - A [`Bernoulli`] distribution's median is given by the formula `p * (1 - p)`
/// where `p` is the success probability.
/// - A [`Generic`] distribution _may_ have it explicitly, or this information
/// may be absent.
pub fn variance(&self) -> Result<ScalarValue> {
match &self {
Uniform(u) => u.variance(),
Exponential(e) => e.variance(),
Gaussian(g) => Ok(g.variance.clone()),
Bernoulli(b) => b.variance(),
Generic(u) => Ok(u.variance.clone()),
}
}
/// Extracts the range of this uncertain quantity, depending on its
/// distribution:
/// - A [`Uniform`] distribution's range is simply its interval.
/// - An [`Exponential`] distribution's range is `[offset, +∞)`.
/// - A [`Gaussian`] distribution's range is unbounded.
/// - A [`Bernoulli`] distribution's range is [`Interval::TRUE_OR_FALSE`], if
/// `p` is neither `0` nor `1`. Otherwise, it is [`Interval::FALSE`]
/// and [`Interval::TRUE`], respectively.
/// - A [`Generic`] distribution is unbounded by default, but more information
/// may be present.
pub fn range(&self) -> Result<Interval> {
match &self {
Uniform(u) => Ok(u.range().clone()),
Exponential(e) => e.range(),
Gaussian(g) => g.range(),
Bernoulli(b) => Ok(b.range()),
Generic(u) => Ok(u.range().clone()),
}
}
/// Returns the data type of the statistical parameters comprising this
/// distribution.
pub fn data_type(&self) -> DataType {
match &self {
Uniform(u) => u.data_type(),
Exponential(e) => e.data_type(),
Gaussian(g) => g.data_type(),
Bernoulli(b) => b.data_type(),
Generic(u) => u.data_type(),
}
}
pub fn target_type(args: &[&ScalarValue]) -> Result<DataType> {
let mut arg_types = args
.iter()
.filter(|&&arg| arg != &ScalarValue::Null)
.map(|&arg| arg.data_type());
let Some(dt) = arg_types.next().map_or_else(
|| Some(DataType::Null),
|first| {
arg_types
.try_fold(first, |target, arg| binary_numeric_coercion(&target, &arg))
},
) else {
return internal_err!("Can only evaluate statistics for numeric types");
};
Ok(dt)
}
}
/// Uniform distribution, represented by its range. If the given range extends
/// towards infinity, the distribution will be improper -- which is OK. For a
/// more in-depth discussion, see:
///
/// <https://en.wikipedia.org/wiki/Continuous_uniform_distribution>
/// <https://en.wikipedia.org/wiki/Prior_probability#Improper_priors>
#[derive(Clone, Debug, PartialEq)]
pub struct UniformDistribution {
interval: Interval,
}
/// Exponential distribution with an optional shift. The probability density
/// function (PDF) is defined as follows:
///
/// For a positive tail (when `positive_tail` is `true`):
///
/// `f(x; λ, offset) = λ exp(-λ (x - offset)) for x ≥ offset`
///
/// For a negative tail (when `positive_tail` is `false`):
///
/// `f(x; λ, offset) = λ exp(-λ (offset - x)) for x ≤ offset`
///
///
/// In both cases, the PDF is `0` outside the specified domain.
///
/// For more information, see:
///
/// <https://en.wikipedia.org/wiki/Exponential_distribution>
#[derive(Clone, Debug, PartialEq)]
pub struct ExponentialDistribution {
rate: ScalarValue,
offset: ScalarValue,
/// Indicates whether the exponential distribution has a positive tail;
/// i.e. it extends towards positive infinity.
positive_tail: bool,
}
/// Gaussian (normal) distribution, represented by its mean and variance.
/// For a more in-depth discussion, see:
///
/// <https://en.wikipedia.org/wiki/Normal_distribution>
#[derive(Clone, Debug, PartialEq)]
pub struct GaussianDistribution {
mean: ScalarValue,
variance: ScalarValue,
}
/// Bernoulli distribution with success probability `p`. If `p` has a null value,
/// the success probability is unknown. For a more in-depth discussion, see:
///
/// <https://en.wikipedia.org/wiki/Bernoulli_distribution>
#[derive(Clone, Debug, PartialEq)]
pub struct BernoulliDistribution {
p: ScalarValue,
}
/// A generic distribution whose functional form is not available, which is
/// approximated via some summary statistics. For a more in-depth discussion, see:
///
/// <https://en.wikipedia.org/wiki/Summary_statistics>
#[derive(Clone, Debug, PartialEq)]
pub struct GenericDistribution {
mean: ScalarValue,
median: ScalarValue,
variance: ScalarValue,
range: Interval,
}
impl UniformDistribution {
fn try_new(interval: Interval) -> Result<Self> {
assert_ne_or_internal_err!(
interval.data_type(),
DataType::Boolean,
"Construction of a boolean `Uniform` distribution is prohibited, create a `Bernoulli` distribution instead."
);
Ok(Self { interval })
}
pub fn data_type(&self) -> DataType {
self.interval.data_type()
}
/// Computes the mean value of this distribution. In case of improper
/// distributions (i.e. when the range is unbounded), the function returns
/// a `NULL` `ScalarValue`.
pub fn mean(&self) -> Result<ScalarValue> {
// TODO: Should we ensure that this always returns a real number data type?
let dt = self.data_type();
let two = ScalarValue::from(2).cast_to(&dt)?;
let result = self
.interval
.lower()
.add_checked(self.interval.upper())?
.div(two);
debug_assert!(
!self.interval.is_unbounded() || result.as_ref().is_ok_and(|r| r.is_null())
);
result
}
pub fn median(&self) -> Result<ScalarValue> {
self.mean()
}
/// Computes the variance value of this distribution. In case of improper
/// distributions (i.e. when the range is unbounded), the function returns
/// a `NULL` `ScalarValue`.
pub fn variance(&self) -> Result<ScalarValue> {
// TODO: Should we ensure that this always returns a real number data type?
let width = self.interval.width()?;
let dt = width.data_type();
let twelve = ScalarValue::from(12).cast_to(&dt)?;
let result = width.mul_checked(&width)?.div(twelve);
debug_assert!(
!self.interval.is_unbounded() || result.as_ref().is_ok_and(|r| r.is_null())
);
result
}
pub fn range(&self) -> &Interval {
&self.interval
}
}
impl ExponentialDistribution {
fn try_new(
rate: ScalarValue,
offset: ScalarValue,
positive_tail: bool,
) -> Result<Self> {
let dt = rate.data_type();
assert_eq_or_internal_err!(
offset.data_type(),
dt,
"Rate and offset must have the same data type"
);
assert_or_internal_err!(
!offset.is_null(),
"Offset of an `ExponentialDistribution` cannot be null"
);
assert_or_internal_err!(
!rate.is_null(),
"Rate of an `ExponentialDistribution` cannot be null"
);
let zero = ScalarValue::new_zero(&dt)?;
assert_or_internal_err!(
!rate.le(&zero),
"Rate of an `ExponentialDistribution` must be positive"
);
Ok(Self {
rate,
offset,
positive_tail,
})
}
pub fn data_type(&self) -> DataType {
self.rate.data_type()
}
pub fn rate(&self) -> &ScalarValue {
&self.rate
}
pub fn offset(&self) -> &ScalarValue {
&self.offset
}
pub fn positive_tail(&self) -> bool {
self.positive_tail
}
pub fn mean(&self) -> Result<ScalarValue> {
// TODO: Should we ensure that this always returns a real number data type?
let one = ScalarValue::new_one(&self.data_type())?;
let tail_mean = one.div(&self.rate)?;
if self.positive_tail {
self.offset.add_checked(tail_mean)
} else {
self.offset.sub_checked(tail_mean)
}
}
pub fn median(&self) -> Result<ScalarValue> {
// TODO: Should we ensure that this always returns a real number data type?
let ln_two = ScalarValue::from(LN_2).cast_to(&self.data_type())?;
let tail_median = ln_two.div(&self.rate)?;
if self.positive_tail {
self.offset.add_checked(tail_median)
} else {
self.offset.sub_checked(tail_median)
}
}
pub fn variance(&self) -> Result<ScalarValue> {
// TODO: Should we ensure that this always returns a real number data type?
let one = ScalarValue::new_one(&self.data_type())?;
let rate_squared = self.rate.mul_checked(&self.rate)?;
one.div(rate_squared)
}
pub fn range(&self) -> Result<Interval> {
let end = ScalarValue::try_from(&self.data_type())?;
if self.positive_tail {
Interval::try_new(self.offset.clone(), end)
} else {
Interval::try_new(end, self.offset.clone())
}
}
}
impl GaussianDistribution {
fn try_new(mean: ScalarValue, variance: ScalarValue) -> Result<Self> {
let dt = mean.data_type();
assert_eq_or_internal_err!(
variance.data_type(),
dt,
"Mean and variance must have the same data type"
);
assert_or_internal_err!(
!variance.is_null(),
"Variance of a `GaussianDistribution` cannot be null"
);
let zero = ScalarValue::new_zero(&dt)?;
assert_or_internal_err!(
!variance.lt(&zero),
"Variance of a `GaussianDistribution` must be positive"
);
Ok(Self { mean, variance })
}
pub fn data_type(&self) -> DataType {
self.mean.data_type()
}
pub fn mean(&self) -> &ScalarValue {
&self.mean
}
pub fn variance(&self) -> &ScalarValue {
&self.variance
}
pub fn median(&self) -> &ScalarValue {
self.mean()
}
pub fn range(&self) -> Result<Interval> {
Interval::make_unbounded(&self.data_type())
}
}
impl BernoulliDistribution {
fn try_new(p: ScalarValue) -> Result<Self> {
if p.is_null() {
return Ok(Self { p });
}
let dt = p.data_type();
let zero = ScalarValue::new_zero(&dt)?;
let one = ScalarValue::new_one(&dt)?;
assert_or_internal_err!(
p.ge(&zero) && p.le(&one),
"Success probability of a `BernoulliDistribution` must be in [0, 1]"
);
Ok(Self { p })
}
pub fn data_type(&self) -> DataType {
self.p.data_type()
}
pub fn p_value(&self) -> &ScalarValue {
&self.p
}
pub fn mean(&self) -> &ScalarValue {
&self.p
}
/// Computes the median value of this distribution. In case of an unknown
/// success probability, the function returns a `NULL` `ScalarValue`.
pub fn median(&self) -> Result<ScalarValue> {
let dt = self.data_type();
if self.p.is_null() {
ScalarValue::try_from(&dt)
} else {
let one = ScalarValue::new_one(&dt)?;
if one.sub_checked(&self.p)?.lt(&self.p) {
ScalarValue::new_one(&dt)
} else {
ScalarValue::new_zero(&dt)
}
}
}
/// Computes the variance value of this distribution. In case of an unknown
/// success probability, the function returns a `NULL` `ScalarValue`.
pub fn variance(&self) -> Result<ScalarValue> {
let dt = self.data_type();
let one = ScalarValue::new_one(&dt)?;
let result = one.sub_checked(&self.p)?.mul_checked(&self.p);
debug_assert!(!self.p.is_null() || result.as_ref().is_ok_and(|r| r.is_null()));
result
}
pub fn range(&self) -> Interval {
let dt = self.data_type();
// Unwraps are safe as the constructor guarantees that the data type
// supports zero and one values.
if ScalarValue::new_zero(&dt).unwrap().eq(&self.p) {
Interval::FALSE
} else if ScalarValue::new_one(&dt).unwrap().eq(&self.p) {
Interval::TRUE
} else {
Interval::TRUE_OR_FALSE
}
}
}
impl GenericDistribution {
fn try_new(
mean: ScalarValue,
median: ScalarValue,
variance: ScalarValue,
range: Interval,
) -> Result<Self> {
assert_ne_or_internal_err!(
range.data_type(),
DataType::Boolean,
"Construction of a boolean `Generic` distribution is prohibited, create a `Bernoulli` distribution instead."
);
let validate_location = |m: &ScalarValue| -> Result<bool> {
// Checks whether the given location estimate is within the range.
if m.is_null() {
Ok(true)
} else {
range.contains_value(m)
}
};
let locations_valid = validate_location(&mean)? && validate_location(&median)?;
let variance_non_negative = if variance.is_null() {
true
} else {
let zero = ScalarValue::new_zero(&variance.data_type())?;
!variance.lt(&zero)
};
assert_or_internal_err!(
locations_valid && variance_non_negative,
"Tried to construct an invalid `GenericDistribution` instance"
);
Ok(Self {
mean,
median,
variance,
range,
})
}
pub fn data_type(&self) -> DataType {
self.mean.data_type()
}
pub fn mean(&self) -> &ScalarValue {
&self.mean
}
pub fn median(&self) -> &ScalarValue {
&self.median
}
pub fn variance(&self) -> &ScalarValue {
&self.variance
}
pub fn range(&self) -> &Interval {
&self.range
}
}
/// This function takes a logical operator and two Bernoulli distributions,
/// and it returns a new Bernoulli distribution that represents the result of
/// the operation. Currently, only `AND` and `OR` operations are supported.
pub fn combine_bernoullis(
op: &Operator,
left: &BernoulliDistribution,
right: &BernoulliDistribution,
) -> Result<BernoulliDistribution> {
let left_p = left.p_value();
let right_p = right.p_value();
match op {
Operator::And => match (left_p.is_null(), right_p.is_null()) {
(false, false) => {
BernoulliDistribution::try_new(left_p.mul_checked(right_p)?)
}
(false, true) if left_p.eq(&ScalarValue::new_zero(&left_p.data_type())?) => {
Ok(left.clone())
}
(true, false)
if right_p.eq(&ScalarValue::new_zero(&right_p.data_type())?) =>
{
Ok(right.clone())
}
_ => {
let dt = Distribution::target_type(&[left_p, right_p])?;
BernoulliDistribution::try_new(ScalarValue::try_from(&dt)?)
}
},
Operator::Or => match (left_p.is_null(), right_p.is_null()) {
(false, false) => {
let sum = left_p.add_checked(right_p)?;
let product = left_p.mul_checked(right_p)?;
let or_success = sum.sub_checked(product)?;
BernoulliDistribution::try_new(or_success)
}
(false, true) if left_p.eq(&ScalarValue::new_one(&left_p.data_type())?) => {
Ok(left.clone())
}
(true, false) if right_p.eq(&ScalarValue::new_one(&right_p.data_type())?) => {
Ok(right.clone())
}
_ => {
let dt = Distribution::target_type(&[left_p, right_p])?;
BernoulliDistribution::try_new(ScalarValue::try_from(&dt)?)
}
},
_ => {
not_impl_err!("Statistical evaluation only supports AND and OR operators")
}
}
}
/// Applies the given operation to the given Gaussian distributions. Currently,
/// this function handles only addition and subtraction operations. If the
/// result is not a Gaussian random variable, it returns `None`. For details,
/// see:
///
/// <https://en.wikipedia.org/wiki/Sum_of_normally_distributed_random_variables>
pub fn combine_gaussians(
op: &Operator,
left: &GaussianDistribution,
right: &GaussianDistribution,
) -> Result<Option<GaussianDistribution>> {
match op {
Operator::Plus => GaussianDistribution::try_new(
left.mean().add_checked(right.mean())?,
left.variance().add_checked(right.variance())?,
)
.map(Some),
Operator::Minus => GaussianDistribution::try_new(
left.mean().sub_checked(right.mean())?,
left.variance().add_checked(right.variance())?,
)
.map(Some),
_ => Ok(None),
}
}
/// Creates a new `Bernoulli` distribution by computing the resulting probability.
/// Expects `op` to be a comparison operator, with `left` and `right` having
/// numeric distributions. The resulting distribution has the `Float64` data
/// type.
pub fn create_bernoulli_from_comparison(
op: &Operator,
left: &Distribution,
right: &Distribution,
) -> Result<Distribution> {
match (left, right) {
(Uniform(left), Uniform(right)) => {
match op {
Operator::Eq | Operator::NotEq => {
let (li, ri) = (left.range(), right.range());
if let Some(intersection) = li.intersect(ri)? {
// If the ranges are not disjoint, calculate the probability
// of equality using cardinalities:
if let (Some(lc), Some(rc), Some(ic)) = (
li.cardinality(),
ri.cardinality(),
intersection.cardinality(),
) {
// Avoid overflow by widening the type temporarily:
let pairs = ((lc as u128) * (rc as u128)) as f64;
let p = (ic as f64).div_checked(pairs)?;
// Alternative approach that may be more stable:
// let p = (ic as f64)
// .div_checked(lc as f64)?
// .div_checked(rc as f64)?;
let mut p_value = ScalarValue::from(p);
if op == &Operator::NotEq {
let one = ScalarValue::from(1.0);
p_value = alter_fp_rounding_mode::<false, _>(
&one,
&p_value,
|lhs, rhs| lhs.sub_checked(rhs),
)?;
};
return Distribution::new_bernoulli(p_value);
}
} else if op == &Operator::Eq {
// If the ranges are disjoint, probability of equality is 0.
return Distribution::new_bernoulli(ScalarValue::from(0.0));
} else {
// If the ranges are disjoint, probability of not-equality is 1.
return Distribution::new_bernoulli(ScalarValue::from(1.0));
}
}
Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq => {
// TODO: We can handle inequality operators and calculate a
// `p` value instead of falling back to an unknown Bernoulli
// distribution. Note that the strict and non-strict inequalities
// may require slightly different logic in case of real vs.
// integral data types.
}
_ => {}
}
}
(Gaussian(_), Gaussian(_)) => {
// TODO: We can handle Gaussian comparisons and calculate a `p` value
// instead of falling back to an unknown Bernoulli distribution.
}
_ => {}
}
let (li, ri) = (left.range()?, right.range()?);
let range_evaluation = apply_operator(op, &li, &ri)?;
if range_evaluation.eq(&Interval::FALSE) {
Distribution::new_bernoulli(ScalarValue::from(0.0))
} else if range_evaluation.eq(&Interval::TRUE) {
Distribution::new_bernoulli(ScalarValue::from(1.0))
} else if range_evaluation.eq(&Interval::TRUE_OR_FALSE) {
Distribution::new_bernoulli(ScalarValue::try_from(&DataType::Float64)?)
} else {
internal_err!("This function must be called with a comparison operator")
}
}
/// Creates a new [`Generic`] distribution that represents the result of the
/// given binary operation on two unknown quantities represented by their
/// [`Distribution`] objects. The function computes the mean, median and
/// variance if possible.
pub fn new_generic_from_binary_op(
op: &Operator,
left: &Distribution,
right: &Distribution,
) -> Result<Distribution> {
Distribution::new_generic(
compute_mean(op, left, right)?,
compute_median(op, left, right)?,
compute_variance(op, left, right)?,
apply_operator(op, &left.range()?, &right.range()?)?,
)
}
/// Computes the mean value for the result of the given binary operation on
/// two unknown quantities represented by their [`Distribution`] objects.
pub fn compute_mean(
op: &Operator,
left: &Distribution,
right: &Distribution,
) -> Result<ScalarValue> {
let (left_mean, right_mean) = (left.mean()?, right.mean()?);
match op {
Operator::Plus => return left_mean.add_checked(right_mean),
Operator::Minus => return left_mean.sub_checked(right_mean),
// Note the independence assumption below:
Operator::Multiply => return left_mean.mul_checked(right_mean),
// TODO: We can calculate the mean for division when we support reciprocals,
// or know the distributions of the operands. For details, see:
//
// <https://en.wikipedia.org/wiki/Algebra_of_random_variables>
// <https://stats.stackexchange.com/questions/185683/distribution-of-ratio-between-two-independent-uniform-random-variables>
//
// Fall back to an unknown mean value for division:
Operator::Divide => {}
// Fall back to an unknown mean value for other cases:
_ => {}
}
let target_type = Distribution::target_type(&[&left_mean, &right_mean])?;
ScalarValue::try_from(target_type)
}
/// Computes the median value for the result of the given binary operation on
/// two unknown quantities represented by its [`Distribution`] objects. Currently,
/// the median is calculable only for addition and subtraction operations on:
/// - [`Uniform`] and [`Uniform`] distributions, and
/// - [`Gaussian`] and [`Gaussian`] distributions.
pub fn compute_median(
op: &Operator,
left: &Distribution,
right: &Distribution,
) -> Result<ScalarValue> {
match (left, right) {
(Uniform(lu), Uniform(ru)) => {
let (left_median, right_median) = (lu.median()?, ru.median()?);
// Under the independence assumption, the result is a symmetric
// triangular distribution, so we can simply add/subtract the
// median values:
match op {
Operator::Plus => return left_median.add_checked(right_median),
Operator::Minus => return left_median.sub_checked(right_median),
// Fall back to an unknown median value for other cases:
_ => {}
}
}
// Under the independence assumption, the result is another Gaussian
// distribution, so we can simply add/subtract the median values:
(Gaussian(lg), Gaussian(rg)) => match op {
Operator::Plus => return lg.mean().add_checked(rg.mean()),
Operator::Minus => return lg.mean().sub_checked(rg.mean()),
// Fall back to an unknown median value for other cases:
_ => {}
},
// Fall back to an unknown median value for other cases:
_ => {}
}
let (left_median, right_median) = (left.median()?, right.median()?);
let target_type = Distribution::target_type(&[&left_median, &right_median])?;
ScalarValue::try_from(target_type)
}
/// Computes the variance value for the result of the given binary operation on
/// two unknown quantities represented by their [`Distribution`] objects.
pub fn compute_variance(
op: &Operator,
left: &Distribution,
right: &Distribution,
) -> Result<ScalarValue> {
let (left_variance, right_variance) = (left.variance()?, right.variance()?);
match op {
// Note the independence assumption below:
Operator::Plus => return left_variance.add_checked(right_variance),
// Note the independence assumption below:
Operator::Minus => return left_variance.add_checked(right_variance),
// Note the independence assumption below:
Operator::Multiply => {
// For more details, along with an explanation of the formula below, see:
//
// <https://en.wikipedia.org/wiki/Distribution_of_the_product_of_two_random_variables>
let (left_mean, right_mean) = (left.mean()?, right.mean()?);
let left_mean_sq = left_mean.mul_checked(&left_mean)?;
let right_mean_sq = right_mean.mul_checked(&right_mean)?;
let left_sos = left_variance.add_checked(&left_mean_sq)?;
let right_sos = right_variance.add_checked(&right_mean_sq)?;
let pos = left_mean_sq.mul_checked(right_mean_sq)?;
return left_sos.mul_checked(right_sos)?.sub_checked(pos);
}
// TODO: We can calculate the variance for division when we support reciprocals,
// or know the distributions of the operands. For details, see:
//
// <https://en.wikipedia.org/wiki/Algebra_of_random_variables>
// <https://stats.stackexchange.com/questions/185683/distribution-of-ratio-between-two-independent-uniform-random-variables>
//
// Fall back to an unknown variance value for division:
Operator::Divide => {}
// Fall back to an unknown variance value for other cases:
_ => {}
}
let target_type = Distribution::target_type(&[&left_variance, &right_variance])?;
ScalarValue::try_from(target_type)
}
#[cfg(test)]
mod tests {
use super::{
BernoulliDistribution, Distribution, GaussianDistribution, UniformDistribution,
combine_bernoullis, combine_gaussians, compute_mean, compute_median,
compute_variance, create_bernoulli_from_comparison, new_generic_from_binary_op,
};
use crate::interval_arithmetic::{Interval, apply_operator};
use crate::operator::Operator;
use arrow::datatypes::DataType;
use datafusion_common::{HashSet, Result, ScalarValue};
#[test]
fn uniform_dist_is_valid_test() -> Result<()> {
assert_eq!(
Distribution::new_uniform(Interval::make_zero(&DataType::Int8)?)?,
Distribution::Uniform(UniformDistribution {
interval: Interval::make_zero(&DataType::Int8)?,
})
);
assert!(Distribution::new_uniform(Interval::TRUE_OR_FALSE).is_err());
Ok(())
}
#[test]
fn exponential_dist_is_valid_test() {
// This array collects test cases of the form (distribution, validity).
let exponentials = vec![
(
Distribution::new_exponential(ScalarValue::Null, ScalarValue::Null, true),
false,
),
(
Distribution::new_exponential(
ScalarValue::from(0_f32),
ScalarValue::from(1_f32),
true,
),
false,
),
(
Distribution::new_exponential(
ScalarValue::from(100_f32),
ScalarValue::from(1_f32),
true,
),
true,
),
(
Distribution::new_exponential(
ScalarValue::from(-100_f32),
ScalarValue::from(1_f32),
true,
),
false,
),
];
for case in exponentials {
assert_eq!(case.0.is_ok(), case.1);
}
}
#[test]
fn gaussian_dist_is_valid_test() {
// This array collects test cases of the form (distribution, validity).
let gaussians = vec![
(
Distribution::new_gaussian(ScalarValue::Null, ScalarValue::Null),
false,
),
(
Distribution::new_gaussian(
ScalarValue::from(0_f32),
ScalarValue::from(0_f32),
),
true,
),
(
Distribution::new_gaussian(
ScalarValue::from(0_f32),
ScalarValue::from(0.5_f32),
),
true,
),
(
Distribution::new_gaussian(
ScalarValue::from(0_f32),
ScalarValue::from(-0.5_f32),
),
false,
),
];
for case in gaussians {
assert_eq!(case.0.is_ok(), case.1);
}
}
#[test]
fn bernoulli_dist_is_valid_test() {
// This array collects test cases of the form (distribution, validity).
let bernoullis = vec![
(Distribution::new_bernoulli(ScalarValue::Null), true),
(Distribution::new_bernoulli(ScalarValue::from(0.)), true),
(Distribution::new_bernoulli(ScalarValue::from(0.25)), true),
(Distribution::new_bernoulli(ScalarValue::from(1.)), true),
(Distribution::new_bernoulli(ScalarValue::from(11.)), false),
(Distribution::new_bernoulli(ScalarValue::from(-11.)), false),
(Distribution::new_bernoulli(ScalarValue::from(0_i64)), true),
(Distribution::new_bernoulli(ScalarValue::from(1_i64)), true),
(
Distribution::new_bernoulli(ScalarValue::from(11_i64)),
false,
),
(
Distribution::new_bernoulli(ScalarValue::from(-11_i64)),
false,
),
];
for case in bernoullis {
assert_eq!(case.0.is_ok(), case.1);
}
}
#[test]
fn generic_dist_is_valid_test() -> Result<()> {
// This array collects test cases of the form (distribution, validity).
let generic_dists = vec![
// Using a boolean range to construct a Generic distribution is prohibited.
(
Distribution::new_generic(
ScalarValue::Null,
ScalarValue::Null,
ScalarValue::Null,
Interval::TRUE_OR_FALSE,
),
false,
),
(
Distribution::new_generic(
ScalarValue::Null,
ScalarValue::Null,
ScalarValue::Null,
Interval::make_zero(&DataType::Float32)?,
),
true,
),
(
Distribution::new_generic(
ScalarValue::from(0_f32),
ScalarValue::Float32(None),
ScalarValue::Float32(None),
Interval::make_zero(&DataType::Float32)?,
),
true,
),
(
Distribution::new_generic(
ScalarValue::Float64(None),
ScalarValue::from(0.),
ScalarValue::Float64(None),
Interval::make_zero(&DataType::Float32)?,
),
true,
),
(
Distribution::new_generic(
ScalarValue::from(-10_f32),
ScalarValue::Float32(None),
ScalarValue::Float32(None),
Interval::make_zero(&DataType::Float32)?,
),
false,
),
(
Distribution::new_generic(
ScalarValue::Float32(None),
ScalarValue::from(10_f32),
ScalarValue::Float32(None),
Interval::make_zero(&DataType::Float32)?,
),
false,
),
(
Distribution::new_generic(
ScalarValue::Null,
ScalarValue::Null,
ScalarValue::Null,
Interval::make_zero(&DataType::Float32)?,
),
true,
),
(
Distribution::new_generic(
ScalarValue::from(0),
ScalarValue::from(0),
ScalarValue::Int32(None),
Interval::make_zero(&DataType::Int32)?,
),
true,
),
(
Distribution::new_generic(
ScalarValue::from(0_f32),
ScalarValue::from(0_f32),
ScalarValue::Float32(None),
Interval::make_zero(&DataType::Float32)?,
),
true,
),
(
Distribution::new_generic(
ScalarValue::from(50.),
ScalarValue::from(50.),
ScalarValue::Float64(None),
Interval::make(Some(0.), Some(100.))?,
),
true,
),
(
Distribution::new_generic(
ScalarValue::from(50.),
ScalarValue::from(50.),
ScalarValue::Float64(None),
Interval::make(Some(-100.), Some(0.))?,
),
false,
),
(
Distribution::new_generic(
ScalarValue::Float64(None),
ScalarValue::Float64(None),
ScalarValue::from(1.),
Interval::make_zero(&DataType::Float64)?,
),
true,
),
(
Distribution::new_generic(
ScalarValue::Float64(None),
ScalarValue::Float64(None),
ScalarValue::from(-1.),
Interval::make_zero(&DataType::Float64)?,
),
false,
),
];
for case in generic_dists {
assert_eq!(case.0.is_ok(), case.1, "{:?}", case.0);
}
Ok(())
}
#[test]
fn mean_extraction_test() -> Result<()> {
// This array collects test cases of the form (distribution, mean value).
let dists = vec![
(
Distribution::new_uniform(Interval::make_zero(&DataType::Int64)?),
ScalarValue::from(0_i64),
),
(
Distribution::new_uniform(Interval::make_zero(&DataType::Float64)?),
ScalarValue::from(0.),
),
(
Distribution::new_uniform(Interval::make(Some(1), Some(100))?),
ScalarValue::from(50),
),
(
Distribution::new_uniform(Interval::make(Some(-100), Some(-1))?),
ScalarValue::from(-50),
),
(
Distribution::new_uniform(Interval::make(Some(-100), Some(100))?),
ScalarValue::from(0),
),
(
Distribution::new_exponential(
ScalarValue::from(2.),
ScalarValue::from(0.),
true,
),
ScalarValue::from(0.5),
),
(
Distribution::new_exponential(
ScalarValue::from(2.),
ScalarValue::from(1.),
true,
),
ScalarValue::from(1.5),
),
(
Distribution::new_gaussian(ScalarValue::from(0.), ScalarValue::from(1.)),
ScalarValue::from(0.),
),
(
Distribution::new_gaussian(
ScalarValue::from(-2.),
ScalarValue::from(0.5),
),
ScalarValue::from(-2.),
),
(
Distribution::new_bernoulli(ScalarValue::from(0.5)),
ScalarValue::from(0.5),
),
(
Distribution::new_generic(
ScalarValue::from(42.),
ScalarValue::from(42.),
ScalarValue::Float64(None),
Interval::make(Some(25.), Some(50.))?,
),
ScalarValue::from(42.),
),
];
for case in dists {
assert_eq!(case.0?.mean()?, case.1);
}
Ok(())
}
#[test]
fn median_extraction_test() -> Result<()> {
// This array collects test cases of the form (distribution, median value).
let dists = vec![
(
Distribution::new_uniform(Interval::make_zero(&DataType::Int64)?),
ScalarValue::from(0_i64),
),
(
Distribution::new_uniform(Interval::make(Some(25.), Some(75.))?),
ScalarValue::from(50.),
),
(
Distribution::new_exponential(
ScalarValue::from(2_f64.ln()),
ScalarValue::from(0.),
true,
),
ScalarValue::from(1.),
),
(
Distribution::new_gaussian(ScalarValue::from(2.), ScalarValue::from(1.)),
ScalarValue::from(2.),
),
(
Distribution::new_bernoulli(ScalarValue::from(0.25)),
ScalarValue::from(0.),
),
(
Distribution::new_bernoulli(ScalarValue::from(0.75)),
ScalarValue::from(1.),
),
(
Distribution::new_gaussian(ScalarValue::from(2.), ScalarValue::from(1.)),
ScalarValue::from(2.),
),
(
Distribution::new_generic(
ScalarValue::from(12.),
ScalarValue::from(12.),
ScalarValue::Float64(None),
Interval::make(Some(0.), Some(25.))?,
),
ScalarValue::from(12.),
),
];
for case in dists {
assert_eq!(case.0?.median()?, case.1);
}
Ok(())
}
#[test]
fn variance_extraction_test() -> Result<()> {
// This array collects test cases of the form (distribution, variance value).
let dists = vec![
(
Distribution::new_uniform(Interval::make(Some(0.), Some(12.))?),
ScalarValue::from(12.),
),
(
Distribution::new_exponential(
ScalarValue::from(10.),
ScalarValue::from(0.),
true,
),
ScalarValue::from(0.01),
),
(
Distribution::new_gaussian(ScalarValue::from(0.), ScalarValue::from(1.)),
ScalarValue::from(1.),
),
(
Distribution::new_bernoulli(ScalarValue::from(0.5)),
ScalarValue::from(0.25),
),
(
Distribution::new_generic(
ScalarValue::Float64(None),
ScalarValue::Float64(None),
ScalarValue::from(0.02),
Interval::make_zero(&DataType::Float64)?,
),
ScalarValue::from(0.02),
),
];
for case in dists {
assert_eq!(case.0?.variance()?, case.1);
}
Ok(())
}
#[test]
fn test_calculate_generic_properties_gauss_gauss() -> Result<()> {
let dist_a =
Distribution::new_gaussian(ScalarValue::from(10.), ScalarValue::from(0.0))?;
let dist_b =
Distribution::new_gaussian(ScalarValue::from(20.), ScalarValue::from(0.0))?;
let test_data = vec![
// Mean:
(
compute_mean(&Operator::Plus, &dist_a, &dist_b)?,
ScalarValue::from(30.),
),
(
compute_mean(&Operator::Minus, &dist_a, &dist_b)?,
ScalarValue::from(-10.),
),
// Median:
(
compute_median(&Operator::Plus, &dist_a, &dist_b)?,
ScalarValue::from(30.),
),
(
compute_median(&Operator::Minus, &dist_a, &dist_b)?,
ScalarValue::from(-10.),
),
];
for (actual, expected) in test_data {
assert_eq!(actual, expected);
}
Ok(())
}
#[test]
fn test_combine_bernoullis_and_op() -> Result<()> {
let op = Operator::And;
let left = BernoulliDistribution::try_new(ScalarValue::from(0.5))?;
let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
let left_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
let right_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
assert_eq!(
combine_bernoullis(&op, &left, &right)?.p_value(),
&ScalarValue::from(0.5 * 0.4)
);
assert_eq!(
combine_bernoullis(&op, &left_null, &right)?.p_value(),
&ScalarValue::Float64(None)
);
assert_eq!(
combine_bernoullis(&op, &left, &right_null)?.p_value(),
&ScalarValue::Float64(None)
);
assert_eq!(
combine_bernoullis(&op, &left_null, &left_null)?.p_value(),
&ScalarValue::Null
);
Ok(())
}
#[test]
fn test_combine_bernoullis_or_op() -> Result<()> {
let op = Operator::Or;
let left = BernoulliDistribution::try_new(ScalarValue::from(0.6))?;
let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
let left_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
let right_null = BernoulliDistribution::try_new(ScalarValue::Null)?;
assert_eq!(
combine_bernoullis(&op, &left, &right)?.p_value(),
&ScalarValue::from(0.6 + 0.4 - (0.6 * 0.4))
);
assert_eq!(
combine_bernoullis(&op, &left_null, &right)?.p_value(),
&ScalarValue::Float64(None)
);
assert_eq!(
combine_bernoullis(&op, &left, &right_null)?.p_value(),
&ScalarValue::Float64(None)
);
assert_eq!(
combine_bernoullis(&op, &left_null, &left_null)?.p_value(),
&ScalarValue::Null
);
Ok(())
}
#[test]
fn test_combine_bernoullis_unsupported_ops() -> Result<()> {
let mut operator_set = operator_set();
operator_set.remove(&Operator::And);
operator_set.remove(&Operator::Or);
let left = BernoulliDistribution::try_new(ScalarValue::from(0.6))?;
let right = BernoulliDistribution::try_new(ScalarValue::from(0.4))?;
for op in operator_set {
assert!(
combine_bernoullis(&op, &left, &right).is_err(),
"Operator {op} should not be supported for Bernoulli distributions"
);
}
Ok(())
}
#[test]
fn test_combine_gaussians_addition() -> Result<()> {
let op = Operator::Plus;
let left = GaussianDistribution::try_new(
ScalarValue::from(3.0),
ScalarValue::from(2.0),
)?;
let right = GaussianDistribution::try_new(
ScalarValue::from(4.0),
ScalarValue::from(1.0),
)?;
let result = combine_gaussians(&op, &left, &right)?.unwrap();
assert_eq!(result.mean(), &ScalarValue::from(7.0)); // 3.0 + 4.0
assert_eq!(result.variance(), &ScalarValue::from(3.0)); // 2.0 + 1.0
Ok(())
}
#[test]
fn test_combine_gaussians_subtraction() -> Result<()> {
let op = Operator::Minus;
let left = GaussianDistribution::try_new(
ScalarValue::from(7.0),
ScalarValue::from(2.0),
)?;
let right = GaussianDistribution::try_new(
ScalarValue::from(4.0),
ScalarValue::from(1.0),
)?;
let result = combine_gaussians(&op, &left, &right)?.unwrap();
assert_eq!(result.mean(), &ScalarValue::from(3.0)); // 7.0 - 4.0
assert_eq!(result.variance(), &ScalarValue::from(3.0)); // 2.0 + 1.0
Ok(())
}
#[test]
fn test_combine_gaussians_unsupported_ops() -> Result<()> {
let mut operator_set = operator_set();
operator_set.remove(&Operator::Plus);
operator_set.remove(&Operator::Minus);
let left = GaussianDistribution::try_new(
ScalarValue::from(7.0),
ScalarValue::from(2.0),
)?;
let right = GaussianDistribution::try_new(
ScalarValue::from(4.0),
ScalarValue::from(1.0),
)?;
for op in operator_set {
assert!(
combine_gaussians(&op, &left, &right)?.is_none(),
"Operator {op} should not be supported for Gaussian distributions"
);
}
Ok(())
}
// Expected test results were calculated in Wolfram Mathematica, by using:
//
// *METHOD_NAME*[TransformedDistribution[
// x *op* y,
// {x ~ *DISTRIBUTION_X*[..], y ~ *DISTRIBUTION_Y*[..]}
// ]]
#[test]
fn test_calculate_generic_properties_uniform_uniform() -> Result<()> {
let dist_a = Distribution::new_uniform(Interval::make(Some(0.), Some(12.))?)?;
let dist_b = Distribution::new_uniform(Interval::make(Some(12.), Some(36.))?)?;
let test_data = vec![
// Mean:
(
compute_mean(&Operator::Plus, &dist_a, &dist_b)?,
ScalarValue::from(30.),
),
(
compute_mean(&Operator::Minus, &dist_a, &dist_b)?,
ScalarValue::from(-18.),
),
(
compute_mean(&Operator::Multiply, &dist_a, &dist_b)?,
ScalarValue::from(144.),
),
// Median:
(
compute_median(&Operator::Plus, &dist_a, &dist_b)?,
ScalarValue::from(30.),
),
(
compute_median(&Operator::Minus, &dist_a, &dist_b)?,
ScalarValue::from(-18.),
),
// Variance:
(
compute_variance(&Operator::Plus, &dist_a, &dist_b)?,
ScalarValue::from(60.),
),
(
compute_variance(&Operator::Minus, &dist_a, &dist_b)?,
ScalarValue::from(60.),
),
(
compute_variance(&Operator::Multiply, &dist_a, &dist_b)?,
ScalarValue::from(9216.),
),
];
for (actual, expected) in test_data {
assert_eq!(actual, expected);
}
Ok(())
}
/// Test for `Uniform`-`Uniform`, `Uniform`-`Generic`, `Generic`-`Uniform`,
/// `Generic`-`Generic` pairs, where range is always present.
#[test]
fn test_compute_range_where_present() -> Result<()> {
let a = &Interval::make(Some(0.), Some(12.0))?;
let b = &Interval::make(Some(0.), Some(12.0))?;
let mean = ScalarValue::from(6.0);
for (dist_a, dist_b) in [
(
Distribution::new_uniform(a.clone())?,
Distribution::new_uniform(b.clone())?,
),
(
Distribution::new_generic(
mean.clone(),
mean.clone(),
ScalarValue::Float64(None),
a.clone(),
)?,
Distribution::new_uniform(b.clone())?,
),
(
Distribution::new_uniform(a.clone())?,
Distribution::new_generic(
mean.clone(),
mean.clone(),
ScalarValue::Float64(None),
b.clone(),
)?,
),
(
Distribution::new_generic(
mean.clone(),
mean.clone(),
ScalarValue::Float64(None),
a.clone(),
)?,
Distribution::new_generic(
mean.clone(),
mean.clone(),
ScalarValue::Float64(None),
b.clone(),
)?,
),
] {
use super::Operator::{
Divide, Eq, Gt, GtEq, Lt, LtEq, Minus, Multiply, NotEq, Plus,
};
for op in [Plus, Minus, Multiply, Divide] {
assert_eq!(
new_generic_from_binary_op(&op, &dist_a, &dist_b)?.range()?,
apply_operator(&op, a, b)?,
"Failed for {dist_a:?} {op} {dist_b:?}"
);
}
for op in [Gt, GtEq, Lt, LtEq, Eq, NotEq] {
assert_eq!(
create_bernoulli_from_comparison(&op, &dist_a, &dist_b)?.range()?,
apply_operator(&op, a, b)?,
"Failed for {dist_a:?} {op} {dist_b:?}"
);
}
}
Ok(())
}
fn operator_set() -> HashSet<Operator> {
use super::Operator::*;
let all_ops = vec![
And,
Or,
Eq,
NotEq,
Gt,
GtEq,
Lt,
LtEq,
Plus,
Minus,
Multiply,
Divide,
Modulo,
IsDistinctFrom,
IsNotDistinctFrom,
RegexMatch,
RegexIMatch,
RegexNotMatch,
RegexNotIMatch,
LikeMatch,
ILikeMatch,
NotLikeMatch,
NotILikeMatch,
BitwiseAnd,
BitwiseOr,
BitwiseXor,
BitwiseShiftRight,
BitwiseShiftLeft,
StringConcat,
AtArrow,
ArrowAt,
];
all_ops.into_iter().collect()
}
}