diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index e555081e4132c..3e0a23f1adc6c 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -163,7 +163,7 @@ impl ScalarUDFImpl for ArrowCastFunc { info: &SimplifyContext, ) -> Result { // convert this into a real cast - let target_type = data_type_from_args(&args)?; + let target_type = data_type_from_args(self.name(), &args)?; // remove second (type) argument args.pop().unwrap(); let arg = args.pop().unwrap(); @@ -189,12 +189,12 @@ impl ScalarUDFImpl for ArrowCastFunc { } /// Returns the requested type from the arguments -fn data_type_from_args(args: &[Expr]) -> Result { - let [_, type_arg] = take_function_args("arrow_cast", args)?; +pub(crate) fn data_type_from_args(name: &str, args: &[Expr]) -> Result { + let [_, type_arg] = take_function_args(name, args)?; let Expr::Literal(ScalarValue::Utf8(Some(val)), _) = type_arg else { return exec_err!( - "arrow_cast requires its second argument to be a constant string, got {:?}", + "{name} requires its second argument to be a constant string, got {:?}", type_arg ); }; diff --git a/datafusion/functions/src/core/arrow_try_cast.rs b/datafusion/functions/src/core/arrow_try_cast.rs new file mode 100644 index 0000000000000..a221c81e07f13 --- /dev/null +++ b/datafusion/functions/src/core/arrow_try_cast.rs @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ArrowTryCastFunc`]: Implementation of the `arrow_try_cast` + +use arrow::datatypes::{DataType, Field, FieldRef}; +use arrow::error::ArrowError; +use datafusion_common::{ + Result, arrow_datafusion_err, datatype::DataTypeExt, exec_datafusion_err, exec_err, + internal_err, types::logical_string, utils::take_function_args, +}; +use std::any::Any; + +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, +}; +use datafusion_macros::user_doc; + +use super::arrow_cast::data_type_from_args; + +/// Like [`arrow_cast`](super::arrow_cast::ArrowCastFunc) but returns NULL on cast failure instead of erroring. +/// +/// This is implemented by simplifying `arrow_try_cast(expr, 'Type')` into +/// `Expr::TryCast` during optimization. +#[user_doc( + doc_section(label = "Other Functions"), + description = "Casts a value to a specific Arrow data type, returning NULL if the cast fails.", + syntax_example = "arrow_try_cast(expression, datatype)", + sql_example = r#"```sql +> select arrow_try_cast('123', 'Int64') as a, + arrow_try_cast('not_a_number', 'Int64') as b; + ++-----+------+ +| a | b | ++-----+------+ +| 123 | NULL | ++-----+------+ +```"#, + argument( + name = "expression", + description = "Expression to cast. The expression can be a constant, column, or function, and any combination of operators." + ), + argument( + name = "datatype", + description = "[Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`]" + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrowTryCastFunc { + signature: Signature, +} + +impl Default for ArrowTryCastFunc { + fn default() -> Self { + Self::new() + } +} + +impl ArrowTryCastFunc { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Any), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for ArrowTryCastFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "arrow_try_cast" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be called instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + // TryCast can always return NULL (on cast failure), so always nullable + let [_, type_arg] = take_function_args(self.name(), args.scalar_arguments)?; + + type_arg + .and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty())) + .map_or_else( + || { + exec_err!( + "{} requires its second argument to be a non-empty constant string", + self.name() + ) + }, + |casted_type| match casted_type.parse::() { + Ok(data_type) => { + Ok(Field::new(self.name(), data_type, true).into()) + } + Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")), + Err(e) => Err(arrow_datafusion_err!(e)), + }, + ) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("arrow_try_cast should have been simplified to try_cast") + } + + fn simplify( + &self, + mut args: Vec, + info: &SimplifyContext, + ) -> Result { + let target_type = data_type_from_args(self.name(), &args)?; + // remove second (type) argument + args.pop().unwrap(); + let arg = args.pop().unwrap(); + + let source_type = info.get_data_type(&arg)?; + let new_expr = if source_type == target_type { + arg + } else { + Expr::TryCast(datafusion_expr::TryCast { + expr: Box::new(arg), + field: target_type.into_nullable_field_ref(), + }) + }; + Ok(ExprSimplifyResult::Simplified(new_expr)) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index a14d563737240..e8737612a1dcf 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -22,6 +22,7 @@ use std::sync::Arc; pub mod arrow_cast; pub mod arrow_metadata; +pub mod arrow_try_cast; pub mod arrowtypeof; pub mod coalesce; pub mod expr_ext; @@ -42,6 +43,7 @@ pub mod version; // create UDFs make_udf_function!(arrow_cast::ArrowCastFunc, arrow_cast); +make_udf_function!(arrow_try_cast::ArrowTryCastFunc, arrow_try_cast); make_udf_function!(nullif::NullIfFunc, nullif); make_udf_function!(nvl::NVLFunc, nvl); make_udf_function!(nvl2::NVL2Func, nvl2); @@ -67,7 +69,11 @@ pub mod expr_fn { arg1 arg2 ),( arrow_cast, - "Returns value2 if value1 is NULL; otherwise it returns value1", + "Casts a value to a specific Arrow data type", + arg1 arg2 + ),( + arrow_try_cast, + "Casts a value to a specific Arrow data type, returning NULL if the cast fails", arg1 arg2 ),( nvl, @@ -140,6 +146,7 @@ pub fn functions() -> Vec> { vec![ nullif(), arrow_cast(), + arrow_try_cast(), arrow_metadata(), nvl(), nvl2(), diff --git a/datafusion/sqllogictest/test_files/arrow_try_cast.slt b/datafusion/sqllogictest/test_files/arrow_try_cast.slt new file mode 100644 index 0000000000000..fffb340798634 --- /dev/null +++ b/datafusion/sqllogictest/test_files/arrow_try_cast.slt @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +# Tests for arrow_try_cast: like arrow_cast but returns NULL on cast failure +########## + +# Successful cast to Float64 +query R +select arrow_try_cast(123, 'Float64'); +---- +123 + +# Successful cast to Int64 +query I +select arrow_try_cast('123', 'Int64'); +---- +123 + +# Failed cast returns NULL +query I +select arrow_try_cast('not_a_number', 'Int64'); +---- +NULL + +# Same-type passthrough +query I +select arrow_try_cast(1, 'Int32'); +---- +1 + +# Cast to LargeUtf8 +query T +select arrow_try_cast('foo', 'LargeUtf8'); +---- +foo + +# Cast integer to string +query T +select arrow_try_cast(42, 'Utf8'); +---- +42 + +# Cast to dictionary type +query T +select arrow_try_cast('bar', 'Dictionary(Int32, Utf8)'); +---- +bar + +# NULL input stays NULL +query I +select arrow_try_cast(NULL, 'Int64'); +---- +NULL + +# Error on invalid type string +statement error +select arrow_try_cast(1, 'NotAType'); + +# Error when second argument is not a string constant +statement error +select arrow_try_cast(1, 123); + +# Multiple arrow_try_cast in one query +query IT +select arrow_try_cast('456', 'Int64') as a, + arrow_try_cast(789, 'Utf8') as b; +---- +456 789 + +# Tests that exercise physical execution (not constant folding) + +# Cast column values to Int64, with mixed valid/null/invalid inputs +query I +select arrow_try_cast(a, 'Int64') from (values('100'), (NULL), ('foo')) t(a); +---- +100 +NULL +NULL + +# Cast column values to Float64 +query R +select arrow_try_cast(a, 'Float64') from (values('3.14'), ('not_num'), (NULL)) t(a); +---- +3.14 +NULL +NULL + +# Cast integer column to Utf8 +query T +select arrow_try_cast(a, 'Utf8') from (values(1), (2), (NULL)) t(a); +---- +1 +2 +NULL diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 918bae0f7d1b5..5a8ef4db3d4b2 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -5185,6 +5185,7 @@ union_tag(union_expression) - [arrow_cast](#arrow_cast) - [arrow_metadata](#arrow_metadata) +- [arrow_try_cast](#arrow_try_cast) - [arrow_typeof](#arrow_typeof) - [get_field](#get_field) - [version](#version) @@ -5257,6 +5258,32 @@ arrow_metadata(expression[, key]) +-------------------------------+ ``` +### `arrow_try_cast` + +Casts a value to a specific Arrow data type, returning NULL if the cast fails. + +```sql +arrow_try_cast(expression, datatype) +``` + +#### Arguments + +- **expression**: Expression to cast. The expression can be a constant, column, or function, and any combination of operators. +- **datatype**: [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`] + +#### Example + +```sql +> select arrow_try_cast('123', 'Int64') as a, + arrow_try_cast('not_a_number', 'Int64') as b; + ++-----+------+ +| a | b | ++-----+------+ +| 123 | NULL | ++-----+------+ +``` + ### `arrow_typeof` Returns the name of the underlying [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) of the expression.