Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50953][PYTHON][CONNECT] Add support for non-literal paths in VariantGet #49609

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2161,8 +2161,12 @@ def is_variant_null(v: "ColumnOrName") -> Column:
is_variant_null.__doc__ = pysparkfuncs.is_variant_null.__doc__


def variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column:
return _invoke_function("variant_get", _to_col(v), lit(path), lit(targetType))
def variant_get(v: "ColumnOrName", path: Union[Column, str], targetType: str) -> Column:
assert isinstance(path, (Column, str))
if isinstance(path, str):
return _invoke_function("variant_get", _to_col(v), lit(path), lit(targetType))
else:
return _invoke_function("variant_get", _to_col(v), path, lit(targetType))


variant_get.__doc__ = pysparkfuncs.variant_get.__doc__
Expand Down
49 changes: 33 additions & 16 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20385,7 +20385,7 @@ def is_variant_null(v: "ColumnOrName") -> Column:


@_try_remote_functions
def variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column:
def variant_get(v: "ColumnOrName", path: Union[Column, str], targetType: str) -> Column:
"""
Extracts a sub-variant from `v` according to `path`, and then cast the sub-variant to
`targetType`. Returns null if the path does not exist. Throws an exception if the cast fails.
Expand All @@ -20396,9 +20396,10 @@ def variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column:
----------
v : :class:`~pyspark.sql.Column` or str
a variant column or column name
path : str
the extraction path. A valid path should start with `$` and is followed by zero or more
segments like `[123]`, `.name`, `['name']`, or `["name"]`.
path : :class:`~pyspark.sql.Column` or str
a column containing the extraction path strings or a string representing the extraction
path. A valid path should start with `$` and is followed by zero or more segments like
`[123]`, `.name`, `['name']`, or `["name"]`.
targetType : str
the target data type to cast into, in a DDL-formatted string

Expand All @@ -20409,21 +20410,29 @@ def variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column:

Examples
--------
>>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ])
>>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }''', 'path': '$.a'} ])
>>> df.select(variant_get(parse_json(df.json), "$.a", "int").alias("r")).collect()
[Row(r=1)]
>>> df.select(variant_get(parse_json(df.json), "$.b", "int").alias("r")).collect()
[Row(r=None)]
>>> df.select(variant_get(parse_json(df.json), df.path, "int").alias("r")).collect()
[Row(r=1)]
"""
from pyspark.sql.classic.column import _to_java_column

return _invoke_function(
"variant_get", _to_java_column(v), _enum_to_value(path), _enum_to_value(targetType)
)
assert isinstance(path, (Column, str))
if isinstance(path, str):
return _invoke_function(
"variant_get", _to_java_column(v), _enum_to_value(path), _enum_to_value(targetType)
)
else:
return _invoke_function(
"variant_get", _to_java_column(v), _to_java_column(path), _enum_to_value(targetType)
)


@_try_remote_functions
def try_variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column:
def try_variant_get(v: "ColumnOrName", path: Union[Column, str], targetType: str) -> Column:
"""
Extracts a sub-variant from `v` according to `path`, and then cast the sub-variant to
`targetType`. Returns null if the path does not exist or the cast fails.
Expand All @@ -20434,9 +20443,10 @@ def try_variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column:
----------
v : :class:`~pyspark.sql.Column` or str
a variant column or column name
path : str
the extraction path. A valid path should start with `$` and is followed by zero or more
segments like `[123]`, `.name`, `['name']`, or `["name"]`.
path : :class:`~pyspark.sql.Column` or str
a column containing the extraction path strings or a string representing the extraction
path. A valid path should start with `$` and is followed by zero or more segments like
`[123]`, `.name`, `['name']`, or `["name"]`.
targetType : str
the target data type to cast into, in a DDL-formatted string

Expand All @@ -20447,19 +20457,26 @@ def try_variant_get(v: "ColumnOrName", path: str, targetType: str) -> Column:

Examples
--------
>>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ])
>>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }''', 'path': '$.a'} ])
>>> df.select(try_variant_get(parse_json(df.json), "$.a", "int").alias("r")).collect()
[Row(r=1)]
>>> df.select(try_variant_get(parse_json(df.json), "$.b", "int").alias("r")).collect()
[Row(r=None)]
>>> df.select(try_variant_get(parse_json(df.json), "$.a", "binary").alias("r")).collect()
[Row(r=None)]
>>> df.select(try_variant_get(parse_json(df.json), df.path, "int").alias("r")).collect()
[Row(r=1)]
"""
from pyspark.sql.classic.column import _to_java_column

