diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index cc632a62fe4d..2cd4b41abeea 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -782,6 +782,10 @@ endif() if(DEFINED ENV{ARROW_SUBSTRAIT_URL}) set(SUBSTRAIT_SOURCE_URL "$ENV{ARROW_SUBSTRAIT_URL}") + if(DEFINED ENV{ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM}) + set(ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM + "$ENV{ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM}") + endif() else() set_urls(SUBSTRAIT_SOURCE_URL "https://github.com/substrait-io/substrait/archive/${ARROW_SUBSTRAIT_BUILD_VERSION}.tar.gz" diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index 09b295632b82..16d86e22ec02 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -74,9 +74,6 @@ namespace engine { namespace { -constexpr int64_t kMicrosPerSecond = 1000000; -constexpr int64_t kMicrosPerMilli = 1000; - Id NormalizeFunctionName(Id id) { // Substrait plans encode the types into the function name so it might look like // add:opt_i32_i32. We don't care about the :opt_i32_i32 so we just trim it @@ -121,9 +118,19 @@ Status DecodeOption(const substrait::FunctionOption& opt, SubstraitCall* call) { Result DecodeScalarFunction( Id id, const substrait::Expression::ScalarFunction& scalar_fn, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { - ARROW_ASSIGN_OR_RAISE(auto output_type_and_nullable, - FromProto(scalar_fn.output_type(), ext_set, conversion_options)); - SubstraitCall call(id, output_type_and_nullable.first, output_type_and_nullable.second); + std::shared_ptr output_type; + bool output_nullable = true; + if (scalar_fn.output_type().kind_case() == substrait::Type::kUnknown) { + output_nullable = scalar_fn.output_type().unknown().nullability() != + substrait::Type::NULLABILITY_REQUIRED; + } else { + ARROW_ASSIGN_OR_RAISE( + auto output_type_and_nullable, + FromProto(scalar_fn.output_type(), ext_set, conversion_options)); + output_type = std::move(output_type_and_nullable.first); + output_nullable = output_type_and_nullable.second; + } + SubstraitCall call(id, std::move(output_type), output_nullable); for (int i = 0; i < scalar_fn.arguments_size(); i++) { ARROW_RETURN_NOT_OK( DecodeArg(scalar_fn.arguments(i), i, &call, ext_set, conversion_options)); @@ -296,6 +303,20 @@ Result FromProto(const substrait::Expression& expr, return FromProto(ref, ext_set, conversion_options, std::move(out)); } + case substrait::Expression::kNamedExpression: { + const auto& named_expr = expr.named_expression(); + if (named_expr.names_size() == 0) { + return Status::Invalid( + "substrait::Expression::NamedExpression had no name components"); + } + std::vector refs; + refs.reserve(named_expr.names_size()); + for (const auto& name : named_expr.names()) { + refs.emplace_back(std::string(name)); + } + return compute::field_ref(FieldRef(std::move(refs))); + } + case substrait::Expression::kIfThen: { const auto& if_then = expr.if_then(); if (!if_then.has_else_()) break; @@ -360,7 +381,8 @@ Result FromProto(const substrait::Expression& expr, function_id = NormalizeFunctionName(function_id); ExtensionIdRegistry::SubstraitCallToArrow function_converter; - if (function_id.uri.empty() || function_id.uri[0] == '/') { + if (function_id.uri.empty() || function_id.uri[0] == '/' || + function_id.uri == kSubstraitUnknownFunctionsUri) { // Currently the Substrait project has not aligned on a standard URI and often // seems to use /. In that case we fall back to name-only matching. ARROW_ASSIGN_OR_RAISE( @@ -528,29 +550,34 @@ Result FromProto(const substrait::Expression::Literal& lit, case substrait::Expression::Literal::kBinary: return Datum(BinaryScalar(lit.binary())); - ARROW_SUPPRESS_DEPRECATION_WARNING - case substrait::Expression::Literal::kTimestamp: - return Datum( - TimestampScalar(static_cast(lit.timestamp()), TimeUnit::MICRO)); - - case substrait::Expression::Literal::kTimestampTz: - return Datum(TimestampScalar(static_cast(lit.timestamp_tz()), - TimeUnit::MICRO, TimestampTzTimezoneString())); - ARROW_UNSUPPRESS_DEPRECATION_WARNING case substrait::Expression::Literal::kPrecisionTimestamp: { - // https://github.com/substrait-io/substrait/issues/611 - // TODO(GH-40741) don't break, return precision timestamp - break; + ARROW_ASSIGN_OR_RAISE(std::shared_ptr type, + precision_timestamp(lit.precision_timestamp().precision())); + return Datum(TimestampScalar(lit.precision_timestamp().value(), std::move(type))); } case substrait::Expression::Literal::kPrecisionTimestampTz: { - // https://github.com/substrait-io/substrait/issues/611 - // TODO(GH-40741) don't break, return precision timestamp - break; + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr type, + precision_timestamp_tz(lit.precision_timestamp_tz().precision())); + return Datum( + TimestampScalar(lit.precision_timestamp_tz().value(), std::move(type))); } case substrait::Expression::Literal::kDate: return Datum(Date32Scalar(lit.date())); - case substrait::Expression::Literal::kTime: - return Datum(Time64Scalar(lit.time(), TimeUnit::MICRO)); + case substrait::Expression::Literal::kPrecisionTime: { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr type, + precision_time(lit.precision_time().precision())); + switch (type->id()) { + case Type::TIME32: + return Datum( + Time32Scalar(static_cast(lit.precision_time().value()), type)); + case Type::TIME64: + return Datum(Time64Scalar(lit.precision_time().value(), type)); + default: + return Status::Invalid("Unexpected Arrow type for Substrait precision_time: ", + type->ToString()); + } + } case substrait::Expression::Literal::kIntervalYearToMonth: case substrait::Expression::Literal::kIntervalDayToSecond: { @@ -887,60 +914,68 @@ struct ScalarToProtoImpl { return EncodeUserDefined(*s.type, value); } - Status Visit(const TimestampScalar& s) { + template + Status VisitTimestamp(const TimestampScalar& s, void (Lit::*set_allocated_sub)(Sub*)) { const auto& t = checked_cast(*s.type); - - uint64_t micros; + auto timestamp = std::make_unique(); + timestamp->set_value(s.value); switch (t.unit()) { case TimeUnit::SECOND: - micros = s.value * kMicrosPerSecond; + timestamp->set_precision(0); break; case TimeUnit::MILLI: - micros = s.value * kMicrosPerMilli; + timestamp->set_precision(3); break; case TimeUnit::MICRO: - micros = s.value; + timestamp->set_precision(6); break; case TimeUnit::NANO: - // TODO(GH-40741): can support nanos when - // https://github.com/substrait-io/substrait/issues/611 is resolved - return NotImplemented(s); + timestamp->set_precision(9); + break; default: return NotImplemented(s); } + (lit_->*set_allocated_sub)(timestamp.release()); + return Status::OK(); + } - // Remove these and use precision timestamp once - // https://github.com/substrait-io/substrait/issues/611 is resolved - ARROW_SUPPRESS_DEPRECATION_WARNING - - if (t.timezone() == "") { - lit_->set_timestamp(micros); - } else { - // Some loss of info here, Substrait doesn't store timezone - // in field data - lit_->set_timestamp_tz(micros); + Status Visit(const TimestampScalar& s) { + const auto& t = checked_cast(*s.type); + if (t.timezone().empty()) { + return VisitTimestamp(s, &Lit::set_allocated_precision_timestamp); } - ARROW_UNSUPPRESS_DEPRECATION_WARNING - - return Status::OK(); + return VisitTimestamp(s, &Lit::set_allocated_precision_timestamp_tz); } // Need to support parameterized UDTs - Status Visit(const Time32Scalar& s) { - google::protobuf::Int32Value value; - value.set_value(s.value); - return EncodeUserDefined(*s.type, value); - } - Status Visit(const Time64Scalar& s) { - if (checked_cast(*s.type).unit() == TimeUnit::MICRO) { - return Primitive(&Lit::set_time, s); - } else { - google::protobuf::Int64Value value; - value.set_value(s.value); - return EncodeUserDefined(*s.type, value); + template + Status VisitTime(const ScalarType& s) { + const auto& t = checked_cast(*s.type); + auto time = std::make_unique(); + time->set_value(static_cast(s.value)); + switch (t.unit()) { + case TimeUnit::SECOND: + time->set_precision(0); + break; + case TimeUnit::MILLI: + time->set_precision(3); + break; + case TimeUnit::MICRO: + time->set_precision(6); + break; + case TimeUnit::NANO: + time->set_precision(9); + break; + default: + return NotImplemented(s); } + lit_->set_allocated_precision_time(time.release()); + return Status::OK(); } + Status Visit(const Time32Scalar& s) { return VisitTime(s); } + Status Visit(const Time64Scalar& s) { return VisitTime(s); } + Status Visit(const MonthIntervalScalar& s) { return NotImplemented(s); } Status Visit(const DayTimeIntervalScalar& s) { return NotImplemented(s); } diff --git a/cpp/src/arrow/engine/substrait/extended_expression_internal.cc b/cpp/src/arrow/engine/substrait/extended_expression_internal.cc index e2e6d934372d..7301f2f7adcb 100644 --- a/cpp/src/arrow/engine/substrait/extended_expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/extended_expression_internal.cc @@ -43,6 +43,93 @@ Status AddExtensionSetToExtendedExpression(const ExtensionSet& ext_set, return AddExtensionSetToMessage(ext_set, expr); } +bool TypeContainsUnknown(const substrait::Type& type) { + switch (type.kind_case()) { + case substrait::Type::kUnknown: + return true; + case substrait::Type::kStruct: + for (const auto& child_type : type.struct_().types()) { + if (TypeContainsUnknown(child_type)) { + return true; + } + } + return false; + case substrait::Type::kList: + return type.list().has_type() && TypeContainsUnknown(type.list().type()); + case substrait::Type::kMap: + return (type.map().has_key() && TypeContainsUnknown(type.map().key())) || + (type.map().has_value() && TypeContainsUnknown(type.map().value())); + default: + return false; + } +} + +bool NamedStructContainsUnknownTypes(const substrait::NamedStruct& named_struct) { + if (!named_struct.has_struct_()) { + return false; + } + for (const auto& type : named_struct.struct_().types()) { + if (TypeContainsUnknown(type)) { + return true; + } + } + return false; +} + +void CollectDepthFirstFieldNames(const FieldVector& fields, + std::vector* names) { + for (const auto& field : fields) { + names->push_back(field->name()); + if (field->type()->id() == Type::STRUCT) { + CollectDepthFirstFieldNames(field->type()->fields(), names); + } + } +} + +Status ValidateSchemaNames(const substrait::NamedStruct& named_struct, + const Schema& schema) { + std::vector schema_names; + CollectDepthFirstFieldNames(schema.fields(), &schema_names); + if (schema_names.size() != static_cast(named_struct.names_size())) { + return Status::Invalid( + "The supplied Arrow schema did not match the ExtendedExpression base_schema. " + "Expected ", + named_struct.names_size(), " depth-first field names but got ", + schema_names.size()); + } + for (int i = 0; i < named_struct.names_size(); ++i) { + if (schema_names[static_cast(i)] != named_struct.names(i)) { + return Status::Invalid( + "The supplied Arrow schema did not match the ExtendedExpression " + "base_schema. Expected field name ", + named_struct.names(i), " at depth-first position ", i, " but got ", + schema_names[static_cast(i)]); + } + } + return Status::OK(); +} + +Result> ResolveInputSchema( + const substrait::NamedStruct& base_schema, const Schema* input_schema_override, + const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { + if (input_schema_override == NULLPTR) { + return FromProto(base_schema, ext_set, conversion_options); + } + + if (NamedStructContainsUnknownTypes(base_schema)) { + ARROW_RETURN_NOT_OK(ValidateSchemaNames(base_schema, *input_schema_override)); + return std::make_shared(*input_schema_override); + } + + ARROW_ASSIGN_OR_RAISE(auto decoded_schema, + FromProto(base_schema, ext_set, conversion_options)); + if (!decoded_schema->Equals(*input_schema_override, /*check_metadata=*/false)) { + return Status::Invalid( + "The supplied Arrow schema did not match the ExtendedExpression base_schema"); + } + return std::make_shared(*input_schema_override); +} + Status VisitNestedFields(const DataType& type, std::function visitor) { if (!is_nested(type.id())) { @@ -149,6 +236,7 @@ Result> CreateExpressionReferenc } // namespace Result FromProto(const substrait::ExtendedExpression& expression, + const Schema* input_schema_override, ExtensionSet* ext_set_out, const ConversionOptions& conversion_options, const ExtensionIdRegistry* registry) { @@ -162,8 +250,10 @@ Result FromProto(const substrait::ExtendedExpression& expressi ExtensionSet ext_set, GetExtensionSetFromExtendedExpression(expression, conversion_options, registry)); - ARROW_ASSIGN_OR_RAISE(bound_expressions.schema, - FromProto(expression.base_schema(), ext_set, conversion_options)); + ARROW_ASSIGN_OR_RAISE( + bound_expressions.schema, + ResolveInputSchema(expression.base_schema(), input_schema_override, ext_set, + conversion_options)); bound_expressions.named_expressions.reserve(expression.referred_expr_size()); diff --git a/cpp/src/arrow/engine/substrait/extended_expression_internal.h b/cpp/src/arrow/engine/substrait/extended_expression_internal.h index 45f89c8610b5..5219cc8de836 100644 --- a/cpp/src/arrow/engine/substrait/extended_expression_internal.h +++ b/cpp/src/arrow/engine/substrait/extended_expression_internal.h @@ -41,6 +41,7 @@ namespace engine { /// Convert a Substrait ExtendedExpression to a vector of expressions and output names ARROW_ENGINE_EXPORT Result FromProto(const substrait::ExtendedExpression& expression, + const Schema* input_schema_override, ExtensionSet* ext_set_out, const ConversionOptions& conversion_options, const ExtensionIdRegistry* extension_id_registry); diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h index 4f631e0f193d..1bb9eb4bcbd9 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.h +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -72,6 +72,7 @@ constexpr const char* kSubstraitAggregateGenericFunctionsUri = /// and any options are ignored. constexpr const char* kArrowSimpleExtensionFunctionsUri = "urn:arrow:substrait_simple_extension_function"; +constexpr const char* kSubstraitUnknownFunctionsUri = "extension:io.substrait:unknown"; struct ARROW_ENGINE_EXPORT Id { std::string_view uri, name; diff --git a/cpp/src/arrow/engine/substrait/extension_types.cc b/cpp/src/arrow/engine/substrait/extension_types.cc index f71b5f7185d0..6bd4fab4dffd 100644 --- a/cpp/src/arrow/engine/substrait/extension_types.cc +++ b/cpp/src/arrow/engine/substrait/extension_types.cc @@ -114,6 +114,21 @@ std::shared_ptr interval_year() { return IntervalYearType::Make({}); } std::shared_ptr interval_day() { return IntervalDayType::Make({}); } +Result> precision_time(int precision) { + switch (precision) { + case 0: + return time32(TimeUnit::SECOND); + case 3: + return time32(TimeUnit::MILLI); + case 6: + return time64(TimeUnit::MICRO); + case 9: + return time64(TimeUnit::NANO); + default: + return Status::NotImplemented("Unrecognized time precision (", precision, ")"); + } +} + Result> precision_timestamp(int precision) { switch (precision) { case 0: diff --git a/cpp/src/arrow/engine/substrait/extension_types.h b/cpp/src/arrow/engine/substrait/extension_types.h index ae71ad83f7e5..2e38823e4913 100644 --- a/cpp/src/arrow/engine/substrait/extension_types.h +++ b/cpp/src/arrow/engine/substrait/extension_types.h @@ -56,6 +56,10 @@ std::shared_ptr interval_year(); ARROW_ENGINE_EXPORT std::shared_ptr interval_day(); +/// constructs the appropriate time type given the precision +ARROW_ENGINE_EXPORT +Result> precision_time(int precision); + /// constructs the appropriate timestamp type given the precision /// no time zone ARROW_ENGINE_EXPORT diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index b9e663ed7b11..fd934d720290 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -691,12 +691,18 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& case substrait::JoinRel::JOIN_TYPE_RIGHT: join_type = acero::JoinType::RIGHT_OUTER; break; - case substrait::JoinRel::JOIN_TYPE_SEMI: + case substrait::JoinRel::JOIN_TYPE_LEFT_SEMI: join_type = acero::JoinType::LEFT_SEMI; break; - case substrait::JoinRel::JOIN_TYPE_ANTI: + case substrait::JoinRel::JOIN_TYPE_RIGHT_SEMI: + join_type = acero::JoinType::RIGHT_SEMI; + break; + case substrait::JoinRel::JOIN_TYPE_LEFT_ANTI: join_type = acero::JoinType::LEFT_ANTI; break; + case substrait::JoinRel::JOIN_TYPE_RIGHT_ANTI: + join_type = acero::JoinType::RIGHT_ANTI; + break; default: return Status::Invalid("Unsupported join type"); } @@ -867,12 +873,17 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& std::vector keys; if (aggregate.groupings_size() > 0) { const substrait::AggregateRel::Grouping& group = aggregate.groupings(0); - int grouping_expr_size = group.grouping_expressions_size(); + int grouping_expr_size = group.expression_references_size(); keys.reserve(grouping_expr_size); for (int exp_id = 0; exp_id < grouping_expr_size; exp_id++) { - ARROW_ASSIGN_OR_RAISE( - compute::Expression expr, - FromProto(group.grouping_expressions(exp_id), ext_set, conversion_options)); + uint32_t ref = group.expression_references(exp_id); + if (ref >= static_cast(aggregate.grouping_expressions_size())) { + return Status::Invalid("Aggregate grouping expression reference ", ref, + " was out of range"); + } + ARROW_ASSIGN_OR_RAISE(compute::Expression expr, + FromProto(aggregate.grouping_expressions(ref), ext_set, + conversion_options)); const FieldRef* field_ref = expr.field_ref(); if (field_ref) { keys.emplace_back(std::move(*field_ref)); diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index 5ce97cb0ccfd..684e405a2454 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -248,7 +248,17 @@ Result DeserializeExpressions( const ConversionOptions& conversion_options, ExtensionSet* ext_set_out) { ARROW_ASSIGN_OR_RAISE(auto extended_expression, ParseFromBuffer(buf)); - return FromProto(extended_expression, ext_set_out, conversion_options, registry); + return FromProto(extended_expression, /*input_schema_override=*/NULLPTR, ext_set_out, + conversion_options, registry); +} + +Result DeserializeExpressions( + const Buffer& buf, const Schema& input_schema, const ExtensionIdRegistry* registry, + const ConversionOptions& conversion_options, ExtensionSet* ext_set_out) { + ARROW_ASSIGN_OR_RAISE(auto extended_expression, + ParseFromBuffer(buf)); + return FromProto(extended_expression, &input_schema, ext_set_out, conversion_options, + registry); } namespace { diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index ab749f4a64b0..88fda211bda0 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -183,6 +183,27 @@ ARROW_ENGINE_EXPORT Result DeserializeExpressions( const ConversionOptions& conversion_options = {}, ExtensionSet* ext_set_out = NULLPTR); +/// \brief Deserialize a Substrait ExtendedExpression message using a supplied schema +/// +/// This overload is intended for partially bound Substrait expressions whose +/// `base_schema` uses unresolved types or whose expressions use unresolved +/// names. The supplied Arrow schema is used to bind the deserialized expressions +/// before they are returned. +/// +/// \param[in] buf a buffer containing the protobuf serialization of a collection of +/// expressions +/// \param[in] input_schema the Arrow schema to bind unresolved expressions against +/// \param[in] registry an extension-id-registry to use, or null for the default one +/// \param[in] conversion_options options to control how the conversion is done +/// \param[out] ext_set_out if non-null, the extension mapping used by the Substrait +/// message is returned here. +/// \return A collection of expressions bound to the supplied schema +ARROW_ENGINE_EXPORT Result DeserializeExpressions( + const Buffer& buf, const Schema& input_schema, + const ExtensionIdRegistry* registry = NULLPTR, + const ConversionOptions& conversion_options = {}, + ExtensionSet* ext_set_out = NULLPTR); + /// \brief Deserializes a Substrait Type message to the corresponding Arrow type /// /// \param[in] buf a buffer containing the protobuf serialization of a Substrait Type diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 138d03b24791..fe149326483b 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -267,10 +267,11 @@ TEST(Substrait, SupportedTypes) { ExpectEq(R"({"string": {}})", utf8()); ExpectEq(R"({"binary": {}})", binary()); - ExpectEq(R"({"timestamp": {}})", timestamp(TimeUnit::MICRO)); + ExpectEq(R"({"precision_timestamp": {"precision": 6}})", timestamp(TimeUnit::MICRO)); ExpectEq(R"({"date": {}})", date32()); - ExpectEq(R"({"time": {}})", time64(TimeUnit::MICRO)); - ExpectEq(R"({"timestamp_tz": {}})", timestamp(TimeUnit::MICRO, "UTC")); + ExpectEq(R"({"precision_time": {"precision": 6}})", time64(TimeUnit::MICRO)); + ExpectEq(R"({"precision_timestamp_tz": {"precision": 6}})", + timestamp(TimeUnit::MICRO, "UTC")); ExpectEq(R"({"interval_year": {}})", interval_year()); ExpectEq(R"({"interval_day": {}})", interval_day()); @@ -532,11 +533,13 @@ TEST(Substrait, SupportedLiterals) { ExpectEq(R"({"binary": "enp6"})", BinaryScalar(Buffer::FromString("zzz"))); - ExpectEq(R"({"timestamp": "579"})", TimestampScalar(579, TimeUnit::MICRO)); + ExpectEq(R"({"precision_timestamp": {"precision": 6, "value": "579"}})", + TimestampScalar(579, TimeUnit::MICRO)); ExpectEq(R"({"date": "5"})", Date32Scalar(5)); - ExpectEq(R"({"time": "64"})", Time64Scalar(64, TimeUnit::MICRO)); + ExpectEq(R"({"precision_time": {"precision": 6, "value": "64"}})", + Time64Scalar(64, TimeUnit::MICRO)); ExpectEq(R"({"interval_year_to_month": {"years": 34, "months": 3}})", ExtensionScalar(FixedSizeListScalar(ArrayFromJSON(int32(), "[34, 3]")), @@ -561,7 +564,8 @@ TEST(Substrait, SupportedLiterals) { R"({"decimal": {"value": "0gKWSQAAAAAAAAAAAAAAAA==", "precision": 27, "scale": 5}})", Decimal128Scalar(Decimal128("123456789.0"), decimal128(27, 5))); - ExpectEq(R"({"timestamp_tz": "579"})", TimestampScalar(579, TimeUnit::MICRO, "UTC")); + ExpectEq(R"({"precision_timestamp_tz": {"precision": 6, "value": "579"}})", + TimestampScalar(579, TimeUnit::MICRO, "UTC")); // special case for empty lists ExpectEq(R"({"empty_list": {"type": {"i32": {}}}})", @@ -1197,26 +1201,26 @@ TEST(Substrait, ExtensionSetFromPlan) { } }} ], - "extension_uris": [ + "extensionUrns": [ { - "extension_uri_anchor": 7, - "uri": ")" + default_extension_types_uri() + + "extensionUrnAnchor": 7, + "urn": ")" + default_extension_types_uri() + R"(" }, { - "extension_uri_anchor": 18, - "uri": ")" + kSubstraitArithmeticFunctionsUri + + "extensionUrnAnchor": 18, + "urn": ")" + kSubstraitArithmeticFunctionsUri + R"(" } ], "extensions": [ {"extension_type": { - "extension_uri_reference": 7, + "extensionUrnReference": 7, "type_anchor": 42, "name": "null" }}, {"extension_function": { - "extension_uri_reference": 18, + "extensionUrnReference": 18, "function_anchor": 42, "name": "add" }} @@ -1247,16 +1251,16 @@ TEST(Substrait, ExtensionSetFromPlan) { TEST(Substrait, ExtensionSetFromPlanMissingFunc) { std::string substrait_json = R"({ "relations": [], - "extension_uris": [ + "extensionUrns": [ { - "extension_uri_anchor": 7, - "uri": ")" + default_extension_types_uri() + + "extensionUrnAnchor": 7, + "urn": ")" + default_extension_types_uri() + R"(" } ], "extensions": [ {"extension_function": { - "extension_uri_reference": 7, + "extensionUrnReference": 7, "function_anchor": 42, "name": "does_not_exist" }} @@ -1295,16 +1299,16 @@ TEST(Substrait, ExtensionSetFromPlanExhaustedFactory) { } }} ], - "extension_uris": [ + "extensionUrns": [ { - "extension_uri_anchor": 7, - "uri": ")" + default_extension_types_uri() + + "extensionUrnAnchor": 7, + "urn": ")" + default_extension_types_uri() + R"(" } ], "extensions": [ {"extension_function": { - "extension_uri_reference": 7, + "extensionUrnReference": 7, "function_anchor": 42, "name": "add" }} @@ -1335,16 +1339,16 @@ TEST(Substrait, ExtensionSetFromPlanRegisterFunc) { std::string substrait_json = R"({ "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [], - "extension_uris": [ + "extensionUrns": [ { - "extension_uri_anchor": 7, - "uri": ")" + default_extension_types_uri() + + "extensionUrnAnchor": 7, + "urn": ")" + default_extension_types_uri() + R"(" } ], "extensions": [ {"extension_function": { - "extension_uri_reference": 7, + "extensionUrnReference": 7, "function_anchor": 42, "name": "new_func" }} @@ -1545,7 +1549,7 @@ TEST(Substrait, InvalidMinimumVersion) { } } }], - "extensionUris": [], + "extensionUrns": [], "extensions": [], })")); @@ -1643,16 +1647,16 @@ TEST(Substrait, JoinPlanBasic) { } } }], - "extension_uris": [ + "extensionUrns": [ { - "extension_uri_anchor": 0, - "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) + + "extensionUrnAnchor": 0, + "urn": ")" + std::string(kSubstraitComparisonFunctionsUri) + R"(" } ], "extensions": [ {"extension_function": { - "extension_uri_reference": 0, + "extensionUrnReference": 0, "function_anchor": 0, "name": "equal" }} @@ -1789,16 +1793,16 @@ TEST(Substrait, JoinPlanInvalidKeyCmp) { } } }], - "extension_uris": [ + "extensionUrns": [ { - "extension_uri_anchor": 0, - "uri": ")" + std::string(kSubstraitArithmeticFunctionsUri) + + "extensionUrnAnchor": 0, + "urn": ")" + std::string(kSubstraitArithmeticFunctionsUri) + R"(" } ], "extensions": [ {"extension_function": { - "extension_uri_reference": 0, + "extensionUrnReference": 0, "function_anchor": 0, "name": "add" }} @@ -1992,8 +1996,7 @@ TEST(Substrait, AggregateBasic) { } } }, - "groupings": [{ - "groupingExpressions": [{ + "groupingExpressions": [{ "selection": { "directReference": { "structField": { @@ -2001,7 +2004,9 @@ TEST(Substrait, AggregateBasic) { } } } - }] + }], + "groupings": [{ + "expressionReferences": [0] }], "measures": [{ "measure": { @@ -2027,13 +2032,13 @@ TEST(Substrait, AggregateBasic) { } } }], - "extensionUris": [{ - "extension_uri_anchor": 0, - "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + "extensionUrns": [{ + "extensionUrnAnchor": 0, + "urn": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" }], "extensions": [{ "extension_function": { - "extension_uri_reference": 0, + "extensionUrnReference": 0, "function_anchor": 0, "name": "sum" } @@ -2066,13 +2071,13 @@ TEST(Substrait, AggregateInvalidRel) { } } }], - "extensionUris": [{ - "extension_uri_anchor": 0, - "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + "extensionUrns": [{ + "extensionUrnAnchor": 0, + "urn": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" }], "extensions": [{ "extension_function": { - "extension_uri_reference": 0, + "extensionUrnReference": 0, "function_anchor": 0, "name": "sum" } @@ -2114,8 +2119,7 @@ TEST(Substrait, AggregateInvalidFunction) { } } }, - "groupings": [{ - "groupingExpressions": [{ + "groupingExpressions": [{ "selection": { "directReference": { "structField": { @@ -2123,20 +2127,22 @@ TEST(Substrait, AggregateInvalidFunction) { } } } - }] + }], + "groupings": [{ + "expressionReferences": [0] }], "measures": [{ }] } } }], - "extensionUris": [{ - "extension_uri_anchor": 0, - "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + "extensionUrns": [{ + "extensionUrnAnchor": 0, + "urn": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" }], "extensions": [{ "extension_function": { - "extension_uri_reference": 0, + "extensionUrnReference": 0, "function_anchor": 0, "name": "sum" } @@ -2178,8 +2184,7 @@ TEST(Substrait, AggregateInvalidAggFuncArgs) { } } }, - "groupings": [{ - "groupingExpressions": [{ + "groupingExpressions": [{ "selection": { "directReference": { "structField": { @@ -2187,7 +2192,9 @@ TEST(Substrait, AggregateInvalidAggFuncArgs) { } } } - }] + }], + "groupings": [{ + "expressionReferences": [0] }], "measures": [{ "measure": { @@ -2204,13 +2211,13 @@ TEST(Substrait, AggregateInvalidAggFuncArgs) { } } }], - "extensionUris": [{ - "extension_uri_anchor": 0, - "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + "extensionUrns": [{ + "extensionUrnAnchor": 0, + "urn": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" }], "extensions": [{ "extension_function": { - "extension_uri_reference": 0, + "extensionUrnReference": 0, "function_anchor": 0, "name": "sum" } @@ -2252,8 +2259,7 @@ TEST(Substrait, AggregateWithFilter) { } } }, - "groupings": [{ - "groupingExpressions": [{ + "groupingExpressions": [{ "selection": { "directReference": { "structField": { @@ -2261,7 +2267,9 @@ TEST(Substrait, AggregateWithFilter) { } } } - }] + }], + "groupings": [{ + "expressionReferences": [0] }], "measures": [{ "measure": { @@ -2278,13 +2286,13 @@ TEST(Substrait, AggregateWithFilter) { } } }], - "extensionUris": [{ - "extension_uri_anchor": 0, - "uri": "https://github.com/apache/arrow/blob/main/format/substrait/extension_types.yaml" + "extensionUrns": [{ + "extensionUrnAnchor": 0, + "urn": "https://github.com/apache/arrow/blob/main/format/substrait/extension_types.yaml" }], "extensions": [{ "extension_function": { - "extension_uri_reference": 0, + "extensionUrnReference": 0, "function_anchor": 0, "name": "equal" } @@ -2326,8 +2334,7 @@ TEST(Substrait, AggregateBadPhase) { } } }, - "groupings": [{ - "groupingExpressions": [{ + "groupingExpressions": [{ "selection": { "directReference": { "structField": { @@ -2335,7 +2342,9 @@ TEST(Substrait, AggregateBadPhase) { } } } - }] + }], + "groupings": [{ + "expressionReferences": [0] }], "measures": [{ "measure": { @@ -2352,13 +2361,13 @@ TEST(Substrait, AggregateBadPhase) { } } }], - "extensionUris": [{ - "extension_uri_anchor": 0, - "uri": "https://github.com/apache/arrow/blob/main/format/substrait/extension_types.yaml" + "extensionUrns": [{ + "extensionUrnAnchor": 0, + "urn": "https://github.com/apache/arrow/blob/main/format/substrait/extension_types.yaml" }], "extensions": [{ "extension_function": { - "extension_uri_reference": 0, + "extensionUrnReference": 0, "function_anchor": 0, "name": "equal" } @@ -2719,16 +2728,16 @@ TEST(SubstraitRoundTrip, ProjectRel) { } } }], - "extension_uris": [ + "extensionUrns": [ { - "extension_uri_anchor": 0, - "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) + + "extensionUrnAnchor": 0, + "urn": ")" + std::string(kSubstraitComparisonFunctionsUri) + R"(" } ], "extensions": [ {"extension_function": { - "extension_uri_reference": 0, + "extensionUrnReference": 0, "function_anchor": 0, "name": "equal" }} @@ -2838,16 +2847,16 @@ TEST(SubstraitRoundTrip, ProjectRelOnFunctionWithEmit) { } } }], - "extension_uris": [ + "extensionUrns": [ { - "extension_uri_anchor": 0, - "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) + + "extensionUrnAnchor": 0, + "urn": ")" + std::string(kSubstraitComparisonFunctionsUri) + R"(" } ], "extensions": [ {"extension_function": { - "extension_uri_reference": 0, + "extensionUrnReference": 0, "function_anchor": 0, "name": "equal" }} @@ -3020,16 +3029,16 @@ TEST(SubstraitRoundTrip, ProjectRelOnFunctionWithAllEmit) { } } ], - "extension_uris": [ + "extensionUrns": [ { - "extension_uri_anchor": 0, - "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) + + "extensionUrnAnchor": 0, + "urn": ")" + std::string(kSubstraitComparisonFunctionsUri) + R"(" } ], "extensions": [ {"extension_function": { - "extension_uri_reference": 0, + "extensionUrnReference": 0, "function_anchor": 0, "name": "equal" }} @@ -3200,16 +3209,16 @@ TEST(SubstraitRoundTrip, FilterRelWithEmit) { } } }], - "extension_uris": [ + "extensionUrns": [ { - "extension_uri_anchor": 0, - "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) + + "extensionUrnAnchor": 0, + "urn": ")" + std::string(kSubstraitComparisonFunctionsUri) + R"(" } ], "extensions": [ {"extension_function": { - "extension_uri_reference": 0, + "extensionUrnReference": 0, "function_anchor": 0, "name": "equal" }} @@ -3329,16 +3338,16 @@ TEST(SubstraitRoundTrip, JoinRel) { } } }], - "extension_uris": [ + "extensionUrns": [ { - "extension_uri_anchor": 0, - "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) + + "extensionUrnAnchor": 0, + "urn": ")" + std::string(kSubstraitComparisonFunctionsUri) + R"(" } ], "extensions": [ {"extension_function": { - "extension_uri_reference": 0, + "extensionUrnReference": 0, "function_anchor": 0, "name": "equal" }} @@ -3483,16 +3492,16 @@ TEST(SubstraitRoundTrip, JoinRelWithEmit) { } } }], - "extension_uris": [ + "extensionUrns": [ { - "extension_uri_anchor": 0, - "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) + + "extensionUrnAnchor": 0, + "urn": ")" + std::string(kSubstraitComparisonFunctionsUri) + R"(" } ], "extensions": [ {"extension_function": { - "extension_uri_reference": 0, + "extensionUrnReference": 0, "function_anchor": 0, "name": "equal" }} @@ -3574,8 +3583,7 @@ TEST(SubstraitRoundTrip, AggregateRel) { } } }, - "groupings": [{ - "groupingExpressions": [{ + "groupingExpressions": [{ "selection": { "directReference": { "structField": { @@ -3583,7 +3591,9 @@ TEST(SubstraitRoundTrip, AggregateRel) { } } } - }] + }], + "groupings": [{ + "expressionReferences": [0] }], "measures": [{ "measure": { @@ -3610,13 +3620,13 @@ TEST(SubstraitRoundTrip, AggregateRel) { } } }], - "extensionUris": [{ - "extension_uri_anchor": 0, - "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + "extensionUrns": [{ + "extensionUrnAnchor": 0, + "urn": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" }], "extensions": [{ "extension_function": { - "extension_uri_reference": 0, + "extensionUrnReference": 0, "function_anchor": 0, "name": "sum" } @@ -3681,8 +3691,7 @@ TEST(SubstraitRoundTrip, AggregateRelOptions) { } } }, - "groupings": [{ - "groupingExpressions": [{ + "groupingExpressions": [{ "selection": { "directReference": { "structField": { @@ -3690,7 +3699,9 @@ TEST(SubstraitRoundTrip, AggregateRelOptions) { } } } - }] + }], + "groupings": [{ + "expressionReferences": [0] }], "measures": [{ "measure": { @@ -3723,13 +3734,13 @@ TEST(SubstraitRoundTrip, AggregateRelOptions) { } } }], - "extensionUris": [{ - "extension_uri_anchor": 0, - "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + "extensionUrns": [{ + "extensionUrnAnchor": 0, + "urn": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" }], "extensions": [{ "extension_function": { - "extension_uri_reference": 0, + "extensionUrnReference": 0, "function_anchor": 0, "name": "variance" } @@ -3800,8 +3811,7 @@ TEST(SubstraitRoundTrip, AggregateRelEmit) { } } }, - "groupings": [{ - "groupingExpressions": [{ + "groupingExpressions": [{ "selection": { "directReference": { "structField": { @@ -3809,7 +3819,9 @@ TEST(SubstraitRoundTrip, AggregateRelEmit) { } } } - }] + }], + "groupings": [{ + "expressionReferences": [0] }], "measures": [{ "measure": { @@ -3836,13 +3848,13 @@ TEST(SubstraitRoundTrip, AggregateRelEmit) { } } }], - "extensionUris": [{ - "extension_uri_anchor": 0, - "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + "extensionUrns": [{ + "extensionUrnAnchor": 0, + "urn": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" }], "extensions": [{ "extension_function": { - "extension_uri_reference": 0, + "extensionUrnReference": 0, "function_anchor": 0, "name": "sum" } @@ -3874,14 +3886,14 @@ TEST(Substrait, IsthmusPlan) { // isthmus -c "CREATE TABLE T1(foo int)" "SELECT foo + 1 FROM T1" std::string substrait_json = R"({ "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, - "extensionUris": [{ - "extensionUriAnchor": 1, - "uri": "/functions_arithmetic.yaml" + "extensionUrns": [{ + "extension_urn_anchor": 1, + "urn": "/functions_arithmetic.yaml" }], "extensions": [{ - "extensionFunction": { - "extensionUriReference": 1, - "functionAnchor": 0, + "extension_function": { + "extension_urn_reference": 1, + "function_anchor": 0, "name": "add:i32_i32" } }], @@ -3996,13 +4008,13 @@ TEST(Substrait, ProjectWithMultiFieldExpressions) { ])"}); const std::string substrait_json = R"({ - "extensionUris": [{ - "extensionUriAnchor": 1, - "uri": "/functions_arithmetic.yaml" + "extensionUrns": [{ + "extensionUrnAnchor": 1, + "urn": "/functions_arithmetic.yaml" }], "extensions": [{ "extensionFunction": { - "extensionUriReference": 1, + "extensionUrnReference": 1, "functionAnchor": 0, "name": "add:i32_i32" } @@ -4148,16 +4160,16 @@ TEST(Substrait, NestedProjectWithMultiFieldExpressions) { ])"}); const std::string substrait_json = R"({ - "extensionUris": [ + "extensionUrns": [ { - "extensionUriAnchor": 1, - "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + "extensionUrnAnchor": 1, + "urn": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" } ], "extensions": [ { "extensionFunction": { - "extensionUriReference": 1, + "extensionUrnReference": 1, "functionAnchor": 2, "name": "add" } @@ -4234,16 +4246,16 @@ TEST(Substrait, NestedEmitProjectWithMultiFieldExpressions) { ])"}); const std::string substrait_json = R"({ - "extensionUris": [ + "extensionUrns": [ { - "extensionUriAnchor": 1, - "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + "extensionUrnAnchor": 1, + "urn": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" } ], "extensions": [ { "extensionFunction": { - "extensionUriReference": 1, + "extensionUrnReference": 1, "functionAnchor": 2, "name": "add" } @@ -4640,7 +4652,7 @@ TEST(Substrait, PlanWithAsOfJoinExtension) { #endif // This demos an extension relation std::string substrait_json = R"({ - "extensionUris": [], + "extensionUrns": [], "extensions": [], "relations": [{ "root": { @@ -5038,29 +5050,29 @@ TEST(Substrait, CompoundEmitFilterless) { } } }], - "extension_uris": [ + "extensionUrns": [ { - "extension_uri_anchor": 42, - "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) + + "extensionUrnAnchor": 42, + "urn": ")" + std::string(kSubstraitComparisonFunctionsUri) + R"(" }, { - "extension_uri_anchor": 72, - "uri": ")" + std::string(kSubstraitArithmeticFunctionsUri) + + "extensionUrnAnchor": 72, + "urn": ")" + std::string(kSubstraitArithmeticFunctionsUri) + R"(" } ], "extensions": [ { "extension_function": { - "extension_uri_reference": 42, + "extensionUrnReference": 42, "function_anchor": 14, "name": "equal" } }, { "extension_function": { - "extension_uri_reference": 72, + "extensionUrnReference": 72, "function_anchor": 32, "name": "add" } @@ -5363,36 +5375,36 @@ TEST(Substrait, CompoundEmitWithFilter) { } } }], - "extension_uris": [ + "extensionUrns": [ { - "extension_uri_anchor": 42, - "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) + + "extensionUrnAnchor": 42, + "urn": ")" + std::string(kSubstraitComparisonFunctionsUri) + R"(" }, { - "extension_uri_anchor": 72, - "uri": ")" + std::string(kSubstraitArithmeticFunctionsUri) + + "extensionUrnAnchor": 72, + "urn": ")" + std::string(kSubstraitArithmeticFunctionsUri) + R"(" } ], "extensions": [ { "extension_function": { - "extension_uri_reference": 42, + "extensionUrnReference": 42, "function_anchor": 14, "name": "equal" } }, { "extension_function": { - "extension_uri_reference": 42, + "extensionUrnReference": 42, "function_anchor": 25, "name": "lt" } }, { "extension_function": { - "extension_uri_reference": 72, + "extensionUrnReference": 72, "function_anchor": 32, "name": "add" } @@ -5518,7 +5530,7 @@ TEST(Substrait, SortAndFetch) { } } ], - "extension_uris": [], + "extensionUrns": [], "extensions": [] })"; @@ -5623,7 +5635,7 @@ TEST(Substrait, MixedSort) { } } ], - "extension_uris": [], + "extensionUrns": [], "extensions": [] })"; @@ -5663,7 +5675,7 @@ TEST(Substrait, PlanWithExtension) { // This demos an extension relation std::string substrait_json = R"({ - "extensionUris": [], + "extensionUrns": [], "extensions": [], "relations": [{ "root": { @@ -5853,7 +5865,7 @@ TEST(Substrait, AsOfJoinDefaultEmit) { GTEST_SKIP() << "ASOF join requires threading"; #endif std::string substrait_json = R"({ - "extensionUris": [], + "extensionUrns": [], "extensions": [], "relations": [{ "root": { @@ -6035,7 +6047,7 @@ TEST(Substrait, AsOfJoinDefaultEmit) { TEST(Substrait, PlanWithNamedTapExtension) { // This demos an extension relation std::string substrait_json = R"({ - "extensionUris": [], + "extensionUrns": [], "extensions": [], "relations": [{ "root": { @@ -6122,13 +6134,13 @@ TEST(Substrait, PlanWithNamedTapExtension) { TEST(Substrait, PlanWithSegmentedAggregateExtension) { // This demos an extension relation std::string substrait_json = R"({ - "extensionUris": [{ - "extension_uri_anchor": 0, - "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + "extensionUrns": [{ + "extensionUrnAnchor": 0, + "urn": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" }], "extensions": [{ "extension_function": { - "extension_uri_reference": 0, + "extensionUrnReference": 0, "function_anchor": 0, "name": "sum" } @@ -6341,5 +6353,164 @@ TEST(Substrait, ExtendedExpressionInvalidPlans) { Raises(StatusCode::Invalid, testing::HasSubstr("Ambiguous plan"))); } +TEST(Substrait, ExtendedExpressionDeserializeUnboundWithSchema) { + constexpr std::string_view kUnboundNamedAdd = R"( + { + "version": {"majorNumber": 9999}, + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:unknown" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUrnReference": 1, + "functionAnchor": 1, + "name": "add:unknown_unknown" + } + } + ], + "referredExpr": [ + { + "expression": { + "scalarFunction": { + "functionReference": 1, + "arguments": [ + { + "value": { + "namedExpression": { + "names": ["a"] + } + } + }, + { + "value": { + "namedExpression": { + "names": ["b"] + } + } + } + ], + "outputType": { + "unknown": {} + } + } + }, + "outputNames": ["sum"] + } + ], + "baseSchema": { + "names": ["a", "b"], + "struct": { + "types": [ + { + "unknown": {} + }, + { + "unknown": {} + } + ] + } + } + } + )"; + + ASSERT_OK_AND_ASSIGN(auto buf, + internal::SubstraitFromJSON("ExtendedExpression", kUnboundNamedAdd, + /*ignore_unknown_fields=*/false)); + + auto binding_schema = schema({field("a", int32()), field("b", int32())}); + ASSERT_OK_AND_ASSIGN( + auto expected_expression, + compute::call("add", {compute::field_ref("a"), compute::field_ref("b")}) + .Bind(*binding_schema)); + + ASSERT_OK_AND_ASSIGN(auto bound_expressions, + DeserializeExpressions(*buf, *binding_schema)); + AssertSchemaEqual(*binding_schema, *bound_expressions.schema); + ASSERT_EQ(1, bound_expressions.named_expressions.size()); + EXPECT_EQ("sum", bound_expressions.named_expressions[0].name); + EXPECT_EQ(expected_expression, bound_expressions.named_expressions[0].expression); + + ASSERT_THAT(DeserializeExpressions(*buf), + Raises(StatusCode::Invalid, testing::HasSubstr("unknown"))); +} + +TEST(Substrait, ExtendedExpressionDeserializeUnboundSchemaMismatch) { + constexpr std::string_view kUnboundNamedAdd = R"( + { + "version": {"majorNumber": 9999}, + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:unknown" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUrnReference": 1, + "functionAnchor": 1, + "name": "add:unknown_unknown" + } + } + ], + "referredExpr": [ + { + "expression": { + "scalarFunction": { + "functionReference": 1, + "arguments": [ + { + "value": { + "namedExpression": { + "names": ["a"] + } + } + }, + { + "value": { + "namedExpression": { + "names": ["b"] + } + } + } + ], + "outputType": { + "unknown": {} + } + } + }, + "outputNames": ["sum"] + } + ], + "baseSchema": { + "names": ["a", "b"], + "struct": { + "types": [ + { + "unknown": {} + }, + { + "unknown": {} + } + ] + } + } + } + )"; + + ASSERT_OK_AND_ASSIGN(auto buf, + internal::SubstraitFromJSON("ExtendedExpression", kUnboundNamedAdd, + /*ignore_unknown_fields=*/false)); + + auto mismatched_schema = schema({field("x", int32()), field("b", int32())}); + ASSERT_THAT(DeserializeExpressions(*buf, *mismatched_schema), + Raises(StatusCode::Invalid, + testing::HasSubstr("did not match the ExtendedExpression"))); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/test_plan_builder.cc b/cpp/src/arrow/engine/substrait/test_plan_builder.cc index a8302145f548..d3f9f15a724c 100644 --- a/cpp/src/arrow/engine/substrait/test_plan_builder.cc +++ b/cpp/src/arrow/engine/substrait/test_plan_builder.cc @@ -131,9 +131,11 @@ Result> CreateAgg(Id function_id, if (!keys.empty()) { substrait::AggregateRel::Grouping* grouping = agg->add_groupings(); + uint32_t grouping_ref = 0; for (int key : keys) { - substrait::Expression* key_expr = grouping->add_grouping_expressions(); + substrait::Expression* key_expr = agg->add_grouping_expressions(); CreateDirectReference(key, key_expr); + grouping->add_expression_references(grouping_ref++); } } diff --git a/cpp/src/arrow/engine/substrait/type_internal.cc b/cpp/src/arrow/engine/substrait/type_internal.cc index 3e8c0dda765b..8aed3b934fef 100644 --- a/cpp/src/arrow/engine/substrait/type_internal.cc +++ b/cpp/src/arrow/engine/substrait/type_internal.cc @@ -126,13 +126,11 @@ Result, bool>> FromProto( case substrait::Type::kBinary: return FromProtoImpl(type.binary()); - ARROW_SUPPRESS_DEPRECATION_WARNING - case substrait::Type::kTimestamp: - return FromProtoImpl(type.timestamp(), TimeUnit::MICRO); - case substrait::Type::kTimestampTz: - return FromProtoImpl(type.timestamp_tz(), TimeUnit::MICRO, - TimestampTzTimezoneString()); - ARROW_UNSUPPRESS_DEPRECATION_WARNING + case substrait::Type::kPrecisionTime: { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr time_type, + precision_time(type.precision_time().precision())); + return std::make_pair(time_type, IsNullable(type.precision_time())); + } case substrait::Type::kPrecisionTimestamp: { ARROW_ASSIGN_OR_RAISE(std::shared_ptr ts_type, precision_timestamp(type.precision_timestamp().precision())); @@ -147,9 +145,6 @@ Result, bool>> FromProto( case substrait::Type::kDate: return FromProtoImpl(type.date()); - case substrait::Type::kTime: - return FromProtoImpl(type.time(), TimeUnit::MICRO); - case substrait::Type::kIntervalYear: return FromProtoImpl(type.interval_year(), interval_year); @@ -233,6 +228,11 @@ Result, bool>> FromProto( return std::make_pair(std::move(type_record.type), IsNullable(user_defined)); } + case substrait::Type::kUnknown: + return Status::Invalid( + "Substrait type 'unknown' cannot be deserialized to an Arrow type " + "without binding to a concrete schema"); + default: break; } @@ -335,15 +335,31 @@ struct DataTypeToProtoImpl { } } - Status Visit(const Time32Type& t) { return EncodeUserDefined(t); } - Status Visit(const Time64Type& t) { - if (t.unit() == TimeUnit::MICRO) { - return SetWith(&substrait::Type::set_allocated_time); - } else { - return EncodeUserDefined(t); + template + Status VisitTime(const TimeType& t) { + auto time = SetWithThen(&substrait::Type::set_allocated_precision_time); + switch (t.unit()) { + case TimeUnit::SECOND: + time->set_precision(0); + break; + case TimeUnit::MILLI: + time->set_precision(3); + break; + case TimeUnit::MICRO: + time->set_precision(6); + break; + case TimeUnit::NANO: + time->set_precision(9); + break; + default: + return NotImplemented(t); } + return Status::OK(); } + Status Visit(const Time32Type& t) { return VisitTime(t); } + Status Visit(const Time64Type& t) { return VisitTime(t); } + Status Visit(const MonthIntervalType& t) { return EncodeUserDefined(t); } Status Visit(const DayTimeIntervalType& t) { return EncodeUserDefined(t); } diff --git a/cpp/src/arrow/engine/substrait/util_internal.h b/cpp/src/arrow/engine/substrait/util_internal.h index d812bbf7b85f..bc8950ba6e72 100644 --- a/cpp/src/arrow/engine/substrait/util_internal.h +++ b/cpp/src/arrow/engine/substrait/util_internal.h @@ -49,9 +49,9 @@ Result GetExtensionSetFromMessage( registry = default_extension_id_registry(); } std::unordered_map uris; - uris.reserve(message.extension_uris_size()); - for (const auto& uri : message.extension_uris()) { - uris[uri.extension_uri_anchor()] = uri.uri(); + uris.reserve(message.extension_urns_size()); + for (const auto& uri : message.extension_urns()) { + uris[uri.extension_urn_anchor()] = uri.urn(); } // NOTE: it's acceptable to use views to memory owned by message; ExtensionSet::Make @@ -66,14 +66,14 @@ Result GetExtensionSetFromMessage( case substrait::extensions::SimpleExtensionDeclaration::kExtensionType: { const auto& type = ext.extension_type(); - std::string_view uri = uris[type.extension_uri_reference()]; + std::string_view uri = uris[type.extension_urn_reference()]; type_ids[type.type_anchor()] = Id{uri, type.name()}; break; } case substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: { const auto& fn = ext.extension_function(); - std::string_view uri = uris[fn.extension_uri_reference()]; + std::string_view uri = uris[fn.extension_urn_reference()]; function_ids[fn.function_anchor()] = Id{uri, fn.name()}; break; } @@ -89,19 +89,19 @@ Result GetExtensionSetFromMessage( template Status AddExtensionSetToMessage(const ExtensionSet& ext_set, Message* message) { - message->clear_extension_uris(); + message->clear_extension_urns(); std::unordered_map map; - auto uris = message->mutable_extension_uris(); + auto uris = message->mutable_extension_urns(); uris->Reserve(static_cast(ext_set.uris().size())); for (uint32_t anchor = 0; anchor < ext_set.uris().size(); ++anchor) { auto uri = ext_set.uris().at(anchor); if (uri.empty()) continue; - auto ext_uri = std::make_unique(); - ext_uri->set_uri(std::string(uri)); - ext_uri->set_extension_uri_anchor(anchor); + auto ext_uri = std::make_unique(); + ext_uri->set_urn(std::string(uri)); + ext_uri->set_extension_urn_anchor(anchor); uris->AddAllocated(ext_uri.release()); map[uri] = anchor; @@ -119,7 +119,7 @@ Status AddExtensionSetToMessage(const ExtensionSet& ext_set, Message* message) { auto ext_decl = std::make_unique(); auto type = std::make_unique(); - type->set_extension_uri_reference(map[type_record.id.uri]); + type->set_extension_urn_reference(map[type_record.id.uri]); type->set_type_anchor(anchor); type->set_name(std::string(type_record.id.name)); ext_decl->set_allocated_extension_type(type.release()); @@ -130,7 +130,7 @@ Status AddExtensionSetToMessage(const ExtensionSet& ext_set, Message* message) { ARROW_ASSIGN_OR_RAISE(Id function_id, ext_set.DecodeFunction(anchor)); auto fn = std::make_unique(); - fn->set_extension_uri_reference(map[function_id.uri]); + fn->set_extension_urn_reference(map[function_id.uri]); fn->set_function_anchor(anchor); fn->set_name(std::string(function_id.name)); diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 137b034d6ffc..323bcdebec16 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2647,7 +2647,7 @@ cdef class Expression(_Weakrefable): return f"" @staticmethod - def from_substrait(object message not None): + def from_substrait(object message not None, schema=None): """ Deserialize an expression from Substrait @@ -2661,13 +2661,17 @@ cdef class Expression(_Weakrefable): ---------- message : bytes or Buffer or a protobuf Message The Substrait message to deserialize + schema : Schema, optional + The input schema to use when the Substrait message contains + unresolved field names or unknown types. Returns ------- Expression The deserialized expression """ - expressions = _pas().BoundExpressions.from_substrait(message).expressions + expressions = _pas().BoundExpressions.from_substrait( + message, schema=schema).expressions if len(expressions) == 0: raise ValueError("Substrait message did not contain any expressions") if len(expressions) > 1: diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx index d9359c8e77d0..132effef9a9f 100644 --- a/python/pyarrow/_substrait.pyx +++ b/python/pyarrow/_substrait.pyx @@ -30,8 +30,17 @@ try: import substrait as py_substrait except ImportError: py_substrait = None + py_substrait_proto = None + py_substrait_extended_expression_pb2 = None else: - import substrait.proto # no-cython-lint + try: + import substrait.proto as py_substrait_proto # no-cython-lint + except ImportError: + py_substrait_proto = None + try: + from substrait import extended_expression_pb2 as py_substrait_extended_expression_pb2 # no-cython-lint + except ImportError: + py_substrait_extended_expression_pb2 = None # TODO GH-37235: Fix exception handling @@ -213,9 +222,13 @@ class SubstraitSchema: def to_pysubstrait(self): """Convert the schema to a substrait-python ExtendedExpression object.""" - if py_substrait is None: - raise ImportError("The 'substrait' package is required.") - return py_substrait.proto.ExtendedExpression.FromString(self.expression) + if py_substrait_proto is not None: + return py_substrait_proto.ExtendedExpression.FromString(self.expression) + if py_substrait_extended_expression_pb2 is not None: + return py_substrait_extended_expression_pb2.ExtendedExpression.FromString( + self.expression) + raise ImportError( + "The 'substrait' package or generated protobuf modules are required.") def serialize_schema(schema): @@ -397,7 +410,7 @@ cdef class BoundExpressions(_Weakrefable): return self @classmethod - def from_substrait(cls, message): + def from_substrait(cls, message, schema=None): """ Convert a Substrait message into a BoundExpressions object @@ -405,6 +418,9 @@ cdef class BoundExpressions(_Weakrefable): ---------- message : Buffer or bytes or protobuf Message The message to convert to a BoundExpressions object + schema : Schema, optional + The input schema to use when the Substrait message contains + unresolved field names or unknown types. Returns ------- @@ -412,18 +428,19 @@ cdef class BoundExpressions(_Weakrefable): The converted expressions, their names, and the bound schema """ if isinstance(message, (bytes, memoryview)): - return deserialize_expressions(message) + return deserialize_expressions(message, schema=schema) elif isinstance(message, Buffer): - return deserialize_expressions(message) + return deserialize_expressions(message, schema=schema) else: try: - return deserialize_expressions(message.SerializeToString()) + return deserialize_expressions( + message.SerializeToString(), schema=schema) except AttributeError: raise TypeError( f"Expected 'pyarrow.Buffer' or bytes or protobuf Message, got '{type(message)}'") -def deserialize_expressions(buf): +def deserialize_expressions(buf, schema=None): """ Deserialize an ExtendedExpression Substrait message into a BoundExpressions object @@ -431,6 +448,9 @@ def deserialize_expressions(buf): ---------- buf : Buffer or bytes The message to deserialize + schema : Schema, optional + The input schema to use when the Substrait message contains + unresolved field names or unknown types. Returns ------- @@ -450,9 +470,18 @@ def deserialize_expressions(buf): raise TypeError( f"Expected 'pyarrow.Buffer' or bytes, got '{type(buf)}'") - with nogil: - c_res_bound_exprs = DeserializeExpressions(deref(c_buffer)) - c_bound_exprs = GetResultValue(c_res_bound_exprs) + if schema is None: + with nogil: + c_res_bound_exprs = DeserializeExpressions(deref(c_buffer)) + c_bound_exprs = GetResultValue(c_res_bound_exprs) + else: + if not isinstance(schema, Schema): + raise TypeError( + f"Expected 'pyarrow.Schema' or None, got '{type(schema)}'") + with nogil: + c_res_bound_exprs = DeserializeExpressions( + deref(c_buffer), deref(( schema).sp_schema)) + c_bound_exprs = GetResultValue(c_res_bound_exprs) return BoundExpressions.wrap(c_bound_exprs) diff --git a/python/pyarrow/includes/libarrow_substrait.pxd b/python/pyarrow/includes/libarrow_substrait.pxd index 865568e2ba6f..9a28bcc0fe68 100644 --- a/python/pyarrow/includes/libarrow_substrait.pxd +++ b/python/pyarrow/includes/libarrow_substrait.pxd @@ -82,6 +82,9 @@ cdef extern from "arrow/engine/substrait/serde.h" namespace "arrow::engine" nogi CResult[CBoundExpressions] DeserializeExpressions( const CBuffer& serialized_expressions) + CResult[CBoundExpressions] DeserializeExpressions( + const CBuffer& serialized_expressions, const CSchema& schema) + CResult[shared_ptr[CBuffer]] SerializeSchema( const CSchema &schema, CExtensionSet* extension_set, const CConversionOptions& conversion_options) diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index fcd1c8d48c5f..ddea0eb37157 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -22,6 +22,7 @@ import pyarrow as pa import pyarrow.compute as pc +import pyarrow.dataset as ds from pyarrow.lib import tobytes from pyarrow.lib import ArrowInvalid, ArrowNotImplementedError @@ -34,6 +35,21 @@ # Ignore these with pytest ... -m 'not substrait' pytestmark = pytest.mark.substrait +UNBOUND_NAMED_ADD = ( + b'\x12\x1b\x1a\x19\x10\x01\x1a\x13add:unknown_unknown \x01\x1a$' + b'\n\x1d\x1a\x1b\x08\x01\x1a\x03\xba\x02\x00"\x08\x1a\x06\x92\x01' + b'\x03\n\x01a"\x08\x1a\x06\x92\x01\x03\n\x01b\x1a\x03sum"\x14\n\x01a' + b'\n\x01b\x12\x0c\n\x03\xba\x02\x00\n\x03\xba\x02\x00\x18\x02:\x0b*' + b'\tsubstraitB"\x08\x01\x12\x1eextension:io.substrait:unknown' +) +UNBOUND_NAMED_GT_ONE = ( + b'\x12\x16\x1a\x14\x10\x01\x1a\x0egt:unknown_i32 \x01\x1a%' + b'\n\x1b\x1a\x19\x08\x01\x1a\x03\xba\x02\x00"\x08\x1a\x06\x92\x01\x03' + b'\n\x01a"\x06\x1a\x04\n\x02(\x01\x1a\x06filter"\x0c\n\x01a\x12\x07' + b'\n\x03\xba\x02\x00\x18\x02:\x0b*\tsubstraitB"\x08\x01\x12\x1e' + b'extension:io.substrait:unknown' +) + def mock_udf_context(batch_length=10): from pyarrow._compute import _get_udf_context @@ -47,6 +63,25 @@ def _write_dummy_data_to_disk(tmpdir, file_name, table): return path +def _project_sum(expr, schema): + table = pa.table({ + "a": pa.array([1, 2], type=pa.int32()), + "b": pa.array([10, 20], type=pa.int32()), + }, schema=schema) + return ds.dataset(table).scanner(columns={"sum": expr}).to_table() + + +def _scan_with_projection_and_filter(projection, filter_expr, schema): + table = pa.table({ + "a": pa.array([1, 2, 3], type=pa.int32()), + "b": pa.array([10, 20, 30], type=pa.int32()), + }, schema=schema) + return ds.dataset(table).scanner( + columns=projection, + filter=filter_expr, + ).to_table() + + @pytest.mark.parametrize("use_threads", [True, False]) def test_run_serialized_query(tmpdir, use_threads): substrait_query = """ @@ -1057,6 +1092,82 @@ def test_serializing_with_compute(): assert str(expr2) == str(expr_norm) +def test_deserializing_unbound_expressions_with_schema(): + schema = pa.schema([ + pa.field("a", pa.int32()), + pa.field("b", pa.int32()) + ]) + + returned = pa.substrait.deserialize_expressions( + UNBOUND_NAMED_ADD, schema=schema) + assert returned.schema == schema + assert list(returned.expressions) == ["sum"] + assert _project_sum(returned.expressions["sum"], schema) == pa.table({ + "sum": pa.array([11, 22], type=pa.int32()) + }) + + +def test_deserializing_unbound_expressions_without_schema(): + with pytest.raises(ArrowInvalid, match="unknown"): + pa.substrait.deserialize_expressions(UNBOUND_NAMED_ADD) + + +def test_deserializing_unbound_expressions_schema_mismatch(): + schema = pa.schema([ + pa.field("a", pa.int32()), + pa.field("c", pa.int32()) + ]) + + with pytest.raises(ArrowInvalid, match="base_schema"): + pa.substrait.deserialize_expressions(UNBOUND_NAMED_ADD, schema=schema) + + +def test_compute_from_substrait_with_schema(): + schema = pa.schema([ + pa.field("a", pa.int32()), + pa.field("b", pa.int32()) + ]) + + expr = pc.Expression.from_substrait(UNBOUND_NAMED_ADD, schema=schema) + assert _project_sum(expr, schema) == pa.table({ + "sum": pa.array([11, 22], type=pa.int32()) + }) + + +def test_compute_filter_from_substrait_with_schema(): + schema = pa.schema([ + pa.field("a", pa.int32()) + ]) + + expr = pc.Expression.from_substrait(UNBOUND_NAMED_GT_ONE, schema=schema) + table = pa.table({ + "a": pa.array([1, 2, 3], type=pa.int32()) + }, schema=schema) + assert ds.dataset(table).scanner(filter=expr).to_table() == pa.table({ + "a": pa.array([2, 3], type=pa.int32()) + }) + + +def test_scanner_from_unbound_substrait_projection_and_filter(): + projection_schema = pa.schema([ + pa.field("a", pa.int32()), + pa.field("b", pa.int32()) + ]) + filter_schema = pa.schema([ + pa.field("a", pa.int32()) + ]) + + projection = pa.substrait.deserialize_expressions( + UNBOUND_NAMED_ADD, schema=projection_schema) + filter_expr = pc.Expression.from_substrait( + UNBOUND_NAMED_GT_ONE, schema=filter_schema) + + assert _scan_with_projection_and_filter( + projection, filter_expr, projection_schema) == pa.table({ + "sum": pa.array([22, 33], type=pa.int32()) + }) + + def test_serializing_udfs(): # Note, UDF in this context means a function that is not # recognized by Substrait. It might still be a builtin pyarrow