Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50979][CONNECT] Remove .expr/.typedExpr implicits #49657

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import org.apache.spark.connect.proto.{NAReplace, Relation}
import org.apache.spark.connect.proto.Expression.{Literal => GLiteral}
import org.apache.spark.connect.proto.NAReplace.Replacement
import org.apache.spark.sql
import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.functions

/**
* Functionality for working with missing data in `DataFrame`s.
Expand All @@ -33,7 +33,6 @@ import org.apache.spark.sql.functions
*/
final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: Relation)
extends sql.DataFrameNaFunctions {
import sparkSession.RichColumn

override protected def drop(minNonNulls: Option[Int]): DataFrame =
buildDropDataFrame(None, minNonNulls)
Expand Down Expand Up @@ -103,7 +102,7 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root:
sparkSession.newDataFrame { builder =>
val fillNaBuilder = builder.getFillNaBuilder.setInput(root)
values.map { case (colName, replaceValue) =>
fillNaBuilder.addCols(colName).addValues(functions.lit(replaceValue).expr.getLiteral)
fillNaBuilder.addCols(colName).addValues(toLiteral(replaceValue).getLiteral)
}
}
}
Expand Down Expand Up @@ -143,8 +142,8 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root:
replacementMap.map { case (oldValue, newValue) =>
Replacement
.newBuilder()
.setOldValue(functions.lit(oldValue).expr.getLiteral)
.setNewValue(functions.lit(newValue).expr.getLiteral)
.setOldValue(toLiteral(oldValue).getLiteral)
.setNewValue(toLiteral(newValue).getLiteral)
.build()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import org.apache.spark.connect.proto.{Relation, StatSampleBy}
import org.apache.spark.sql
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, PrimitiveDoubleEncoder}
import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, toLiteral}
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.DataFrameStatFunctions.approxQuantileResultEncoder
import org.apache.spark.sql.functions.lit

