Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 40 additions & 131 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use std::any::Any;
use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::hash::{Hash, Hasher};
use std::sync::Arc;

use arrow::array::{as_string_array, create_array, record_batch, Int8Array, UInt64Array};
Expand All @@ -43,9 +43,9 @@ use datafusion_common::{
use datafusion_expr::expr::FieldMetadata;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{
lit_with_metadata, Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody,
LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs,
ScalarUDF, ScalarUDFImpl, Signature, Volatility,
lit_with_metadata, udf_equals_hash, Accumulator, ColumnarValue, CreateFunction,
CreateFunctionBody, LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs,
ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_functions_nested::range::range_udf;
use parking_lot::Mutex;
Expand Down Expand Up @@ -181,6 +181,7 @@ async fn scalar_udf() -> Result<()> {
Ok(())
}

#[derive(PartialEq, Hash)]
struct Simple0ArgsScalarUDF {
name: String,
signature: Signature,
Expand Down Expand Up @@ -218,33 +219,7 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF {
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100))))
}

fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
let Some(other) = other.as_any().downcast_ref::<Self>() else {
return false;
};
let Self {
name,
signature,
return_type,
} = self;
name == &other.name
&& signature == &other.signature
&& return_type == &other.return_type
}

fn hash_value(&self) -> u64 {
let Self {
name,
signature,
return_type,
} = self;
let mut hasher = DefaultHasher::new();
std::any::type_name::<Self>().hash(&mut hasher);
name.hash(&mut hasher);
signature.hash(&mut hasher);
return_type.hash(&mut hasher);
hasher.finish()
}
udf_equals_hash!(ScalarUDFImpl);
}

#[tokio::test]
Expand Down Expand Up @@ -517,7 +492,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
}

/// Volatile UDF that should append a different value to each row
#[derive(Debug)]
#[derive(Debug, PartialEq, Hash)]
struct AddIndexToStringVolatileScalarUDF {
name: String,
signature: Signature,
Expand Down Expand Up @@ -586,33 +561,7 @@ impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF {
Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer))))
}

fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
let Some(other) = other.as_any().downcast_ref::<Self>() else {
return false;
};
let Self {
name,
signature,
return_type,
} = self;
name == &other.name
&& signature == &other.signature
&& return_type == &other.return_type
}

fn hash_value(&self) -> u64 {
let Self {
name,
signature,
return_type,
} = self;
let mut hasher = DefaultHasher::new();
std::any::type_name::<Self>().hash(&mut hasher);
name.hash(&mut hasher);
signature.hash(&mut hasher);
return_type.hash(&mut hasher);
hasher.finish()
}
udf_equals_hash!(ScalarUDFImpl);
}

#[tokio::test]
Expand Down Expand Up @@ -992,7 +941,7 @@ impl FunctionFactory for CustomFunctionFactory {
//
// it also defines custom [ScalarUDFImpl::simplify()]
// to replace ScalarUDF expression with one instance contains.
#[derive(Debug)]
#[derive(Debug, PartialEq, Hash)]
struct ScalarFunctionWrapper {
name: String,
expr: Expr,
Expand Down Expand Up @@ -1031,37 +980,7 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
Ok(ExprSimplifyResult::Simplified(replacement))
}

fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
let Some(other) = other.as_any().downcast_ref::<Self>() else {
return false;
};
let Self {
name,
expr,
signature,
return_type,
} = self;
name == &other.name
&& expr == &other.expr
&& signature == &other.signature
&& return_type == &other.return_type
}

fn hash_value(&self) -> u64 {
let Self {
name,
expr,
signature,
return_type,
} = self;
let mut hasher = DefaultHasher::new();
std::any::type_name::<Self>().hash(&mut hasher);
name.hash(&mut hasher);
expr.hash(&mut hasher);
signature.hash(&mut hasher);
return_type.hash(&mut hasher);
hasher.finish()
}
udf_equals_hash!(ScalarUDFImpl);
}

