Skip to content

Conversation

@revans2
Copy link
Collaborator

@revans2 revans2 commented Aug 27, 2025

Description

This attempts to make it possible to fall back to the CPU in a more efficient way. In the current code we fall back to the CPU at the level of a SparkPlan node. So an entire ProjectExec or HashAggregateExec would fall back to the CPU if a single expression in it could not run on the GPU.

  1. This can be very costly in terms of data movement because we have to move the entire input of the SparkPlan node to the CPU and then the entire result back to the GPU again.
  2. It is costly in terms of releasing the semaphore. When we release the GPU semaphore it can put something new on the GPU and increase memory pressure.
  3. It means that even if most of the heavy lifting, like an aggregation operation, could be done on the GPU that the expensive computation would be done on the CPU instead.
  4. When falling back for an entire SparkPlan node only a single CPU thread would be used to process the data. This does not scale well, in all cases.

The cpu bridge has a thread pool that is used to execute the CPU expressions in parallel moving only the minimal data needed. This allows more of the processing to stay on the GPU, it minimizes data movement, and even though it does not release the semaphore when running on the CPU it offsets this by throwing as many cores at the processing as there are configured tasks.

This does not work for non-deterministic expressions or aggregations.

It is currently off by default for two reasons

  1. It is a big job to change all of the tests and test infrastructure that expect fallback to happen in a specific way to deal with it when it does not happen that way.
  2. I would like more feedback on this before turning it on by default. I wrote this initially as a way to experiment with dynamic CPU resource monitoring, using AI to do most of the heavy lifting. But I saw such great performance gains, even for cases when most of the processing for an operation moved to the CPU that I decided to polish it a bit and submit it anyways,

From a performance standpoint I have tested it in a few situations.

  1. a UDF among a much larger query
  • Gen Data ~ 50 GiB compressed, 100 GiB uncompressed
spark.time(spark.range(0, 1500000000L, 1, 100).selectExpr("id", "CAST(id as STRING) str_id", "CAST(round(rand(0) * 100) as LONG) as site", "date_add('2000-01-01', CAST(id % (10 * 365 * 24) DIV 24 as INT)) as d", "CAST(id % 24 as BYTE) as h", "round(rand(1) * 40.0, 2) as temp_c", "rand(2) as feat1", "rand(3) as feat2", "rand(4) as feat3").write.mode("overwrite").parquet("/data/tmp/PERF_TEST"))
  • Query with UDF
val c_to_f = udf((c: Double) => (c * 9.0 / 5) + 32)
spark.udf.register("c_to_f", c_to_f)
val with_temp_f = spark.read.parquet("/data/tmp/PERF_TEST").selectExpr("*","c_to_f(temp_c) as temp_f")
val avg_by_site = with_temp_f.groupBy(col("site")).agg(stddev_pop(col("temp_f")).alias("site_stddev_temp_f"),avg(col("temp_f")).alias("site_avg_temp_f"), stddev_pop(col("feat1")).alias("site_stddev_1"), avg(col("feat1")).alias("site_avg_1"), stddev_pop(col("feat3")).alias("site_stddev_3"), avg(col("feat3")).alias("site_avg_3")).withColumnRenamed("site", "site_site")
val avg_by_day_hour = with_temp_f.groupBy(col("d"), col("h")).agg(stddev_pop(col("temp_f")).alias("dh_stddev_temp_f"), avg(col("temp_f")).alias("dh_avg_temp_f")).withColumnRenamed("d", "dh_d").withColumnRenamed("h", "dh_h")
spark.time(with_temp_f.join(avg_by_site, with_temp_f("site") === avg_by_site("site_site")).join(avg_by_day_hour, col("d") === col("dh_d") && col("h") === col("dh_h")).selectExpr("site","d", "h", "temp_f", "temp_f - site_avg_temp_f as diff_site_temp_f", "temp_f - dh_avg_temp_f as diff_dh_temp_f", "site_stddev_temp_f", "dh_stddev_temp_f", "site_stddev_1", "site_avg_1", "site_stddev_3", "site_avg_3").orderBy("site", "d", "h").show())
  • With no UDF