return _invoke_function(
"try_variant_get", _to_java_column(v), _enum_to_value(path), _enum_to_value(targetType)
)
if isinstance(path, str):
return _invoke_function(
"try_variant_get", _to_java_column(v), _enum_to_value(path), _enum_to_value(targetType)
)
else:
return _invoke_function(
"try_variant_get", _to_java_column(v), _to_java_column(path), _enum_to_value(targetType)
)


@_try_remote_functions
Expand Down
8 changes: 7 additions & 1 deletion python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,9 @@ def test_parse_json(self):
self.assertEqual("""{"b":[{"c":"str2"}]}""", actual["var_lit"])

def test_variant_expressions(self):
df = self.spark.createDataFrame([Row(json="""{ "a" : 1 }"""), Row(json="""{ "b" : 2 }""")])
df = self.spark.createDataFrame(
[Row(json="""{ "a" : 1 }""", path="$.a"), Row(json="""{ "b" : 2 }""", path="$.b")]
)
v = F.parse_json(df.json)

def check(resultDf, expected):
Expand All @@ -1510,6 +1512,10 @@ def check(resultDf, expected):
check(df.select(F.variant_get(v, "$.b", "int")), [None, 2])
check(df.select(F.variant_get(v, "$.a", "double")), [1.0, None])

# non-literal variant_get
check(df.select(F.variant_get(v, df.path, "int")), [1, 2])
check(df.select(F.try_variant_get(v, df.path, "binary")), [None, None])

with self.assertRaises(SparkRuntimeException) as ex:
df.select(F.variant_get(v, "$.a", "binary")).collect()