/**
* Statistic functions for `DataFrame`s.
Expand Down Expand Up @@ -120,20 +120,19 @@ final class DataFrameStatFunctions private[sql] (protected val df: DataFrame)

/** @inheritdoc */
def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = {
import sparkSession.RichColumn
require(
fractions.values.forall(p => p >= 0.0 && p <= 1.0),
s"Fractions must be in [0, 1], but got $fractions.")
sparkSession.newDataFrame { builder =>
val sampleByBuilder = builder.getSampleByBuilder
.setInput(root)
.setCol(col.expr)
.setCol(toExpr(col))
.setSeed(seed)
fractions.foreach { case (k, v) =>
sampleByBuilder.addFractions(
StatSampleBy.Fraction
.newBuilder()
.setStratum(lit(k).expr.getLiteral)
.setStratum(toLiteral(k).getLiteral)
.setFraction(v))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.connect.proto
import org.apache.spark.sql
import org.apache.spark.sql.Column
import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toExpr

/**
* Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2
Expand All @@ -33,7 +34,6 @@ import org.apache.spark.sql.Column
@Experimental
final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])
extends sql.DataFrameWriterV2[T] {
import ds.sparkSession.RichColumn

private val builder = proto.WriteOperationV2
.newBuilder()
Expand Down Expand Up @@ -73,7 +73,7 @@ final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])
/** @inheritdoc */
@scala.annotation.varargs
override def partitionedBy(column: Column, columns: Column*): this.type = {
builder.addAllPartitioningColumns((column +: columns).map(_.expr).asJava)
builder.addAllPartitioningColumns((column +: columns).map(toExpr).asJava)
this
}

Expand Down Expand Up @@ -106,7 +106,7 @@ final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])

/** @inheritdoc */
def overwrite(condition: Column): Unit = {
builder.setOverwriteCondition(condition.expr)
builder.setOverwriteCondition(toExpr(condition))
executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.expressions.OrderUtils
import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, toLiteral, toTypedExpr}
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.client.SparkResult
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter}
Expand Down Expand Up @@ -140,7 +141,6 @@ class Dataset[T] private[sql] (
@DeveloperApi val plan: proto.Plan,
val encoder: Encoder[T])
extends sql.Dataset[T] {
import sparkSession.RichColumn

// Make sure we don't forget to set plan id.
assert(plan.getRoot.getCommon.hasPlanId)
Expand Down Expand Up @@ -336,7 +336,7 @@ class Dataset[T] private[sql] (
buildJoin(right, Seq(joinExprs)) { builder =>
builder
.setJoinType(toJoinType(joinType))
.setJoinCondition(joinExprs.expr)
.setJoinCondition(toExpr(joinExprs))
}
}

Expand Down Expand Up @@ -375,7 +375,7 @@ class Dataset[T] private[sql] (
.setLeft(plan.getRoot)
.setRight(other.plan.getRoot)
.setJoinType(joinTypeValue)
.setJoinCondition(condition.expr)
.setJoinCondition(toExpr(condition))
.setJoinDataType(joinBuilder.getJoinDataTypeBuilder
.setIsLeftStruct(this.agnosticEncoder.isStruct)
.setIsRightStruct(other.agnosticEncoder.isStruct))
Expand All @@ -396,7 +396,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataFrame(joinExprs.toSeq) { builder =>
val lateralJoinBuilder = builder.getLateralJoinBuilder
lateralJoinBuilder.setLeft(plan.getRoot).setRight(right.plan.getRoot)
joinExprs.foreach(c => lateralJoinBuilder.setJoinCondition(c.expr))
joinExprs.foreach(c => lateralJoinBuilder.setJoinCondition(toExpr(c)))
lateralJoinBuilder.setJoinType(joinTypeValue)
}
}
Expand Down Expand Up @@ -440,7 +440,7 @@ class Dataset[T] private[sql] (
builder.getHintBuilder
.setInput(plan.getRoot)
.setName(name)
.addAllParameters(parameters.map(p => functions.lit(p).expr).asJava)
.addAllParameters(parameters.map(p => toLiteral(p)).asJava)
}

private def getPlanId: Option[Long] =
Expand Down Expand Up @@ -486,7 +486,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataset(encoder) { builder =>
builder.getProjectBuilder
.setInput(plan.getRoot)
.addExpressions(col.typedExpr(this.encoder))
.addExpressions(toTypedExpr(col, this.encoder))
}
}

Expand All @@ -504,14 +504,14 @@ class Dataset[T] private[sql] (
sparkSession.newDataset(encoder, cols) { builder =>
builder.getProjectBuilder
.setInput(plan.getRoot)
.addAllExpressions(cols.map(_.typedExpr(this.encoder)).asJava)
.addAllExpressions(cols.map(c => toTypedExpr(c, this.encoder)).asJava)
}
}

/** @inheritdoc */
def filter(condition: Column): Dataset[T] = {
sparkSession.newDataset(agnosticEncoder, Seq(condition)) { builder =>
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(toExpr(condition))
}
}

Expand All @@ -523,12 +523,12 @@ class Dataset[T] private[sql] (
sparkSession.newDataFrame(ids.toSeq ++ valuesOption.toSeq.flatten) { builder =>
val unpivot = builder.getUnpivotBuilder
.setInput(plan.getRoot)
.addAllIds(ids.toImmutableArraySeq.map(_.expr).asJava)
.addAllIds(ids.toImmutableArraySeq.map(toExpr).asJava)
.setVariableColumnName(variableColumnName)
.setValueColumnName(valueColumnName)
valuesOption.foreach { values =>
unpivot.getValuesBuilder
.addAllValues(values.toImmutableArraySeq.map(_.expr).asJava)
.addAllValues(values.toImmutableArraySeq.map(toExpr).asJava)
}
}
}
Expand All @@ -537,7 +537,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataFrame(indices) { builder =>
val transpose = builder.getTransposeBuilder.setInput(plan.getRoot)
indices.foreach { indexColumn =>
transpose.addIndexColumns(indexColumn.expr)
transpose.addIndexColumns(toExpr(indexColumn))
}
}

