Skip to content

Commit

Permalink
Migrate documentation for regr* aggregate functions to code (#12871)
Browse files Browse the repository at this point in the history
* Migrate documentation for regr* functions to code

* Fix double expression

* Fix logical conflict
  • Loading branch information
alamb authored Oct 22, 2024
1 parent ef1365a commit 227908f
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 205 deletions.
187 changes: 153 additions & 34 deletions datafusion/functions-aggregate/src/regr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@

//! Defines physical expressions that can evaluated at runtime during query execution
use std::any::Any;
use std::fmt::Debug;

use arrow::array::Float64Array;
use arrow::{
array::{ArrayRef, UInt64Array},
Expand All @@ -29,10 +26,17 @@ use arrow::{
};
use datafusion_common::{downcast_value, plan_err, unwrap_or_internal_err, ScalarValue};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL;
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
use datafusion_expr::{
Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
};
use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::OnceLock;

macro_rules! make_regr_udaf_expr_and_func {
($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => {
Expand Down Expand Up @@ -76,23 +80,7 @@ impl Regr {
}
}

/*
#[derive(Debug)]
pub struct Regr {
name: String,
regr_type: RegrType,
expr_y: Arc<dyn PhysicalExpr>,
expr_x: Arc<dyn PhysicalExpr>,
}
impl Regr {
pub fn get_regr_type(&self) -> RegrType {
self.regr_type.clone()
}
}
*/

#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Hash, Eq)]
#[allow(clippy::upper_case_acronyms)]
pub enum RegrType {
/// Variant for `regr_slope` aggregate expression
Expand Down Expand Up @@ -135,6 +123,148 @@ pub enum RegrType {
SXY,
}

impl RegrType {
/// return the documentation for the `RegrType`
fn documentation(&self) -> Option<&Documentation> {
get_regr_docs().get(self)
}
}

static DOCUMENTATION: OnceLock<HashMap<RegrType, Documentation>> = OnceLock::new();
fn get_regr_docs() -> &'static HashMap<RegrType, Documentation> {
DOCUMENTATION.get_or_init(|| {
let mut hash_map = HashMap::new();
hash_map.insert(
RegrType::Slope,
Documentation::builder()
.with_doc_section(DOC_SECTION_STATISTICAL)
.with_description(
"Returns the slope of the linear regression line for non-null pairs in aggregate columns. \
Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k*X + b) using minimal RSS fitting.",
)
.with_syntax_example("regr_slope(expression_y, expression_x)")
.with_standard_argument("expression_y", Some("Dependent variable"))
.with_standard_argument("expression_x", Some("Independent variable"))
.build()
.unwrap()
);

hash_map.insert(
RegrType::Intercept,
Documentation::builder()
.with_doc_section(DOC_SECTION_STATISTICAL)
.with_description(
"Computes the y-intercept of the linear regression line. For the equation (y = kx + b), \
this function returns b.",
)
.with_syntax_example("regr_intercept(expression_y, expression_x)")
.with_standard_argument("expression_y", Some("Dependent variable"))
.with_standard_argument("expression_x", Some("Independent variable"))
.build()
.unwrap()
);

hash_map.insert(
RegrType::Count,
Documentation::builder()
.with_doc_section(DOC_SECTION_STATISTICAL)
.with_description(
"Counts the number of non-null paired data points.",
)
.with_syntax_example("regr_count(expression_y, expression_x)")
.with_standard_argument("expression_y", Some("Dependent variable"))
.with_standard_argument("expression_x", Some("Independent variable"))
.build()
.unwrap()
);

hash_map.insert(
RegrType::R2,
Documentation::builder()
.with_doc_section(DOC_SECTION_STATISTICAL)
.with_description(
"Computes the square of the correlation coefficient between the independent and dependent variables.",
)
.with_syntax_example("regr_r2(expression_y, expression_x)")
.with_standard_argument("expression_y", Some("Dependent variable"))
.with_standard_argument("expression_x", Some("Independent variable"))
.build()
.unwrap()
);

hash_map.insert(
RegrType::AvgX,
Documentation::builder()
.with_doc_section(DOC_SECTION_STATISTICAL)
.with_description(
"Computes the average of the independent variable (input) expression_x for the non-null paired data points.",
)
.with_syntax_example("regr_avgx(expression_y, expression_x)")
.with_standard_argument("expression_y", Some("Dependent variable"))
.with_standard_argument("expression_x", Some("Independent variable"))
.build()
.unwrap()
);

hash_map.insert(
RegrType::AvgY,
Documentation::builder()
.with_doc_section(DOC_SECTION_STATISTICAL)
.with_description(
"Computes the average of the dependent variable (output) expression_y for the non-null paired data points.",
)
.with_syntax_example("regr_avgy(expression_y, expression_x)")
.with_standard_argument("expression_y", Some("Dependent variable"))
.with_standard_argument("expression_x", Some("Independent variable"))
.build()
.unwrap()
);

hash_map.insert(
RegrType::SXX,
Documentation::builder()
.with_doc_section(DOC_SECTION_STATISTICAL)
.with_description(
"Computes the sum of squares of the independent variable.",
)
.with_syntax_example("regr_sxx(expression_y, expression_x)")
.with_standard_argument("expression_y", Some("Dependent variable"))
.with_standard_argument("expression_x", Some("Independent variable"))
.build()
.unwrap()
);

hash_map.insert(
RegrType::SYY,
Documentation::builder()
.with_doc_section(DOC_SECTION_STATISTICAL)
.with_description(
"Computes the sum of squares of the dependent variable.",
)
.with_syntax_example("regr_syy(expression_y, expression_x)")
.with_standard_argument("expression_y", Some("Dependent variable"))
.with_standard_argument("expression_x", Some("Independent variable"))
.build()
.unwrap()
);

hash_map.insert(
RegrType::SXY,
Documentation::builder()
.with_doc_section(DOC_SECTION_STATISTICAL)
.with_description(
"Computes the sum of products of paired data points.",
)
.with_syntax_example("regr_sxy(expression_y, expression_x)")
.with_standard_argument("expression_y", Some("Dependent variable"))
.with_standard_argument("expression_x", Some("Independent variable"))
.build()
.unwrap()
);
hash_map
})
}

impl AggregateUDFImpl for Regr {
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -198,22 +328,11 @@ impl AggregateUDFImpl for Regr {
),
])
}
}

/*
impl PartialEq<dyn Any> for Regr {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.name == x.name
&& self.expr_y.eq(&x.expr_y)
&& self.expr_x.eq(&x.expr_x)
})
.unwrap_or(false)
fn documentation(&self) -> Option<&Documentation> {
self.regr_type.documentation()
}
}
*/

/// `RegrAccumulator` is used to compute linear regression aggregate functions
/// by maintaining statistics needed to compute them in an online fashion.
Expand Down
Loading

0 comments on commit 227908f

Please sign in to comment.