impl ScalarFunctionWrapper {
Expand Down Expand Up @@ -1296,6 +1215,21 @@ struct MyRegexUdf {
regex: Regex,
}

impl PartialEq for MyRegexUdf {
fn eq(&self, other: &Self) -> bool {
let Self { signature, regex } = self;
signature == &other.signature && regex.as_str() == other.regex.as_str()
}
}

impl Hash for MyRegexUdf {
fn hash<H: Hasher>(&self, state: &mut H) {
let Self { signature, regex } = self;
signature.hash(state);
regex.as_str().hash(state);
}
}

impl MyRegexUdf {
fn new(pattern: &str) -> Self {
Self {
Expand Down Expand Up @@ -1348,19 +1282,7 @@ impl ScalarUDFImpl for MyRegexUdf {
}
}

fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<MyRegexUdf>() {
self.regex.as_str() == other.regex.as_str()
} else {
false
}
}

fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.regex.as_str().hash(hasher);
hasher.finish()
}
udf_equals_hash!(ScalarUDFImpl);
}

#[tokio::test]
Expand Down Expand Up @@ -1458,13 +1380,25 @@ async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordB
ctx.sql(sql).await?.collect().await
}

#[derive(Debug)]
#[derive(Debug, PartialEq)]
struct MetadataBasedUdf {
name: String,
signature: Signature,
metadata: HashMap<String, String>,
}

impl Hash for MetadataBasedUdf {
fn hash<H: Hasher>(&self, state: &mut H) {
let Self {
name,
signature,
metadata: _, // unhashable
} = self;
name.hash(state);
signature.hash(state);
}
}

impl MetadataBasedUdf {
fn new(metadata: HashMap<String, String>) -> Self {
// The name we return must be unique. Otherwise we will not call distinct
Expand Down Expand Up @@ -1537,32 +1471,7 @@ impl ScalarUDFImpl for MetadataBasedUdf {
}
}

fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
let Some(other) = other.as_any().downcast_ref::<Self>() else {
return false;
};
let Self {
name,
signature,
metadata,
} = self;
name == &other.name
&& signature == &other.signature
&& metadata == &other.metadata
}

fn hash_value(&self) -> u64 {
let Self {
name,
signature,
metadata: _, // unhashable
} = self;
let mut hasher = DefaultHasher::new();
std::any::type_name::<Self>().hash(&mut hasher);
name.hash(&mut hasher);
signature.hash(&mut hasher);
hasher.finish()
}
udf_equals_hash!(ScalarUDFImpl);
}

#[tokio::test]
Expand Down
38 changes: 21 additions & 17 deletions datafusion/expr/src/async_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
// specific language governing permissions and limitations
// under the License.

use crate::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
use crate::utils::{arc_ptr_eq, arc_ptr_hash};
use crate::{
udf_equals_hash, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
};
use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, FieldRef};
use async_trait::async_trait;
Expand All @@ -26,7 +29,7 @@ use datafusion_expr_common::columnar_value::ColumnarValue;
use datafusion_expr_common::signature::Signature;
use std::any::Any;
use std::fmt::{Debug, Display};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::hash::{Hash, Hasher};
use std::sync::Arc;

/// A scalar UDF that can invoke using async methods
Expand Down Expand Up @@ -62,6 +65,21 @@ pub struct AsyncScalarUDF {
inner: Arc<dyn AsyncScalarUDFImpl>,
}

impl PartialEq for AsyncScalarUDF {
fn eq(&self, other: &Self) -> bool {
let Self { inner } = self;
// TODO when MSRV >= 1.86.0, switch to `inner.equals(other.inner.as_ref())` leveraging trait upcasting.
arc_ptr_eq(inner, &other.inner)
}
}

impl Hash for AsyncScalarUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
let Self { inner } = self;
arc_ptr_hash(inner, state);
}
}

impl AsyncScalarUDF {
pub fn new(inner: Arc<dyn AsyncScalarUDFImpl>) -> Self {
Self { inner }
Expand Down Expand Up @@ -113,21 +131,7 @@ impl ScalarUDFImpl for AsyncScalarUDF {
internal_err!("async functions should not be called directly")
}

fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
let Some(other) = other.as_any().downcast_ref::<Self>() else {
return false;
};
let Self { inner } = self;
// TODO when MSRV >= 1.86.0, switch to `inner.equals(other.inner.as_ref())` leveraging trait upcasting
Arc::ptr_eq(inner, &other.inner)
}

fn hash_value(&self) -> u64 {
let Self { inner } = self;
let mut hasher = DefaultHasher::new();
Arc::as_ptr(inner).hash(&mut hasher);
hasher.finish()
}
udf_equals_hash!(ScalarUDFImpl);
}

impl Display for AsyncScalarUDF {
Expand Down
Loading