Skip to content

Commit 2773d87

Browse files
migrate regr_* functions to UDAF
Ref: apache/datafusion#10898
1 parent 1e21ba7 commit 2773d87

File tree

2 files changed

+99
-18
lines changed

2 files changed

+99
-18
lines changed

python/datafusion/functions.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,50 +1376,50 @@ def regr_avgx(y: Expr, x: Expr, distinct: bool = False) -> Expr:
13761376
13771377
Only non-null pairs of the inputs are evaluated.
13781378
"""
1379-
return Expr(f.regr_avgx[y.expr, x.expr], distinct)
1379+
return Expr(f.regr_avgx(y.expr, x.expr, distinct))
13801380

13811381

13821382
def regr_avgy(y: Expr, x: Expr, distinct: bool = False) -> Expr:
13831383
"""Computes the average of the dependent variable ``y``.
13841384
13851385
Only non-null pairs of the inputs are evaluated.
13861386
"""
1387-
return Expr(f.regr_avgy[y.expr, x.expr], distinct)
1387+
return Expr(f.regr_avgy(y.expr, x.expr, distinct))
13881388

13891389

13901390
def regr_count(y: Expr, x: Expr, distinct: bool = False) -> Expr:
13911391
"""Counts the number of rows in which both expressions are not null."""
1392-
return Expr(f.regr_count[y.expr, x.expr], distinct)
1392+
return Expr(f.regr_count(y.expr, x.expr, distinct))
13931393

13941394

13951395
def regr_intercept(y: Expr, x: Expr, distinct: bool = False) -> Expr:
13961396
"""Computes the intercept from the linear regression."""
1397-
return Expr(f.regr_intercept[y.expr, x.expr], distinct)
1397+
return Expr(f.regr_intercept(y.expr, x.expr, distinct))
13981398

13991399

14001400
def regr_r2(y: Expr, x: Expr, distinct: bool = False) -> Expr:
14011401
"""Computes the R-squared value from linear regression."""
1402-
return Expr(f.regr_r2[y.expr, x.expr], distinct)
1402+
return Expr(f.regr_r2(y.expr, x.expr, distinct))
14031403

14041404

14051405
def regr_slope(y: Expr, x: Expr, distinct: bool = False) -> Expr:
14061406
"""Computes the slope from linear regression."""
1407-
return Expr(f.regr_slope[y.expr, x.expr], distinct)
1407+
return Expr(f.regr_slope(y.expr, x.expr, distinct))
14081408

14091409

14101410
def regr_sxx(y: Expr, x: Expr, distinct: bool = False) -> Expr:
14111411
"""Computes the sum of squares of the independent variable `x`."""
1412-
return Expr(f.regr_sxx[y.expr, x.expr], distinct)
1412+
return Expr(f.regr_sxx(y.expr, x.expr, distinct))
14131413

14141414

14151415
def regr_sxy(y: Expr, x: Expr, distinct: bool = False) -> Expr:
14161416
"""Computes the sum of products of pairs of numbers."""
1417-
return Expr(f.regr_sxy[y.expr, x.expr], distinct)
1417+
return Expr(f.regr_sxy(y.expr, x.expr, distinct))
14181418

14191419

14201420
def regr_syy(y: Expr, x: Expr, distinct: bool = False) -> Expr:
14211421
"""Computes the sum of squares of the dependent variable `y`."""
1422-
return Expr(f.regr_syy[y.expr, x.expr], distinct)
1422+
return Expr(f.regr_syy(y.expr, x.expr, distinct))
14231423

14241424

14251425
def first_value(

src/functions.rs

Lines changed: 90 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,96 @@ pub fn var_pop(expression: PyExpr, distinct: bool) -> PyResult<PyExpr> {
190190
}
191191
}
192192

193+
#[pyfunction]
194+
pub fn regr_avgx(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
195+
let expr = functions_aggregate::expr_fn::regr_avgx(expr_y.expr, expr_x.expr);
196+
if distinct {
197+
Ok(expr.distinct().build()?.into())
198+
} else {
199+
Ok(expr.into())
200+
}
201+
}
202+
203+
#[pyfunction]
204+
pub fn regr_avgy(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
205+
let expr = functions_aggregate::expr_fn::regr_avgy(expr_y.expr, expr_x.expr);
206+
if distinct {
207+
Ok(expr.distinct().build()?.into())
208+
} else {
209+
Ok(expr.into())
210+
}
211+
}
212+
213+
#[pyfunction]
214+
pub fn regr_count(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
215+
let expr = functions_aggregate::expr_fn::regr_count(expr_y.expr, expr_x.expr);
216+
if distinct {
217+
Ok(expr.distinct().build()?.into())
218+
} else {
219+
Ok(expr.into())
220+
}
221+
}
222+
223+
#[pyfunction]
224+
pub fn regr_intercept(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
225+
let expr = functions_aggregate::expr_fn::regr_intercept(expr_y.expr, expr_x.expr);
226+
if distinct {
227+
Ok(expr.distinct().build()?.into())
228+
} else {
229+
Ok(expr.into())
230+
}
231+
}
232+
233+
#[pyfunction]
234+
pub fn regr_r2(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
235+
let expr = functions_aggregate::expr_fn::regr_r2(expr_y.expr, expr_x.expr);
236+
if distinct {
237+
Ok(expr.distinct().build()?.into())
238+
} else {
239+
Ok(expr.into())
240+
}
241+
}
242+
243+
#[pyfunction]
244+
pub fn regr_slope(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
245+
let expr = functions_aggregate::expr_fn::regr_slope(expr_y.expr, expr_x.expr);
246+
if distinct {
247+
Ok(expr.distinct().build()?.into())
248+
} else {
249+
Ok(expr.into())
250+
}
251+
}
252+
253+
#[pyfunction]
254+
pub fn regr_sxx(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
255+
let expr = functions_aggregate::expr_fn::regr_sxx(expr_y.expr, expr_x.expr);
256+
if distinct {
257+
Ok(expr.distinct().build()?.into())
258+
} else {
259+
Ok(expr.into())
260+
}
261+
}
262+
263+
#[pyfunction]
264+
pub fn regr_sxy(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
265+
let expr = functions_aggregate::expr_fn::regr_sxy(expr_y.expr, expr_x.expr);
266+
if distinct {
267+
Ok(expr.distinct().build()?.into())
268+
} else {
269+
Ok(expr.into())
270+
}
271+
}
272+
273+
#[pyfunction]
274+
pub fn regr_syy(expr_y: PyExpr, expr_x: PyExpr, distinct: bool) -> PyResult<PyExpr> {
275+
let expr = functions_aggregate::expr_fn::regr_syy(expr_y.expr, expr_x.expr);
276+
if distinct {
277+
Ok(expr.distinct().build()?.into())
278+
} else {
279+
Ok(expr.into())
280+
}
281+
}
282+
193283
#[pyfunction]
194284
#[pyo3(signature = (expr, distinct = false, filter = None, order_by = None, null_treatment = None))]
195285
pub fn first_value(
@@ -817,15 +907,6 @@ array_fn!(range, start stop step);
817907
aggregate_function!(array_agg, ArrayAgg);
818908
aggregate_function!(max, Max);
819909
aggregate_function!(min, Min);
820-
aggregate_function!(regr_avgx, RegrAvgx);
821-
aggregate_function!(regr_avgy, RegrAvgy);
822-
aggregate_function!(regr_count, RegrCount);
823-
aggregate_function!(regr_intercept, RegrIntercept);
824-
aggregate_function!(regr_r2, RegrR2);
825-
aggregate_function!(regr_slope, RegrSlope);
826-
aggregate_function!(regr_sxx, RegrSXX);
827-
aggregate_function!(regr_sxy, RegrSXY);
828-
aggregate_function!(regr_syy, RegrSYY);
829910
aggregate_function!(bit_and, BitAnd);
830911
aggregate_function!(bit_or, BitOr);
831912
aggregate_function!(bit_xor, BitXor);

0 commit comments

Comments
 (0)