val with_temp_f = spark.read.parquet("/data/tmp/PERF_TEST").selectExpr("*","(temp_c * 9.0 / 5) + 32 as temp_f")
val avg_by_site = with_temp_f.groupBy(col("site")).agg(stddev_pop(col("temp_f")).alias("site_stddev_temp_f"),avg(col("temp_f")).alias("site_avg_temp_f"), stddev_pop(col("feat1")).alias("site_stddev_1"), avg(col("feat1")).alias("site_avg_1"), stddev_pop(col("feat3")).alias("site_stddev_3"), avg(col("feat3")).alias("site_avg_3")).withColumnRenamed("site", "site_site")
val avg_by_day_hour = with_temp_f.groupBy(col("d"), col("h")).agg(stddev_pop(col("temp_f")).alias("dh_stddev_temp_f"), avg(col("temp_f")).alias("dh_avg_temp_f")).withColumnRenamed("d", "dh_d").withColumnRenamed("h", "dh_h")
spark.time(with_temp_f.join(avg_by_site, with_temp_f("site") === avg_by_site("site_site")).join(avg_by_day_hour, col("d") === col("dh_d") && col("h") === col("dh_h")).selectExpr("site","d", "h", "temp_f", "temp_f - site_avg_temp_f as diff_site_temp_f", "temp_f - dh_avg_temp_f as diff_dh_temp_f", "site_stddev_temp_f", "dh_stddev_temp_f", "site_stddev_1", "site_avg_1", "site_stddev_3", "site_avg_3").orderBy("site", "d", "h").show())
  Spark 3.4.2 48 GiB - GPU wall time (ms) Spark 3.5.0 48 GiB - GPU wall time (ms)
GPU Fallback 88,396 84,770
GPU Bridge + gen 58,806 56,520
GPU Bridge nogen 351,544 NOPE
GPU no UDF 50,324 51,272
CPU 16 cores 143,947 137,395
CPU 16 cores no UDF   136,483

I also ran some much simpler tests where almost the entire query is a single expression that is not on the GPU.

  • getbit simple
spark.time(spark.range(0, 10000000000L, 1, 160).selectExpr("id % 11 as a", "getbit(id, 0) as b").groupBy("a").agg(sum(col("b")).as("r")).orderBy(desc("a")).show())
  • getbit complex
spark.time(spark.range(0, 10000000000L, 1, 160).selectExpr("id % 11 as a", "getbit(getbit(id, 0), getbit(id, 1)) as b").groupBy("a").agg(sum(col("b")).as("r")).orderBy(desc("a")).show())
  • simple add
spark.conf.set("spark.rapids.sql.expression.Add", false)
spark.time(spark.range(0, 10000000000L, 1, 160).selectExpr("id % 11 as a", "id % 20 + 5 as b").groupBy("a").agg(sum(col("b")).as("r")).orderBy(desc("a")).show())
mode query time 4 GiB GPU (ms) time 8 GiB GPU (ms) time 16 GiB GPU (ms) time 48 GiB GPU (ms)
default fallback getbit simple OOM OOM 33,931 34,985
CPU getbit simple SAME SAME 6,181 SAME
bridge + gen getbit simple OOM 11,681 9,893 7,475
bridge getbit simple OOM 23,659 13,714 11,261
default fallback getbit complex OOM OOM 34,052 35,710
CPU getbit complex SAME SAME 6,188 SAME
bridge + gen getbit complex OOM 11,785 10,450 10,167
bridge getbit complex OOM 37,705 22,970 23,379
pure GPU simple add 2,242 1,726 1,641 1,645
CPU simple add SAME SAME 6,163 SAME
default fallback simple add OOM OOM 38,686 39,557
bridge + gen simple add OOM 14,323 14,171 11,429
bridge simple add OOM 11,841 11,860 11,817

