Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions datafusion/functions/src/core/arrow_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ impl ScalarUDFImpl for ArrowCastFunc {
info: &SimplifyContext,
) -> Result<ExprSimplifyResult> {
// convert this into a real cast
let target_type = data_type_from_args(&args)?;
let target_type = data_type_from_args(self.name(), &args)?;
// remove second (type) argument
args.pop().unwrap();
let arg = args.pop().unwrap();
Expand All @@ -189,12 +189,12 @@ impl ScalarUDFImpl for ArrowCastFunc {
}

/// Returns the requested type from the arguments
fn data_type_from_args(args: &[Expr]) -> Result<DataType> {
let [_, type_arg] = take_function_args("arrow_cast", args)?;
pub(crate) fn data_type_from_args(name: &str, args: &[Expr]) -> Result<DataType> {
let [_, type_arg] = take_function_args(name, args)?;

let Expr::Literal(ScalarValue::Utf8(Some(val)), _) = type_arg else {
return exec_err!(
"arrow_cast requires its second argument to be a constant string, got {:?}",
"{name} requires its second argument to be a constant string, got {:?}",
type_arg
);
};
Expand Down
158 changes: 158 additions & 0 deletions datafusion/functions/src/core/arrow_try_cast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! [`ArrowTryCastFunc`]: Implementation of the `arrow_try_cast`

use arrow::datatypes::{DataType, Field, FieldRef};
use arrow::error::ArrowError;
use datafusion_common::{
Result, arrow_datafusion_err, datatype::DataTypeExt, exec_datafusion_err, exec_err,
internal_err, types::logical_string, utils::take_function_args,
};
use std::any::Any;

use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs,
ScalarUDFImpl, Signature, TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;

use super::arrow_cast::data_type_from_args;

/// Like [`arrow_cast`](super::arrow_cast::ArrowCastFunc) but returns NULL on cast failure instead of erroring.
///
/// This is implemented by simplifying `arrow_try_cast(expr, 'Type')` into
/// `Expr::TryCast` during optimization.
#[user_doc(
doc_section(label = "Other Functions"),
description = "Casts a value to a specific Arrow data type, returning NULL if the cast fails.",
syntax_example = "arrow_try_cast(expression, datatype)",
sql_example = r#"```sql
> select arrow_try_cast('123', 'Int64') as a,
arrow_try_cast('not_a_number', 'Int64') as b;

+-----+------+
| a | b |
+-----+------+
| 123 | NULL |
+-----+------+
```"#,
argument(
name = "expression",
description = "Expression to cast. The expression can be a constant, column, or function, and any combination of operators."
),
argument(
name = "datatype",
description = "[Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`]"
)
)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct ArrowTryCastFunc {
signature: Signature,
}

impl Default for ArrowTryCastFunc {
fn default() -> Self {
Self::new()
}
}

impl ArrowTryCastFunc {
pub fn new() -> Self {
Self {
signature: Signature::coercible(
vec![
Coercion::new_exact(TypeSignatureClass::Any),
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
],
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for ArrowTryCastFunc {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"arrow_try_cast"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!("return_field_from_args should be called instead")
}

fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
// TryCast can always return NULL (on cast failure), so always nullable
let [_, type_arg] = take_function_args(self.name(), args.scalar_arguments)?;

type_arg
.and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty()))
.map_or_else(
|| {
exec_err!(
"{} requires its second argument to be a non-empty constant string",
self.name()
)
},
|casted_type| match casted_type.parse::<DataType>() {
Ok(data_type) => {
Ok(Field::new(self.name(), data_type, true).into())
}
Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")),
Err(e) => Err(arrow_datafusion_err!(e)),
},
)
}

fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
internal_err!("arrow_try_cast should have been simplified to try_cast")
}

fn simplify(
&self,
mut args: Vec<Expr>,
info: &SimplifyContext,
) -> Result<ExprSimplifyResult> {
let target_type = data_type_from_args(self.name(), &args)?;
// remove second (type) argument
args.pop().unwrap();
let arg = args.pop().unwrap();

let source_type = info.get_data_type(&arg)?;
let new_expr = if source_type == target_type {
arg
} else {
Expr::TryCast(datafusion_expr::TryCast {
expr: Box::new(arg),
field: target_type.into_nullable_field_ref(),
})
};
Ok(ExprSimplifyResult::Simplified(new_expr))
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}
9 changes: 8 additions & 1 deletion datafusion/functions/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use std::sync::Arc;

