Skip to content

Commit 6b1e9c6

Browse files
authored
Implement trait based API for defining WindowUDF (#8719)
* Implement trait based API for defining WindowUDF * add test case & docs * fix docs * rename WindowUDFImpl function
1 parent 9a6cc88 commit 6b1e9c6

File tree

8 files changed

+498
-44
lines changed

8 files changed

+498
-44
lines changed

datafusion-examples/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ cargo run --example csv_sql
6363
- [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF)
6464
- [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF)
6565
- [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF)
66+
- [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF)
6667

6768
## Distributed
6869

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
19+
use std::any::Any;
20+
21+
use arrow::{
22+
array::{ArrayRef, AsArray, Float64Array},
23+
datatypes::Float64Type,
24+
};
25+
use datafusion::error::Result;
26+
use datafusion::prelude::*;
27+
use datafusion_common::ScalarValue;
28+
use datafusion_expr::{
29+
PartitionEvaluator, Signature, WindowFrame, WindowUDF, WindowUDFImpl,
30+
};
31+
32+
/// This example shows how to use the full WindowUDFImpl API to implement a user
33+
/// defined window function. As in the `simple_udwf.rs` example, this struct implements
34+
/// a function `partition_evaluator` that returns the `MyPartitionEvaluator` instance.
35+
///
36+
/// To do so, we must implement the `WindowUDFImpl` trait.
37+
struct SmoothItUdf {
38+
signature: Signature,
39+
}
40+
41+
impl SmoothItUdf {
42+
/// Create a new instance of the SmoothItUdf struct
43+
fn new() -> Self {
44+
Self {
45+
signature: Signature::exact(
46+
// this function will always take one arguments of type f64
47+
vec![DataType::Float64],
48+
// this function is deterministic and will always return the same
49+
// result for the same input
50+
Volatility::Immutable,
51+
),
52+
}
53+
}
54+
}
55+
56+
impl WindowUDFImpl for SmoothItUdf {
57+
/// We implement as_any so that we can downcast the WindowUDFImpl trait object
58+
fn as_any(&self) -> &dyn Any {
59+
self
60+
}
61+
62+
/// Return the name of this function
63+
fn name(&self) -> &str {
64+
"smooth_it"
65+
}
66+
67+
/// Return the "signature" of this function -- namely that types of arguments it will take
68+
fn signature(&self) -> &Signature {
69+
&self.signature
70+
}
71+
72+
/// What is the type of value that will be returned by this function.
73+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
74+
Ok(DataType::Float64)
75+
}
76+
77+
/// Create a `PartitionEvalutor` to evaluate this function on a new
78+
/// partition.
79+
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
80+
Ok(Box::new(MyPartitionEvaluator::new()))
81+
}
82+
}
83+
84+
/// This implements the lowest level evaluation for a window function
85+
///
86+
/// It handles calculating the value of the window function for each
87+
/// distinct values of `PARTITION BY` (each car type in our example)
88+
#[derive(Clone, Debug)]
89+
struct MyPartitionEvaluator {}
90+
91+
impl MyPartitionEvaluator {
92+
fn new() -> Self {
93+
Self {}
94+
}
95+
}
96+
97+
/// Different evaluation methods are called depending on the various
98+
/// settings of WindowUDF. This example uses the simplest and most
99+
/// general, `evaluate`. See `PartitionEvaluator` for the other more
100+
/// advanced uses.
101+
impl PartitionEvaluator for MyPartitionEvaluator {
102+
/// Tell DataFusion the window function varies based on the value
103+
/// of the window frame.
104+
fn uses_window_frame(&self) -> bool {
105+
true
106+
}
107+
108+
/// This function is called once per input row.
109+
///
110+
/// `range`specifies which indexes of `values` should be
111+
/// considered for the calculation.
112+
///
113+
/// Note this is the SLOWEST, but simplest, way to evaluate a
114+
/// window function. It is much faster to implement
115+
/// evaluate_all or evaluate_all_with_rank, if possible
116+
fn evaluate(
117+
&mut self,
118+
values: &[ArrayRef],
119+
range: &std::ops::Range<usize>,
120+
) -> Result<ScalarValue> {
121+
// Again, the input argument is an array of floating
122+
// point numbers to calculate a moving average
123+
let arr: &Float64Array = values[0].as_ref().as_primitive::<Float64Type>();
124+
125+
let range_len = range.end - range.start;
126+
127+
// our smoothing function will average all the values in the
128+
let output = if range_len > 0 {
129+
let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum();
130+
Some(sum / range_len as f64)
131+
} else {
132+
None
133+
};
134+
135+
Ok(ScalarValue::Float64(output))
136+
}
137+
}
138+
139+
// create local execution context with `cars.csv` registered as a table named `cars`
140+
async fn create_context() -> Result<SessionContext> {
141+
// declare a new context. In spark API, this corresponds to a new spark SQL session
142+
let ctx = SessionContext::new();
143+
144+
// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
145+
println!("pwd: {}", std::env::current_dir().unwrap().display());
146+
let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string();
147+
let read_options = CsvReadOptions::default().has_header(true);
148+
149+
ctx.register_csv("cars", &csv_path, read_options).await?;
150+
Ok(ctx)
151+
}
152+
153+
#[tokio::main]
154+
async fn main() -> Result<()> {
155+
let ctx = create_context().await?;
156+
let smooth_it = WindowUDF::from(SmoothItUdf::new());
157+
ctx.register_udwf(smooth_it.clone());
158+
159+
// Use SQL to run the new window function
160+
let df = ctx.sql("SELECT * from cars").await?;
161+
// print the results
162+
df.show().await?;
163+
164+
// Use SQL to run the new window function:
165+
//
166+
// `PARTITION BY car`:each distinct value of car (red, and green)
167+
// should be treated as a separate partition (and will result in
168+
// creating a new `PartitionEvaluator`)
169+
//
170+
// `ORDER BY time`: within each partition ('green' or 'red') the
171+
// rows will be be ordered by the value in the `time` column
172+
//
173+
// `evaluate_inside_range` is invoked with a window defined by the
174+
// SQL. In this case:
175+
//
176+
// The first invocation will be passed row 0, the first row in the
177+
// partition.
178+
//
179+
// The second invocation will be passed rows 0 and 1, the first
180+
// two rows in the partition.
181+
//
182+
// etc.
183+
let df = ctx
184+
.sql(
185+
"SELECT \
186+
car, \
187+
speed, \
188+
smooth_it(speed) OVER (PARTITION BY car ORDER BY time) AS smooth_speed,\
189+
time \
190+
from cars \
191+
ORDER BY \
192+
car",
193+
)
194+
.await?;
195+
// print the results
196+
df.show().await?;
197+
198+
// this time, call the new widow function with an explicit
199+
// window so evaluate will be invoked with each window.
200+
//
201+
// `ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING`: each invocation
202+
// sees at most 3 rows: the row before, the current row, and the 1
203+
// row afterward.
204+
let df = ctx.sql(
205+
"SELECT \
206+
car, \
207+
speed, \
208+
smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS smooth_speed,\
209+
time \
210+
from cars \
211+
ORDER BY \
212+
car",
213+
).await?;
214+
// print the results
215+
df.show().await?;
216+
217+
// Now, run the function using the DataFrame API:
218+
let window_expr = smooth_it.call(
219+
vec![col("speed")], // smooth_it(speed)
220+
vec![col("car")], // PARTITION BY car
221+
vec![col("time").sort(true, true)], // ORDER BY time ASC
222+
WindowFrame::new(false),
223+
);
224+
let df = ctx.table("cars").await?.window(vec![window_expr])?;
225+
226+
// print the results
227+
df.show().await?;
228+
229+
Ok(())
230+
}