The bridge version is more memory efficient than falling back for everything, but not quite as good as a pure GPU implementation. There is code to allow us to run the expression as an interpreted expression instead of as code gen, but the performance can vary wildly, and is generally slower than the code gen version. I am happy to rip out the interpreted version as the code gen version can fall back to an interpreted version behind the scenes in Spark. It just requires setting a separate config entirely.

Note: A lot of this code was written with AI (specifically claude-4-sonnet through cursor)

Checklists

  • This PR has added documentation for new or modified features or behaviors.
    • Not Yet. I want to hold off a bit longer until I get feedback on this and we can decide what to do with it next.
  • This PR has added new tests or modified existing tests to cover new code paths.
  • Performance testing has been performed and its results are added in the PR description. Or, an issue has been filed with a link in the PR description.

a per expression level instead of per-SparkPlan node level.

A lot fo this code was written with AI (specifically claude-4-sonnet
through cursor)

Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
@revans2
Copy link
Collaborator Author

revans2 commented Aug 27, 2025

build

@sameerz sameerz added the performance A performance related task/issue label Sep 15, 2025
/**
* Converts a CPU expression to a GPU expression.
*/
def convertToGpuBase(): Expression
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was a little confused about convertToGpuBase vs convertToGpu (the naming), and kind of wish we didn't use Base in the name. Something like: keeping convertToGpu for, well, converting to the GPU, and using something different for the function defined in line 1453 that indicates it's going to be either or. I do not know if this is possible or adds a lot more work.

Copilot AI review requested due to automatic review settings September 22, 2025 14:22
@revans2
Copy link
Collaborator Author

revans2 commented Sep 22, 2025

build

@revans2
Copy link
Collaborator Author

revans2 commented Sep 22, 2025

@zpuller and @abellina could you please take another look?

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces support for a "CpuBridge" feature that enables GPU-CPU hybrid execution at an expression level rather than falling back entire SparkPlan nodes to the CPU. The primary goal is to minimize data movement costs, utilize GPU resources more efficiently, and allow CPU expressions to run in parallel using a thread pool.

Key changes:

  • Adds CPU bridge infrastructure with thread pool for parallel CPU expression evaluation
  • Modifies expression metadata system to support bridge decision-making and optimization
  • Updates method signatures from convertToGpu() to convertToGpuImpl() and introduces new wrapper logic

Reviewed Changes

Copilot reviewed 24 out of 24 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
RapidsMeta.scala Core infrastructure for CPU bridge support with expression analysis and optimization logic
RapidsConf.scala Configuration options for enabling CPU bridge and controlling which expressions can use it
GpuOverrides.scala Updates expression metadata classes to use new convertToGpuImpl() method signature
GpuCpuBridgeExpression.scala Main bridge expression implementation handling GPU-to-CPU data transfer and parallel evaluation
GpuCpuBridgeOptimizer.scala Optimizer for making bridge vs GPU decisions and merging adjacent bridge expressions
GpuCpuBridgeThreadPool.scala Thread pool implementation with priority queuing and task context propagation
various shim files Method signature updates from convertToGpu() to convertToGpuBase() or convertToGpuImpl()
cpu_bridge_test.py Integration tests for CPU bridge functionality

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
@revans2 revans2 changed the title Adds in support for a "CpuBridge" that lets us fall back to the CPU on a per-expression level instead of per-SparkPlan node level. Adds in support for a "CpuBridge" that lets us fall back to the CPU on a per-expression level instead of per-SparkPlan node level. [databricks] Sep 22, 2025
@revans2
Copy link
Collaborator Author

revans2 commented Sep 22, 2025

build

@revans2
Copy link
Collaborator Author

revans2 commented Sep 22, 2025

I did run into one unexpected issue when I tried to run test_bloom_filter_join_cpu_probe with thsi enabled by default. I am getting errors like.