Expand Down
38 changes: 36 additions & 2 deletions sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7115,7 +7115,7 @@ object functions {
def is_variant_null(v: Column): Column = Column.fn("is_variant_null", v)

/**
* Extracts a sub-variant from `v` according to `path`, and then cast the sub-variant to
* Extracts a sub-variant from `v` according to `path` string, and then cast the sub-variant to
* `targetType`. Returns null if the path does not exist. Throws an exception if the cast fails.
*
* @param v
Expand All @@ -7132,7 +7132,24 @@ object functions {
Column.fn("variant_get", v, lit(path), lit(targetType))

/**
* Extracts a sub-variant from `v` according to `path`, and then cast the sub-variant to
* Extracts a sub-variant from `v` according to `path` column, and then cast the sub-variant to
* `targetType`. Returns null if the path does not exist. Throws an exception if the cast fails.
*
* @param v
* a variant column.
* @param path
* the column containing the extraction path strings. A valid path string should start with `$`
* and is followed by zero or more segments like `[123]`, `.name`, `['name']`, or `["name"]`.
* @param targetType
* the target data type to cast into, in a DDL-formatted string.
* @group variant_funcs
* @since 4.0.0
*/
def variant_get(v: Column, path: Column, targetType: String): Column =
Column.fn("variant_get", v, path, lit(targetType))

/**
* Extracts a sub-variant from `v` according to `path` string, and then cast the sub-variant to
* `targetType`. Returns null if the path does not exist or the cast fails..
*
* @param v
Expand All @@ -7148,6 +7165,23 @@ object functions {
def try_variant_get(v: Column, path: String, targetType: String): Column =
Column.fn("try_variant_get", v, lit(path), lit(targetType))

/**
* Extracts a sub-variant from `v` according to `path` column, and then cast the sub-variant to
* `targetType`. Returns null if the path does not exist or the cast fails..
*
* @param v
* a variant column.
* @param path
* the column containing the extraction path strings. A valid path string should start with `$`
* and is followed by zero or more segments like `[123]`, `.name`, `['name']`, or `["name"]`.
* @param targetType
* the target data type to cast into, in a DDL-formatted string.
* @group variant_funcs
* @since 4.0.0
*/
def try_variant_get(v: Column, path: Column, targetType: String): Column =
Column.fn("try_variant_get", v, lit(path), lit(targetType))

/**
* Returns schema in the SQL format of a variant.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,6 @@ case class VariantGet(
val check = super.checkInputDataTypes()
if (check.isFailure) {
check
} else if (!path.foldable) {
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> toSQLId("path"),
"inputType" -> toSQLType(path.dataType),
"inputExpr" -> toSQLExpr(path)))
} else if (!VariantGet.checkDataType(targetType)) {
DataTypeMismatch(
errorSubClass = "CAST_WITHOUT_SUGGESTION",
Expand All @@ -265,10 +258,12 @@ case class VariantGet(

override lazy val dataType: DataType = targetType.asNullable

@transient private lazy val parsedPath = {
val pathValue = path.eval().toString
VariantPathParser.parse(pathValue).getOrElse {
throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
@transient private lazy val parsedPath: Option[Array[VariantPathSegment]] = {
if (path.foldable) {
val pathValue = path.eval().toString
Some(VariantGet.getParsedPath(pathValue, prettyName))
} else {
None
}
}

Expand All @@ -287,31 +282,65 @@ case class VariantGet(
timeZoneId,
zoneId)

protected override def nullSafeEval(input: Any, path: Any): Any = {
VariantGet.variantGet(input.asInstanceOf[VariantVal], parsedPath, dataType, castArgs)
protected override def nullSafeEval(input: Any, path: Any): Any = parsedPath match {
case Some(pp) =>
VariantGet.variantGet(input.asInstanceOf[VariantVal], pp, dataType, castArgs)
case _ =>
VariantGet.variantGet(input.asInstanceOf[VariantVal], path.asInstanceOf[UTF8String], dataType,
castArgs, prettyName)
}

protected override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childCode = child.genCode(ctx)
val tmp = ctx.freshVariable("tmp", classOf[Object])
val parsedPathArg = ctx.addReferenceObj("parsedPath", parsedPath)
val dataTypeArg = ctx.addReferenceObj("dataType", dataType)
val castArgsArg = ctx.addReferenceObj("castArgs", castArgs)
val code = code"""
${childCode.code}
boolean ${ev.isNull} = ${childCode.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
Object $tmp = org.apache.spark.sql.catalyst.expressions.variant.VariantGet.variantGet(
${childCode.value}, $parsedPathArg, $dataTypeArg, $castArgsArg);
if ($tmp == null) {
${ev.isNull} = true;
} else {
${ev.value} = (${CodeGenerator.boxedType(dataType)})$tmp;
protected override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = parsedPath match {
case Some(pp) =>
val childCode = child.genCode(ctx)
val tmp = ctx.freshVariable("tmp", classOf[Object])
val parsedPathArg = ctx.addReferenceObj("parsedPath", pp)
val dataTypeArg = ctx.addReferenceObj("dataType", dataType)
val castArgsArg = ctx.addReferenceObj("castArgs", castArgs)
val code = code"""
${childCode.code}
boolean ${ev.isNull} = ${childCode.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
Object $tmp = org.apache.spark.sql.catalyst.expressions.variant.VariantGet.variantGet(
${childCode.value}, $parsedPathArg, $dataTypeArg, $castArgsArg);
if ($tmp == null) {
${ev.isNull} = true;
} else {
${ev.value} = (${CodeGenerator.boxedType(dataType)})$tmp;
}
}
}
"""
ev.copy(code = code)
"""
ev.copy(code = code)
case None =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are a lot of common code between these two branches, can't we merge them and adjust a few places w.r.t. parsedPath?

val tmp = ctx.freshVariable("tmp", classOf[Object])
val childCode = child.genCode(ctx)
val pathCode = path.genCode(ctx)
val dataTypeArg = ctx.addReferenceObj("dataType", dataType)
val castArgsArg = ctx.addReferenceObj("castArgs", castArgs)
val parsedPathVar = ctx.freshName("parsedPath")
val parsedPathType = CodeGenerator.typeName(classOf[Array[VariantPathSegment]])
val UTF8Type = CodeGenerator.typeName(classOf[UTF8String])
val code = code"""
${childCode.code}
${pathCode.code}
boolean ${ev.isNull} = ${childCode.isNull} || ${pathCode.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
$parsedPathType $parsedPathVar = ($parsedPathType)
org.apache.spark.sql.catalyst.expressions.variant.VariantGet.getParsedPath(
${pathCode.value}.toString(), "$prettyName");
Object $tmp = org.apache.spark.sql.catalyst.expressions.variant.VariantGet.variantGet(
${childCode.value}, ($UTF8Type) ${pathCode.value}, $dataTypeArg, $castArgsArg,
"$prettyName");
if ($tmp == null) {
${ev.isNull} = true;
} else {
${ev.value} = (${CodeGenerator.boxedType(dataType)})$tmp;
}
}
"""
ev.copy(code = code)
}

override def left: Expression = child
Expand Down Expand Up @@ -350,6 +379,15 @@ case object VariantGet {
case _ => false
}

/**
* Get parsed Array[VariantPathSegment] from string representing path
*/
def getParsedPath(pathValue: String, prettyName: String): Array[VariantPathSegment] = {
VariantPathParser.parse(pathValue).getOrElse {
throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
}
}

/** The actual implementation of the `VariantGet` expression. */
def variantGet(
input: VariantVal,
Expand All @@ -368,6 +406,20 @@ case object VariantGet {
VariantGet.cast(v, dataType, castArgs)
}

/**
* Implementation of the `VariantGet` expression where the path is provided as a UTF8String
*/
def variantGet(
input: VariantVal,
path: UTF8String,
dataType: DataType,
castArgs: VariantCastArgs,
prettyName: String): Any = {
val pathValue = path.toString
val parsedPath = VariantGet.getParsedPath(pathValue, prettyName)
variantGet(input, parsedPath, dataType, castArgs)
}

/**
* A simple wrapper of the `cast` function that takes `Variant` rather than `VariantVal`. The
* `Cast` expression uses it and makes the implementation simpler.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,11 @@ object RequestedVariantField {
def fullVariant: RequestedVariantField =
RequestedVariantField(VariantMetadata("$", failOnError = true, "UTC"), VariantType)

def apply(v: VariantGet): RequestedVariantField =
def apply(v: VariantGet): RequestedVariantField = {
assert(v.path.foldable)
RequestedVariantField(
VariantMetadata(v.path.eval().toString, v.failOnError, v.timeZoneId.get), v.dataType)
}

def apply(c: Cast): RequestedVariantField =
RequestedVariantField(
Expand Down Expand Up @@ -212,7 +214,7 @@ class VariantInRelation {
// fields, which also changes the struct type containing it, and it is difficult to reconstruct
// the original struct value. This is not a big loss, because we need the full variant anyway.
def collectRequestedFields(expr: Expression): Unit = expr match {
case v@VariantGet(StructPathToVariant(fields), _, _, _, _) =>
case v@VariantGet(StructPathToVariant(fields), path, _, _, _) if path.foldable =>
addField(fields, RequestedVariantField(v))
case c@Cast(StructPathToVariant(fields), _, _, _) => addField(fields, RequestedVariantField(c))
case IsNotNull(StructPath(_, _)) | IsNull(StructPath(_, _)) =>
Expand Down Expand Up @@ -240,7 +242,7 @@ class VariantInRelation {

// Rewrite patterns should be consistent with visit patterns in `collectRequestedFields`.
expr.transformDown {
case g@VariantGet(v@StructPathToVariant(fields), _, _, _, _) =>
case g@VariantGet(v@StructPathToVariant(fields), path, _, _, _) if path.foldable =>
// Rewrite the attribute in advance, rather than depending on the last branch to rewrite it.
// Ww need to avoid the `v@StructPathToVariant(fields)` branch to rewrite the child again.
GetStructField(rewriteAttribute(v), fields(RequestedVariantField(g)))
Expand Down
Loading
Loading