Skip to content

Commit c459d71

Browse files
jackywang-dbcloud-fan
authored andcommitted
[SPARK-53421][SPARK-53377][SDP] Propagate Logical Plan ID in SDP Analysis
### What changes were proposed in this pull request? Propagate `LogicalPlan.PLAN_ID_TAG` to the resolved logical plan during SDP analysis so when the whole plan is sent to Spark for analysis, it contains the correct plan id. ### Why are the changes needed? Spark Connect attaches a plan id to each logical plan. In SDP, we take part of the logical plan and analyze it independently to resolve table references correctly. When this happens, the logical plan id is lost which causes resolution errors when the plan is sent to Spark for complete analysis. For example, group by and rollup functions would fail with `sql.AnalysisException: [CANNOT_RESOLVE_DATAFRAME_COLUMN] Cannot resolve dataframe column "id". It's probably because of illegal references like df1.select(df2.col("a"))` ```python3 from pyspark.sql.functions import col, sum, count dp.materialized_view def groupby_result(): return spark.read.table("src").groupBy("id").count() ``` This happens because we take the below unresolved logical plan: ``` 'Aggregate ['id], ['id, 'count(1) AS count#7] +- 'UnresolvedRelation [src], [], false ``` Perform independent analysis on the `UnresolvedRelation` part to identify the table. During this analysis, the plan id is lost. ``` 'Aggregate ['id], ['id, 'count(1) AS count#7] +- SubqueryAlias spark_catalog.default.src +- Relation spark_catalog.default.src[id#9L] parquet ``` So when the above partially resolved logical plan is sent to Spark for analysis, it tries to resolve the `id` attribute in the aggregate operation with respect to the `SubqueryAlias` children, and fails because the children no longer contains the same plan id. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Tests ### Was this patch authored or co-authored using generative AI tooling? Closes #52121 from JiaqiWang18/SPARK-53377-sdp-groupBy-rollup-tests. Authored-by: Jacky Wang <jacky.wang@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 871fe3d commit c459d71

File tree

4 files changed

+117
-5
lines changed

4 files changed

+117
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,17 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
181181
}
182182
}
183183

184+
def mergeTagsFrom(other: BaseType): Unit = {
185+
if (!other.isTagsEmpty) {
186+
// Merge all tags from the other node into this node.
187+
// Unlike copyTagsFrom which only copies when this node has no tags,
188+
// mergeTagsFrom will always merge tags regardless of existing state.
189+
// If both nodes have the same tag with different values, the value
190+
// from the other node will overwrite the existing value in this node.
191+
tags ++= other.tags
192+
}
193+
}
194+
184195
def setTagValue[T](tag: TreeNodeTag[T], value: T): Unit = {
185196
tags(tag) = value
186197
}

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.api.python.PythonUtils
3030
import org.apache.spark.sql.AnalysisException
3131
import org.apache.spark.sql.catalyst.TableIdentifier
3232
import org.apache.spark.sql.connect.service.SparkConnectService
33-
import org.apache.spark.sql.pipelines.graph.DataflowGraph
33+
import org.apache.spark.sql.pipelines.graph.{DataflowGraph, PipelineUpdateContextImpl}
3434
import org.apache.spark.sql.pipelines.utils.{EventVerificationTestHelpers, TestPipelineUpdateContextMixin}
3535

