Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cpp/cmake_modules/ThirdpartyToolchain.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,10 @@ endif()

if(DEFINED ENV{ARROW_SUBSTRAIT_URL})
set(SUBSTRAIT_SOURCE_URL "$ENV{ARROW_SUBSTRAIT_URL}")
if(DEFINED ENV{ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM})
set(ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM
"$ENV{ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM}")
endif()
else()
set_urls(SUBSTRAIT_SOURCE_URL
"https://github.com/substrait-io/substrait/archive/${ARROW_SUBSTRAIT_BUILD_VERSION}.tar.gz"
Expand Down
151 changes: 93 additions & 58 deletions cpp/src/arrow/engine/substrait/expression_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,6 @@ namespace engine {

namespace {

constexpr int64_t kMicrosPerSecond = 1000000;
constexpr int64_t kMicrosPerMilli = 1000;

Id NormalizeFunctionName(Id id) {
// Substrait plans encode the types into the function name so it might look like
// add:opt_i32_i32. We don't care about the :opt_i32_i32 so we just trim it
Expand Down Expand Up @@ -121,9 +118,19 @@ Status DecodeOption(const substrait::FunctionOption& opt, SubstraitCall* call) {
Result<SubstraitCall> DecodeScalarFunction(
Id id, const substrait::Expression::ScalarFunction& scalar_fn,
const ExtensionSet& ext_set, const ConversionOptions& conversion_options) {
ARROW_ASSIGN_OR_RAISE(auto output_type_and_nullable,
FromProto(scalar_fn.output_type(), ext_set, conversion_options));
SubstraitCall call(id, output_type_and_nullable.first, output_type_and_nullable.second);
std::shared_ptr<DataType> output_type;
bool output_nullable = true;
if (scalar_fn.output_type().kind_case() == substrait::Type::kUnknown) {
output_nullable = scalar_fn.output_type().unknown().nullability() !=
substrait::Type::NULLABILITY_REQUIRED;
} else {
ARROW_ASSIGN_OR_RAISE(
auto output_type_and_nullable,
FromProto(scalar_fn.output_type(), ext_set, conversion_options));
output_type = std::move(output_type_and_nullable.first);
output_nullable = output_type_and_nullable.second;
}
SubstraitCall call(id, std::move(output_type), output_nullable);
for (int i = 0; i < scalar_fn.arguments_size(); i++) {
ARROW_RETURN_NOT_OK(
DecodeArg(scalar_fn.arguments(i), i, &call, ext_set, conversion_options));
Expand Down Expand Up @@ -296,6 +303,20 @@ Result<compute::Expression> FromProto(const substrait::Expression& expr,
return FromProto(ref, ext_set, conversion_options, std::move(out));
}

case substrait::Expression::kNamedExpression: {
const auto& named_expr = expr.named_expression();
if (named_expr.names_size() == 0) {
return Status::Invalid(
"substrait::Expression::NamedExpression had no name components");
}
std::vector<FieldRef> refs;
refs.reserve(named_expr.names_size());
for (const auto& name : named_expr.names()) {
refs.emplace_back(std::string(name));
}
return compute::field_ref(FieldRef(std::move(refs)));
}

case substrait::Expression::kIfThen: {
const auto& if_then = expr.if_then();
if (!if_then.has_else_()) break;
Expand Down Expand Up @@ -360,7 +381,8 @@ Result<compute::Expression> FromProto(const substrait::Expression& expr,
function_id = NormalizeFunctionName(function_id);
ExtensionIdRegistry::SubstraitCallToArrow function_converter;

if (function_id.uri.empty() || function_id.uri[0] == '/') {
if (function_id.uri.empty() || function_id.uri[0] == '/' ||
function_id.uri == kSubstraitUnknownFunctionsUri) {
// Currently the Substrait project has not aligned on a standard URI and often
// seems to use /. In that case we fall back to name-only matching.
ARROW_ASSIGN_OR_RAISE(
Expand Down Expand Up @@ -528,29 +550,34 @@ Result<Datum> FromProto(const substrait::Expression::Literal& lit,
case substrait::Expression::Literal::kBinary:
return Datum(BinaryScalar(lit.binary()));

ARROW_SUPPRESS_DEPRECATION_WARNING
case substrait::Expression::Literal::kTimestamp:
return Datum(
TimestampScalar(static_cast<int64_t>(lit.timestamp()), TimeUnit::MICRO));

case substrait::Expression::Literal::kTimestampTz:
return Datum(TimestampScalar(static_cast<int64_t>(lit.timestamp_tz()),
TimeUnit::MICRO, TimestampTzTimezoneString()));
ARROW_UNSUPPRESS_DEPRECATION_WARNING
case substrait::Expression::Literal::kPrecisionTimestamp: {
// https://github.com/substrait-io/substrait/issues/611
// TODO(GH-40741) don't break, return precision timestamp
break;
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<DataType> type,
precision_timestamp(lit.precision_timestamp().precision()));
return Datum(TimestampScalar(lit.precision_timestamp().value(), std::move(type)));
}
case substrait::Expression::Literal::kPrecisionTimestampTz: {
// https://github.com/substrait-io/substrait/issues/611
// TODO(GH-40741) don't break, return precision timestamp
break;
ARROW_ASSIGN_OR_RAISE(
std::shared_ptr<DataType> type,
precision_timestamp_tz(lit.precision_timestamp_tz().precision()));
return Datum(
TimestampScalar(lit.precision_timestamp_tz().value(), std::move(type)));
}
case substrait::Expression::Literal::kDate:
return Datum(Date32Scalar(lit.date()));
case substrait::Expression::Literal::kTime:
return Datum(Time64Scalar(lit.time(), TimeUnit::MICRO));
case substrait::Expression::Literal::kPrecisionTime: {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<DataType> type,
precision_time(lit.precision_time().precision()));
switch (type->id()) {
case Type::TIME32:
return Datum(
Time32Scalar(static_cast<int32_t>(lit.precision_time().value()), type));
case Type::TIME64:
return Datum(Time64Scalar(lit.precision_time().value(), type));
default:
return Status::Invalid("Unexpected Arrow type for Substrait precision_time: ",
type->ToString());
}
}

case substrait::Expression::Literal::kIntervalYearToMonth:
case substrait::Expression::Literal::kIntervalDayToSecond: {
Expand Down Expand Up @@ -887,60 +914,68 @@ struct ScalarToProtoImpl {
return EncodeUserDefined(*s.type, value);
}

Status Visit(const TimestampScalar& s) {
template <typename Sub>
Status VisitTimestamp(const TimestampScalar& s, void (Lit::*set_allocated_sub)(Sub*)) {
const auto& t = checked_cast<const TimestampType&>(*s.type);

uint64_t micros;
auto timestamp = std::make_unique<Sub>();
timestamp->set_value(s.value);
switch (t.unit()) {
case TimeUnit::SECOND:
micros = s.value * kMicrosPerSecond;
timestamp->set_precision(0);
break;
case TimeUnit::MILLI:
micros = s.value * kMicrosPerMilli;
timestamp->set_precision(3);
break;
case TimeUnit::MICRO:
micros = s.value;
timestamp->set_precision(6);
break;
case TimeUnit::NANO:
// TODO(GH-40741): can support nanos when
// https://github.com/substrait-io/substrait/issues/611 is resolved
return NotImplemented(s);
timestamp->set_precision(9);
break;
default:
return NotImplemented(s);
}
(lit_->*set_allocated_sub)(timestamp.release());
return Status::OK();
}

// Remove these and use precision timestamp once
// https://github.com/substrait-io/substrait/issues/611 is resolved
ARROW_SUPPRESS_DEPRECATION_WARNING

if (t.timezone() == "") {
lit_->set_timestamp(micros);
} else {
// Some loss of info here, Substrait doesn't store timezone
// in field data
lit_->set_timestamp_tz(micros);
Status Visit(const TimestampScalar& s) {
const auto& t = checked_cast<const TimestampType&>(*s.type);
if (t.timezone().empty()) {
return VisitTimestamp(s, &Lit::set_allocated_precision_timestamp);
}
ARROW_UNSUPPRESS_DEPRECATION_WARNING

return Status::OK();
return VisitTimestamp(s, &Lit::set_allocated_precision_timestamp_tz);
}

// Need to support parameterized UDTs
Status Visit(const Time32Scalar& s) {
google::protobuf::Int32Value value;
value.set_value(s.value);
return EncodeUserDefined(*s.type, value);
}
Status Visit(const Time64Scalar& s) {
if (checked_cast<const Time64Type&>(*s.type).unit() == TimeUnit::MICRO) {
return Primitive(&Lit::set_time, s);
} else {
google::protobuf::Int64Value value;
value.set_value(s.value);
return EncodeUserDefined(*s.type, value);
template <typename ScalarType, typename ValueType = typename ScalarType::ValueType>
Status VisitTime(const ScalarType& s) {
const auto& t = checked_cast<const typename ScalarType::TypeClass&>(*s.type);
auto time = std::make_unique<Lit::PrecisionTime>();
time->set_value(static_cast<int64_t>(s.value));
switch (t.unit()) {
case TimeUnit::SECOND:
time->set_precision(0);
break;
case TimeUnit::MILLI:
time->set_precision(3);
break;
case TimeUnit::MICRO:
time->set_precision(6);
break;
case TimeUnit::NANO:
time->set_precision(9);
break;
default:
return NotImplemented(s);
}
lit_->set_allocated_precision_time(time.release());
return Status::OK();
}

Status Visit(const Time32Scalar& s) { return VisitTime(s); }
Status Visit(const Time64Scalar& s) { return VisitTime(s); }

Status Visit(const MonthIntervalScalar& s) { return NotImplemented(s); }
Status Visit(const DayTimeIntervalScalar& s) { return NotImplemented(s); }

Expand Down
94 changes: 92 additions & 2 deletions cpp/src/arrow/engine/substrait/extended_expression_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,93 @@ Status AddExtensionSetToExtendedExpression(const ExtensionSet& ext_set,
return AddExtensionSetToMessage(ext_set, expr);
}

bool TypeContainsUnknown(const substrait::Type& type) {
switch (type.kind_case()) {
case substrait::Type::kUnknown:
return true;
case substrait::Type::kStruct:
for (const auto& child_type : type.struct_().types()) {
if (TypeContainsUnknown(child_type)) {
return true;
}
}
return false;
case substrait::Type::kList:
return type.list().has_type() && TypeContainsUnknown(type.list().type());
case substrait::Type::kMap:
return (type.map().has_key() && TypeContainsUnknown(type.map().key())) ||
(type.map().has_value() && TypeContainsUnknown(type.map().value()));
default:
return false;
}
}

bool NamedStructContainsUnknownTypes(const substrait::NamedStruct& named_struct) {
if (!named_struct.has_struct_()) {
return false;
}
for (const auto& type : named_struct.struct_().types()) {
if (TypeContainsUnknown(type)) {
return true;
}
}
return false;
}

void CollectDepthFirstFieldNames(const FieldVector& fields,
std::vector<std::string_view>* names) {
for (const auto& field : fields) {
names->push_back(field->name());
if (field->type()->id() == Type::STRUCT) {
CollectDepthFirstFieldNames(field->type()->fields(), names);
}
}
}

Status ValidateSchemaNames(const substrait::NamedStruct& named_struct,
const Schema& schema) {
std::vector<std::string_view> schema_names;
CollectDepthFirstFieldNames(schema.fields(), &schema_names);
if (schema_names.size() != static_cast<size_t>(named_struct.names_size())) {
return Status::Invalid(
"The supplied Arrow schema did not match the ExtendedExpression base_schema. "
"Expected ",
named_struct.names_size(), " depth-first field names but got ",
schema_names.size());
}
for (int i = 0; i < named_struct.names_size(); ++i) {
if (schema_names[static_cast<size_t>(i)] != named_struct.names(i)) {
return Status::Invalid(
"The supplied Arrow schema did not match the ExtendedExpression "
"base_schema. Expected field name ",
named_struct.names(i), " at depth-first position ", i, " but got ",
schema_names[static_cast<size_t>(i)]);
}
}
return Status::OK();
}

Result<std::shared_ptr<Schema>> ResolveInputSchema(
const substrait::NamedStruct& base_schema, const Schema* input_schema_override,
const ExtensionSet& ext_set, const ConversionOptions& conversion_options) {
if (input_schema_override == NULLPTR) {
return FromProto(base_schema, ext_set, conversion_options);
}

if (NamedStructContainsUnknownTypes(base_schema)) {
ARROW_RETURN_NOT_OK(ValidateSchemaNames(base_schema, *input_schema_override));
return std::make_shared<Schema>(*input_schema_override);
}

ARROW_ASSIGN_OR_RAISE(auto decoded_schema,
FromProto(base_schema, ext_set, conversion_options));
if (!decoded_schema->Equals(*input_schema_override, /*check_metadata=*/false)) {
return Status::Invalid(
"The supplied Arrow schema did not match the ExtendedExpression base_schema");
}
return std::make_shared<Schema>(*input_schema_override);
}

Status VisitNestedFields(const DataType& type,
std::function<Status(const Field&)> visitor) {
if (!is_nested(type.id())) {
Expand Down Expand Up @@ -149,6 +236,7 @@ Result<std::unique_ptr<substrait::ExpressionReference>> CreateExpressionReferenc
} // namespace

Result<BoundExpressions> FromProto(const substrait::ExtendedExpression& expression,
const Schema* input_schema_override,
ExtensionSet* ext_set_out,
const ConversionOptions& conversion_options,
const ExtensionIdRegistry* registry) {
Expand All @@ -162,8 +250,10 @@ Result<BoundExpressions> FromProto(const substrait::ExtendedExpression& expressi
ExtensionSet ext_set,
GetExtensionSetFromExtendedExpression(expression, conversion_options, registry));

ARROW_ASSIGN_OR_RAISE(bound_expressions.schema,
FromProto(expression.base_schema(), ext_set, conversion_options));
ARROW_ASSIGN_OR_RAISE(
bound_expressions.schema,
ResolveInputSchema(expression.base_schema(), input_schema_override, ext_set,
conversion_options));

bound_expressions.named_expressions.reserve(expression.referred_expr_size());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ namespace engine {
/// Convert a Substrait ExtendedExpression to a vector of expressions and output names
ARROW_ENGINE_EXPORT
Result<BoundExpressions> FromProto(const substrait::ExtendedExpression& expression,
const Schema* input_schema_override,
ExtensionSet* ext_set_out,
const ConversionOptions& conversion_options,
const ExtensionIdRegistry* extension_id_registry);
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/engine/substrait/extension_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ constexpr const char* kSubstraitAggregateGenericFunctionsUri =
/// and any options are ignored.
constexpr const char* kArrowSimpleExtensionFunctionsUri =
"urn:arrow:substrait_simple_extension_function";
constexpr const char* kSubstraitUnknownFunctionsUri = "extension:io.substrait:unknown";

struct ARROW_ENGINE_EXPORT Id {
std::string_view uri, name;
Expand Down
15 changes: 15 additions & 0 deletions cpp/src/arrow/engine/substrait/extension_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,21 @@ std::shared_ptr<DataType> interval_year() { return IntervalYearType::Make({}); }

std::shared_ptr<DataType> interval_day() { return IntervalDayType::Make({}); }

Result<std::shared_ptr<DataType>> precision_time(int precision) {
switch (precision) {
case 0:
return time32(TimeUnit::SECOND);
case 3:
return time32(TimeUnit::MILLI);
case 6:
return time64(TimeUnit::MICRO);
case 9:
return time64(TimeUnit::NANO);
default:
return Status::NotImplemented("Unrecognized time precision (", precision, ")");
}
}

Result<std::shared_ptr<DataType>> precision_timestamp(int precision) {
switch (precision) {
case 0:
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/engine/substrait/extension_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ std::shared_ptr<DataType> interval_year();
ARROW_ENGINE_EXPORT
std::shared_ptr<DataType> interval_day();

/// constructs the appropriate time type given the precision
ARROW_ENGINE_EXPORT
Result<std::shared_ptr<DataType>> precision_time(int precision);

/// constructs the appropriate timestamp type given the precision
/// no time zone
ARROW_ENGINE_EXPORT
Expand Down
Loading
Loading