pub mod arrow_cast;
pub mod arrow_metadata;
pub mod arrow_try_cast;
pub mod arrowtypeof;
pub mod coalesce;
pub mod expr_ext;
Expand All @@ -42,6 +43,7 @@ pub mod version;

// create UDFs
make_udf_function!(arrow_cast::ArrowCastFunc, arrow_cast);
make_udf_function!(arrow_try_cast::ArrowTryCastFunc, arrow_try_cast);
make_udf_function!(nullif::NullIfFunc, nullif);
make_udf_function!(nvl::NVLFunc, nvl);
make_udf_function!(nvl2::NVL2Func, nvl2);
Expand All @@ -67,7 +69,11 @@ pub mod expr_fn {
arg1 arg2
),(
arrow_cast,
"Returns value2 if value1 is NULL; otherwise it returns value1",
"Casts a value to a specific Arrow data type",
arg1 arg2
),(
arrow_try_cast,
"Casts a value to a specific Arrow data type, returning NULL if the cast fails",
arg1 arg2
),(
nvl,
Expand Down Expand Up @@ -140,6 +146,7 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
vec![
nullif(),
arrow_cast(),
arrow_try_cast(),
arrow_metadata(),
nvl(),
nvl2(),
Expand Down
109 changes: 109 additions & 0 deletions datafusion/sqllogictest/test_files/arrow_try_cast.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

##########
# Tests for arrow_try_cast: like arrow_cast but returns NULL on cast failure
##########

# Successful cast to Float64
query R
select arrow_try_cast(123, 'Float64');
----
123

# Successful cast to Int64
query I
select arrow_try_cast('123', 'Int64');
----
123

# Failed cast returns NULL
query I
select arrow_try_cast('not_a_number', 'Int64');
----
NULL

# Same-type passthrough
query I
select arrow_try_cast(1, 'Int32');
----
1

# Cast to LargeUtf8
query T
select arrow_try_cast('foo', 'LargeUtf8');
----
foo

# Cast integer to string
query T
select arrow_try_cast(42, 'Utf8');
----
42

# Cast to dictionary type
query T
select arrow_try_cast('bar', 'Dictionary(Int32, Utf8)');
----
bar

# NULL input stays NULL
query I
select arrow_try_cast(NULL, 'Int64');
----
NULL

# Error on invalid type string
statement error
select arrow_try_cast(1, 'NotAType');

# Error when second argument is not a string constant
statement error
select arrow_try_cast(1, 123);

# Multiple arrow_try_cast in one query
query IT
select arrow_try_cast('456', 'Int64') as a,
arrow_try_cast(789, 'Utf8') as b;
----
456 789

# Tests that exercise physical execution (not constant folding)

# Cast column values to Int64, with mixed valid/null/invalid inputs
query I
select arrow_try_cast(a, 'Int64') from (values('100'), (NULL), ('foo')) t(a);
----
100
NULL
NULL

# Cast column values to Float64
query R
select arrow_try_cast(a, 'Float64') from (values('3.14'), ('not_num'), (NULL)) t(a);
----
3.14
NULL
NULL

# Cast integer column to Utf8
query T
select arrow_try_cast(a, 'Utf8') from (values(1), (2), (NULL)) t(a);
----
1
2
NULL
27 changes: 27 additions & 0 deletions docs/source/user-guide/sql/scalar_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -5185,6 +5185,7 @@ union_tag(union_expression)

- [arrow_cast](#arrow_cast)
- [arrow_metadata](#arrow_metadata)
- [arrow_try_cast](#arrow_try_cast)
- [arrow_typeof](#arrow_typeof)
- [get_field](#get_field)
- [version](#version)
Expand Down Expand Up @@ -5257,6 +5258,32 @@ arrow_metadata(expression[, key])
+-------------------------------+
```

### `arrow_try_cast`

Casts a value to a specific Arrow data type, returning NULL if the cast fails.

```sql
arrow_try_cast(expression, datatype)
```

#### Arguments

- **expression**: Expression to cast. The expression can be a constant, column, or function, and any combination of operators.
- **datatype**: [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`]

#### Example

```sql
> select arrow_try_cast('123', 'Int64') as a,
arrow_try_cast('not_a_number', 'Int64') as b;

+-----+------+
| a | b |
+-----+------+
| 123 | NULL |
+-----+------+
```

### `arrow_typeof`

Returns the name of the underlying [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) of the expression.
Expand Down
Loading