3636
/**
@@ -434,6 +434,66 @@ class PythonPipelineSuite
434434
.map(_.identifier) == Seq(graphIdentifier("a"), graphIdentifier("something")))
435435
}
436436

437+
test("groupby and rollup works with internal datasets, referencing with (col, str)") {
438+
val graph = buildGraph("""
439+
from pyspark.sql.functions import col, sum, count
440+
441+
@dp.materialized_view
442+
def src():
443+
return spark.range(3)
444+
445+
@dp.materialized_view
446+
def groupby_with_col_result():
447+
return spark.read.table("src").groupBy(col("id")).agg(
448+
sum("id").alias("sum_id"),
449+
count("*").alias("cnt")
450+
)
451+
452+
@dp.materialized_view
453+
def groupby_with_str_result():
454+
return spark.read.table("src").groupBy("id").agg(
455+
sum("id").alias("sum_id"),
456+
count("*").alias("cnt")
457+
)
458+
459+
@dp.materialized_view
460+
def rollup_with_col_result():
461+
return spark.read.table("src").rollup(col("id")).agg(
462+
sum("id").alias("sum_id"),
463+
count("*").alias("cnt")
464+
)
465+
466+
@dp.materialized_view
467+
def rollup_with_str_result():
468+
return spark.read.table("src").rollup("id").agg(
469+
sum("id").alias("sum_id"),
470+
count("*").alias("cnt")
471+
)
472+
""")
473+
474+
val updateContext = new PipelineUpdateContextImpl(graph, _ => ())
475+
updateContext.pipelineExecution.runPipeline()
476+
updateContext.pipelineExecution.awaitCompletion()
477+
478+
val groupbyDfs =
479+
Seq(spark.table("groupby_with_col_result"), spark.table("groupby_with_str_result"))
480+
481+
val rollupDfs =
482+
Seq(spark.table("rollup_with_col_result"), spark.table("rollup_with_str_result"))
483+
484+
// groupBy: each variant should have exactly one row per id [0,1,2]
485+
groupbyDfs.foreach { df =>
486+
assert(df.select("id").collect().map(_.getLong(0)).toSet == Set(0L, 1L, 2L))
487+
}
488+
489+
// rollup: each variant should have groupBy rows + one total row
490+
rollupDfs.foreach { df =>
491+
assert(df.count() == 3 + 1) // 3 ids + 1 total
492+
val totalRow = df.filter("id IS NULL").collect().head
493+
assert(totalRow.getLong(1) == 3L && totalRow.getLong(2) == 3L)
494+
}
495+
}
496+
437497
test("create pipeline without table will throw RUN_EMPTY_PIPELINE exception") {
438498
checkError(
439499
exception = intercept[AnalysisException] {

sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,23 +112,29 @@ object FlowAnalysis {
112112
// - SELECT ... FROM STREAM(t1)
113113
// - SELECT ... FROM STREAM t1
114114
case u: UnresolvedRelation if u.isStreaming =>
115-
readStreamInput(
115+
val resolved = readStreamInput(
116116
context,
117117
name = IdentifierHelper.toQuotedString(u.multipartIdentifier),
118118
spark.readStream,
119119
streamingReadOptions = StreamingReadOptions()
120120
).queryExecution.analyzed
121-
121+
// Spark Connect requires the PLAN_ID_TAG to be propagated to the resolved plan
122+
// to allow correct analysis of the parent plan that contains this subquery
123+
resolved.mergeTagsFrom(u)
124+
resolved
122125
// Batch read on another dataset in the pipeline
123126
case u: UnresolvedRelation =>
124-
readBatchInput(
127+
val resolved = readBatchInput(
125128
context,
126129
name = IdentifierHelper.toQuotedString(u.multipartIdentifier),
127130
batchReadOptions = BatchReadOptions()
128131
).queryExecution.analyzed
132+
// Spark Connect requires the PLAN_ID_TAG to be propagated to the resolved plan
133+
// to allow correct analysis of the parent plan that contains this subquery
134+
resolved.mergeTagsFrom(u)
135+
resolved
129136
}
130137
Dataset.ofRows(spark, resolvedPlan)
131-
132138
}
133139

134140
/**

sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SqlPipelineSuite.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,41 @@ class SqlPipelineSuite extends PipelineTest with SharedSparkSession {
743743
)
744744
}
745745

746+
test("groupby and rollup works with internal datasets") {
747+
val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
748+
sqlText = s"""
749+
|CREATE MATERIALIZED VIEW src AS
750+
| SELECT id
751+
| FROM range(3);
752+
|
753+
|CREATE MATERIALIZED VIEW groupby_result AS
754+
| SELECT id, SUM(id) AS sum_id, COUNT(*) AS cnt
755+
| FROM src
756+
| GROUP BY id;
757+
|
758+
|CREATE MATERIALIZED VIEW rollup_result AS
759+
| SELECT id, SUM(id) AS sum_id, COUNT(*) AS cnt
760+
| FROM src
761+
| GROUP BY ROLLUP(id);
762+
|""".stripMargin
763+
)
764+
765+
startPipelineAndWaitForCompletion(unresolvedDataflowGraph)
766+
767+
val groupbyDf = spark.table(fullyQualifiedIdentifier("groupby_result"))
768+
val rollupDf = spark.table(fullyQualifiedIdentifier("rollup_result"))
769+
770+
// groupBy should have exactly one row per id [0,1,2]
771+
assert(groupbyDf.select("id").collect().map(_.getLong(0)).toSet == Set(0L, 1L, 2L))
772+
773+
// rollup should have all groupBy rows + one extra (the total row)
774+
assert(rollupDf.count() == groupbyDf.count() + 1)
775+
776+
// verify the rollup total row: id IS NULL, sum_id=3, cnt=3
777+
val totalRow = rollupDf.filter("id IS NULL").collect().head
778+
assert(totalRow.getLong(1) == 3L && totalRow.getLong(2) == 3L)
779+
}
780+
746781
test("Empty streaming table definition is disallowed") {
747782
val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
748783
sqlText = "CREATE STREAMING TABLE st;"

0 commit comments

Comments
 (0)