java.lang.IllegalArgumentException: requirement failed: Subquery scalar-subquery#55, [id=#193] has not finished
	at scala.Predef$.require(Predef.scala:281)
	at org.apache.spark.sql.execution.ScalarSubquery.eval(subquery.scala:98)
	at org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain.bloomFilter$lzycompute(BloomFilterMightContain.scala:93)
	at org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain.bloomFilter(BloomFilterMightContain.scala:92)
	at org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain.eval(BloomFilterMightContain.scala:98)
	at org.apache.spark.sql.catalyst.expressions.InterpretedUnsafeProjection.apply(InterpretedUnsafeProjection.scala:81)
	at com.nvidia.spark.rapids.GpuCpuBridgeExpression.$anonfun$createCodegenEvaluationFunction$3(GpuCpuBridgeExpression.scala:359)
	at com.nvidia.spark.rapids.GpuCpuBridgeExpression.$anonfun$createCodegenEvaluationFunction$3$adapted(GpuCpuBridgeExpression.scala:358)

and also NPEs like

Caused by: java.lang.NullPointerException
	at org.apache.spark.sql.catalyst.InternalRow$.$anonfun$getAccessor$16(InternalRow.scala:156)
	at org.apache.spark.sql.catalyst.InternalRow$.$anonfun$getAccessor$16$adapted(InternalRow.scala:155)
	at org.apache.spark.sql.catalyst.expressions.BoundReference.eval(BoundAttribute.scala:40)
	at org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain.bloomFilter$lzycompute(BloomFilterMightContain.scala:93)
	at org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain.bloomFilter(BloomFilterMightContain.scala:92)
	at org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain.eval(BloomFilterMightContain.scala:98)
	at org.apache.spark.sql.catalyst.expressions.InterpretedUnsafeProjection.apply(InterpretedUnsafeProjection.scala:81)
	at com.nvidia.spark.rapids.GpuCpuBridgeExpression.$anonfun$createCodegenEvaluationFunction$3(GpuCpuBridgeExpression.scala:359)
	at com.nvidia.spark.rapids.GpuCpuBridgeExpression.$anonfun$createCodegenEvaluationFunction$3$adapted(GpuCpuBridgeExpression.scala:358)

I am going to spend some time to try and understand this a bit more and see if I can fix it.

@revans2
Copy link
Collaborator Author

revans2 commented Sep 23, 2025

I fixed the ScalarSubquery issue so I think this is ready to go.

@revans2
Copy link
Collaborator Author

revans2 commented Sep 23, 2025

build

@revans2
Copy link
Collaborator Author

revans2 commented Sep 23, 2025

build

@revans2
Copy link
Collaborator Author

revans2 commented Oct 8, 2025

build

Copy link
Collaborator

@abellina abellina left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to spend more time in GpuCpuBridgeExpression.scala and optimizer, but I think this is looking good.

TEST_PARALLEL_OPTS=()
elif [[ ${TEST_PARALLEL} -gt ${MAX_PARALLEL} ]]; then
TEST_PARALLEL_OPTS=("-n" "$MAX_PARALLEL")
TEST_PARALLEL_OPTS=("-n" "$MAX_PARALLEL" "--dist=load" "--maxschedchunk=0")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a comment would be nice describing what --dist=load and --maxschedchunk=0 do

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry this is something that leaked in here by accident. This is for #13235 and I was running a lot of tests so I added this to speed things up.



@allow_non_gpu("ProjectExec", "Pmod")
@allow_non_gpu("ProjectExec", "Pmod", "BoundReference", "Literal", "PromotePrecision")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to allow these other operators non_gpu?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could make it so that Literal and BoundReference are ignored automatically, because they are all over the place. PromotePrecision is needed because on older versions of Spark it will insert this to track metadata when it will cast a Decimal expression to a higher precision.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should also explain that if we fall back to the CPU using the bridge we still want to throw an exception. So we then need to choose between adding an allow list of expressions, which this has done, or allowing the fallback if the Exec it is a part of is in the allow list. I picked the former because I felt it would let us be more accurate in our testing.