Expand All @@ -553,7 +553,7 @@ class Dataset[T] private[sql] (
function = func,
inputEncoders = agnosticEncoder :: agnosticEncoder :: Nil,
outputEncoder = agnosticEncoder)
val reduceExpr = Column.fn("reduce", udf.apply(col("*"), col("*"))).expr
val reduceExpr = toExpr(Column.fn("reduce", udf.apply(col("*"), col("*"))))

val result = sparkSession
.newDataset(agnosticEncoder) { builder =>
Expand Down Expand Up @@ -590,7 +590,7 @@ class Dataset[T] private[sql] (
val groupingSetMsgs = groupingSets.map { groupingSet =>
val groupingSetMsg = proto.Aggregate.GroupingSets.newBuilder()
for (groupCol <- groupingSet) {
groupingSetMsg.addGroupingSet(groupCol.expr)
groupingSetMsg.addGroupingSet(toExpr(groupCol))
}
groupingSetMsg.build()
}
Expand Down Expand Up @@ -779,7 +779,7 @@ class Dataset[T] private[sql] (
s"The size of column names: ${names.size} isn't equal to " +
s"the size of columns: ${values.size}")
val aliases = values.zip(names).map { case (value, name) =>
value.name(name).expr.getAlias
toExpr(value.name(name)).getAlias
}
sparkSession.newDataFrame(values) { builder =>
builder.getWithColumnsBuilder
Expand Down Expand Up @@ -812,7 +812,7 @@ class Dataset[T] private[sql] (
def withMetadata(columnName: String, metadata: Metadata): DataFrame = {
val newAlias = proto.Expression.Alias
.newBuilder()
.setExpr(col(columnName).expr)
.setExpr(toExpr(col(columnName)))
.addName(columnName)
.setMetadata(metadata.json)
sparkSession.newDataFrame { builder =>
Expand Down Expand Up @@ -845,7 +845,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataFrame(cols) { builder =>
builder.getDropBuilder
.setInput(plan.getRoot)
.addAllColumns(cols.map(_.expr).asJava)
.addAllColumns(cols.map(toExpr).asJava)
}
}

Expand Down Expand Up @@ -915,7 +915,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataset[T](agnosticEncoder) { builder =>
builder.getFilterBuilder
.setInput(plan.getRoot)
.setCondition(udf.apply(col("*")).expr)
.setCondition(toExpr(udf.apply(col("*"))))
}
}

Expand Down Expand Up @@ -944,7 +944,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataset(outputEncoder) { builder =>
builder.getMapPartitionsBuilder
.setInput(plan.getRoot)
.setFunc(udf.apply(col("*")).expr.getCommonInlineUserDefinedFunction)
.setFunc(toExpr(udf.apply(col("*"))).getCommonInlineUserDefinedFunction)
}
}

Expand Down Expand Up @@ -1020,7 +1020,7 @@ class Dataset[T] private[sql] (
sparkSession.newDataset(agnosticEncoder, partitionExprs) { builder =>
val repartitionBuilder = builder.getRepartitionByExpressionBuilder
.setInput(plan.getRoot)
.addAllPartitionExprs(partitionExprs.map(_.expr).asJava)
.addAllPartitionExprs(partitionExprs.map(toExpr).asJava)
numPartitions.foreach(repartitionBuilder.setNumPartitions)
}
}
Expand All @@ -1036,7 +1036,7 @@ class Dataset[T] private[sql] (
// The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments.
// However, we don't want to complicate the semantics of this API method.
// Instead, let's give users a friendly error message, pointing them to the new method.
val sortOrders = partitionExprs.filter(_.expr.hasSortOrder)
val sortOrders = partitionExprs.filter(e => toExpr(e).hasSortOrder)
if (sortOrders.nonEmpty) {
throw new IllegalArgumentException(
s"Invalid partitionExprs specified: $sortOrders\n" +
Expand All @@ -1050,7 +1050,7 @@ class Dataset[T] private[sql] (
partitionExprs: Seq[Column]): Dataset[T] = {
require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.")
val sortExprs = partitionExprs.map {
case e if e.expr.hasSortOrder => e
case e if toExpr(e).hasSortOrder => e
case e => e.asc
}
buildRepartitionByExpression(numPartitions, sortExprs)
Expand Down Expand Up @@ -1158,7 +1158,7 @@ class Dataset[T] private[sql] (
builder.getCollectMetricsBuilder
.setInput(plan.getRoot)
.setName(name)
.addAllMetrics((expr +: exprs).map(_.expr).asJava)
.addAllMetrics((expr +: exprs).map(toExpr).asJava)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql
import org.apache.spark.sql.{Column, Encoder, TypedColumn}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, ProductEncoder}
import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toExpr
import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.{toExpr, toTypedExpr}
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfUtils}
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
Expand Down Expand Up @@ -394,7 +394,6 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
private val valueMapFunc: Option[IV => V],
private val keysFunc: () => Dataset[IK])
extends KeyValueGroupedDataset[K, V] {
import sparkSession.RichColumn

override def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = {
new KeyValueGroupedDatasetImpl[L, V, IK, IV](
Expand Down Expand Up @@ -436,7 +435,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
sparkSession.newDataset[U](outputEncoder) { builder =>
builder.getGroupMapBuilder
.setInput(plan.getRoot)
.addAllSortingExpressions(sortExprs.map(e => e.expr).asJava)
.addAllSortingExpressions(sortExprs.map(toExpr).asJava)
.addAllGroupingExpressions(groupingExprs)
.setFunc(getUdf(nf, outputEncoder)(ivEncoder))
}
Expand All @@ -453,10 +452,10 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
builder.getCoGroupMapBuilder
.setInput(plan.getRoot)
.addAllInputGroupingExpressions(groupingExprs)
.addAllInputSortingExpressions(thisSortExprs.map(e => e.expr).asJava)
.addAllInputSortingExpressions(thisSortExprs.map(toExpr).asJava)
.setOther(otherImpl.plan.getRoot)
.addAllOtherGroupingExpressions(otherImpl.groupingExprs)
.addAllOtherSortingExpressions(otherSortExprs.map(e => e.expr).asJava)
.addAllOtherSortingExpressions(otherSortExprs.map(toExpr).asJava)
.setFunc(getUdf(nf, outputEncoder)(ivEncoder, otherImpl.ivEncoder))
}
}
Expand All @@ -469,7 +468,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
.setInput(plan.getRoot)
.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
.addAllGroupingExpressions(groupingExprs)
.addAllAggregateExpressions(columns.map(_.typedExpr(vEncoder)).asJava)
.addAllAggregateExpressions(columns.map(c => toTypedExpr(c, vEncoder)).asJava)
}
}

Expand Down Expand Up @@ -534,7 +533,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
function = nf,
inputEncoders = inputEncoders,
outputEncoder = outputEncoder)
udf.apply(inputEncoders.map(_ => col("*")): _*).expr.getCommonInlineUserDefinedFunction
toExpr(udf.apply(inputEncoders.map(_ => col("*")): _*)).getCommonInlineUserDefinedFunction
}

private def getUdf[U: Encoder, S: Encoder](
Expand All @@ -549,7 +548,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
function = nf,
inputEncoders = inputEncoders,
outputEncoder = outputEncoder)
udf.apply(inputEncoders.map(_ => col("*")): _*).expr.getCommonInlineUserDefinedFunction
toExpr(udf.apply(inputEncoders.map(_ => col("*")): _*)).getCommonInlineUserDefinedFunction
}

/**
Expand Down
Loading
Loading