datafusion/core/tests/user_defined/user_defined_window_functions.rs

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
//! user defined window functions
2020
2121
use std::{
22+
any::Any,
2223
ops::Range,
2324
sync::{
2425
atomic::{AtomicUsize, Ordering},
@@ -32,8 +33,7 @@ use arrow_schema::DataType;
3233
use datafusion::{assert_batches_eq, prelude::SessionContext};
3334
use datafusion_common::{Result, ScalarValue};
3435
use datafusion_expr::{
35-
function::PartitionEvaluatorFactory, PartitionEvaluator, ReturnTypeFunction,
36-
Signature, Volatility, WindowUDF,
36+
PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl,
3737
};
3838

3939
/// A query with a window function evaluated over the entire partition
@@ -471,24 +471,48 @@ impl OddCounter {
471471
}
472472

473473
fn register(ctx: &mut SessionContext, test_state: Arc<TestState>) {
474-
let name = "odd_counter";
475-
let volatility = Volatility::Immutable;
476-
477-
let signature = Signature::exact(vec![DataType::Int64], volatility);
478-
479-
let return_type = Arc::new(DataType::Int64);
480-
let return_type: ReturnTypeFunction =
481-
Arc::new(move |_| Ok(Arc::clone(&return_type)));
482-
483-
let partition_evaluator_factory: PartitionEvaluatorFactory =
484-
Arc::new(move || Ok(Box::new(OddCounter::new(Arc::clone(&test_state)))));
485-
486-
ctx.register_udwf(WindowUDF::new(
487-
name,
488-
&signature,
489-
&return_type,
490-
&partition_evaluator_factory,
491-
))
474+
struct SimpleWindowUDF {
475+
signature: Signature,
476+
return_type: DataType,
477+
test_state: Arc<TestState>,
478+
}
479+
480+
impl SimpleWindowUDF {
481+
fn new(test_state: Arc<TestState>) -> Self {
482+
let signature =
483+
Signature::exact(vec![DataType::Float64], Volatility::Immutable);
484+
let return_type = DataType::Int64;
485+
Self {
486+
signature,
487+
return_type,
488+
test_state,
489+
}
490+
}
491+
}
492+
493+
impl WindowUDFImpl for SimpleWindowUDF {
494+
fn as_any(&self) -> &dyn Any {
495+
self
496+
}
497+
498+
fn name(&self) -> &str {
499+
"odd_counter"
500+
}
501+
502+
fn signature(&self) -> &Signature {
503+
&self.signature
504+
}
505+
506+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
507+
Ok(self.return_type.clone())
508+
}
509+
510+
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
511+
Ok(Box::new(OddCounter::new(Arc::clone(&self.test_state))))
512+
}
513+
}
514+
515+
ctx.register_udwf(WindowUDF::from(SimpleWindowUDF::new(test_state)))
492516
}
493517
}
494518

0 commit comments

Comments
 (0)