Some(GpuBindReferences.bindGpuReferences(finalProjections, aggBufferAttributes))
val bound = Some(GpuBindReferences.bindGpuReferences(finalProjections, aggBufferAttributes))
// Inject CPU bridge metrics if provided
metrics.foreach(m => bound.foreach(GpuMetric.injectMetrics(_, m)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we pass metrics to bindGpuReferences and inject them there? Might make the code more compact.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I debated this a lot. I was worried that it would make the code for bind references more complicated, especially the part where and expression can take over for how it binds references, but I agree it is kind of ugly as it is today and prone to errors if we forget a place. I will try it out and see how complex it gets.

case class PrioritizedCpuBridgeTask[T](
task: Callable[T],
taskContext: TaskContext,
batchSize: Int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is batchSize used for?

nullSafe = false,
releaseSemaphore = false
)
val r = new NvtxRange("evaluateOnCPU", NvtxColor.BLUE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, withResource

@revans2
Copy link
Collaborator Author

revans2 commented Oct 13, 2025

build

@revans2
Copy link
Collaborator Author

revans2 commented Oct 13, 2025

build

@revans2
Copy link
Collaborator Author

revans2 commented Oct 13, 2025

I have updated the code and did a bit of performance optimization. @abellina and @zpuller if you could take another look I would appreciate it. I think I have addressed all of the review comments so far.

@revans2
Copy link
Collaborator Author

revans2 commented Oct 13, 2025

build

@revans2
Copy link
Collaborator Author

revans2 commented Oct 13, 2025

build

@revans2
Copy link
Collaborator Author

revans2 commented Oct 13, 2025

build

Copy link
Collaborator

@zpuller zpuller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still working through it but I had a couple comments for now

* from SparkPlan nodes. Use the public API that requires metrics instead.
*/
def bindGpuReferences[A <: Expression](
def bindGpuReferencesInternal[A <: Expression](
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we/should we make this project private?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want it to be private because there are a few cases, even though they are rare, where we want to call it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant project private as in it would be callable from (most) RAPIDS code but not from spark itself. But anyway I don't have a strong opinion

override def prettyName: String = "gpu_cpu_bridge"

override def toString: String = {
val gpuInputsStr = if (gpuInputs.nonEmpty) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we need to check if it's nonEmpty? if we took an empty Seq and did .mkString(", ") wouldn't it just be empty string?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar below

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Cursor does not always deal with these very well.

* Takes an iterator of rows and the expected row count, produces a complete
* GpuColumnVector result.
*/
private def createEvaluationFunction(
Copy link
Collaborator

@zpuller zpuller Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this and the below take a cpuExpression argument, given that cpuExpressions is a field of this class? I could see reasons to do this, but where I get confused is that the threadLocalProjection seems to be based off the field cpuExpression, not the arg

@revans2
Copy link
Collaborator Author

revans2 commented Oct 14, 2025

build

@revans2
Copy link
Collaborator Author

revans2 commented Oct 14, 2025

build

Copy link
Collaborator

@zpuller zpuller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Posted one other minor question, but looking good

override def call(): T = {
// Register this thread with RmmSpark for memory tracking if we have a task context
if (taskContext != null) {
RmmSpark.currentThreadIsDedicatedToTask(taskContext.taskAttemptId())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to use TaskRegistryTracker here, or do we not want the same behavior eg. retries?

@abellina
Copy link
Collaborator

what is the state of this? Is it something we should re-review?

@pxLi
Copy link
Member

pxLi commented Nov 17, 2025

NOTE: release/25.12 has been created from main. Please retarget your PR to release/25.12 if it should be included in the release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance A performance related task/issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants