Accelerating Spark with GPUs: NVIDIA RAPIDS Code Walkthrough

Intro

The RAPIDs ecosystem is an NVIDIA-backed open-source software collection that utilizes NVIDIA’s GPU interface layer, CUDA, to provide higher-level GPU-accelerated abstractions. A key example of this is spark-rapids. spark-rapids is an extension to Spark that enables GPU-based optimization during Spark query execution. In this article, we will reference the spark-rapids source code as we step through a sample plan optimization and zoom into the implementations of a basic filter operation on/off GPU.

Context

In addition to Spark’s full support for traditional data analysis via its SQL and dataframe APIs, Spark has shipped with a machine learning library MLlib from its inception. Today, users use this library to train models, often on hosted Spark services like Databricks and Amazon EMR. MLlib faces stiff competition from services directly oriented towards machine learning like AWS SageMaker, and from data warehouse providers like Snowflake extending their offerings with machine learning primitives. Given the GPU-intensive nature of ML workloads, spark-rapids fulfills a critical need in the Spark ecosystem.

Development Environment

Working with the spark-rapids repo requires CUDA drivers and a compatible GPU. We can meet these requirements by allocating a g4dn.xlarge AWS EC2 instance with the free “Amazon Linux 2 AMI with NVIDIA TESLA GPU Driver” AMI. Then, we can issue the following commands to properly set up our environment for a VSCode remote SSH session:

sudo yum install git
git clone https://github.com/NVIDIA/spark-rapids.git

# Note - important to use this version of Maven.
wget https://archive.apache.org/dist/maven/maven-3/3.6.3/binaries/apache-maven-3.6.3-bin.tar.gz
tar xvf apache-maven-3.6.3-bin.tar.gz
echo "export PATH=$PATH:~/apache-maven-3.6.3/bin" >> ~/.bashrc
source ~/.bashrc

# Note - important to use Java 11 with VSCode.
# See: https://github.com/NVIDIA/spark-rapids/issues/9542.
sudo yum install java-11-amazon-corretto.x86_64
sudo yum install jq

cd spark-rapids
git checkout branch-23.04
mvn -U clean install -DskipTests -Dmaven.javadoc.skip

# Note - important to use spark-rapid's bloop instead of VSCode plugin bloop.
# See: https://github.com/NVIDIA/spark-rapids/issues/9542.
./build/buildall --generate-bloop --profile=311
ln -s .bloop-spark311 .bloop

Once VSCode remote SSH is configured, finalize the setup by installing the Scala Metals plugin. Note that it is important to deny the dialog prompting us to import the Maven build - this will force Metals to use our custom generated bloop install.

Plan Modification

Spark exposes the SparkSessionExtension as the entry point for integrating new rules. Callers implement a callable that accepts the class, then uses interface methods during session initialization to register custom rules. spark-rapids calls several SparkSessionExtension methods in its SQLExecPlugin:

class SQLExecPlugin extends (SparkSessionExtensions => Unit) with Logging {
...
	override def apply(extensions: SparkSessionExtensions): Unit = {
	  extensions.injectColumnar(columnarOverrides)
	  extensions.injectQueryStagePrepRule(queryStagePrepOverrides)
	  extensions.injectPlannerStrategy(_ => strategyRules)
	}
...
}

Each of these inject calls represents a different category of rules. The rule categories used by Spark are:

  1. Columnar Rules - Rules for replacing operators in the physical plan with columnar versions, and for replacing column-to-row transition operators with custom versions.
  2. Query Stage Prep Rules - Rules applied during adaptive query execution preparation phase.
  3. Strategy Rules - Rules informing Spark how to translate the logical plan into a physical plan.

The bulk of spark-rapids’s plan modification logic lives in the columnar rules, so we will focus on this rule category.

Stepping into the columnar override registration code, we see that the ColumnarOverrideRules class is composed of members of the GpuOverrides and GpuTransitionOverrides suites.

case class ColumnarOverrideRules() extends ColumnarRule with Logging {
  lazy val overrides: Rule[SparkPlan] = GpuOverrides()
  lazy val overrideTransitions: Rule[SparkPlan] = new GpuTransitionOverrides()

  override def preColumnarTransitions : Rule[SparkPlan] = overrides

  override def postColumnarTransitions: Rule[SparkPlan] = overrideTransitions
}

Each of these override classes exposes an apply method, that is called by the Spark framework during various phases of plan generation. We can see the exact mechanism used to apply these columnar rules to the query plan by taking a look at the appropriate section in the Spark codebase:

def apply(plan: SparkPlan): SparkPlan = {
    var preInsertPlan: SparkPlan = plan
    columnarRules.foreach((r : ColumnarRule) =>
      preInsertPlan = r.preColumnarTransitions(preInsertPlan))
    var postInsertPlan = insertTransitions(preInsertPlan)
    columnarRules.reverse.foreach((r : ColumnarRule) =>
      postInsertPlan = r.postColumnarTransitions(postInsertPlan))
    postInsertPlan
  }

We see that the preColumnarRules are applied first. Then, in insertTransitions, the framework provides automatic insertion of row-to-columnar transition operators depending on the formats required by the plan. Finally, the caller has the opportunity to perform another set of postColumnarTransitions rules to the plan with transitions. In this article, we will only step into the code used for the preColumnarTransitions. If you are interested in the exact mechanism used in postColumnarTransitions for spark-rapids, see the source code here.

Stepping into GpuOverrides, we see that the bulk of the logic is applied in applyOverrides:

private def applyOverrides(plan: SparkPlan, conf: RapidsConf): SparkPlan = {
    val wrap = GpuOverrides.wrapAndTagPlan(plan, conf)
    val detectDeltaCheckpoint = conf.isDetectDeltaCheckpointQueries
    if (conf.isDetectDeltaLogQueries && isDeltaLakeMetadataQuery(plan, detectDeltaCheckpoint)) {
      wrap.entirePlanWillNotWork("Delta Lake metadata queries are not efficient on GPU")
    }
    val reasonsToNotReplaceEntirePlan = wrap.getReasonsNotToReplaceEntirePlan
    if (conf.allowDisableEntirePlan && reasonsToNotReplaceEntirePlan.nonEmpty) {
      if (conf.shouldExplain) {
        logWarning("Can't replace any part of this plan due to: " +
            s"${reasonsToNotReplaceEntirePlan.mkString(",")}")
      }
      plan
    } else {
      val optimizations = GpuOverrides.getOptimizations(wrap, conf)
      wrap.runAfterTagRules()
      if (conf.shouldExplain) {
        wrap.tagForExplain()
        val explain = wrap.explain(conf.shouldExplainAll)
        if (explain.nonEmpty) {
          logWarning(s"\n$explain")
          if (conf.optimizerShouldExplainAll && optimizations.nonEmpty) {
            logWarning(s"Cost-based optimizations applied:\n${optimizations.mkString("\n")}")
          }
        }
      }
      GpuOverrides.doConvertPlan(wrap, conf, optimizations)
    }
  }

These functions drive the optimizations:

  • wrapAndTagPlan - Takes a raw SparkPlan and converts it to a “wrapped” version containing GPU-related metadata for each node. The main metadata classes are in the RapidsMeta file.
  • getOptimizations - Use a cost-based optimizer to eliminate situations where dispatching to the GPU results in worse overall query performance.
  • doConvertPlan - Convert the wrapped plan back into a SparkPlan, now containing GPU operators where applicable.

Now let’s take a look at an entire plan as it moves through the generation process. We’ll use one of the test cases as an example. The FilterExprSuite provides some basic plans to examine. For instance, this test issues a few simple filters to a DataFrame:

testSparkResultsAreEqual("filter with decimal columns", mixedDf(_), repart = 0) { df =>
    df.filter(col("ints") > 90)
      .filter(col("decimals").isNotNull)
      .select("ints", "strings", "decimals")
  }

Inspecting the raw Spark plan, we see that this query consists of a scan (load), a filter operation, then a project (choosing our columns):

ProjectExec@15413 "Project [ints#118, strings#121, decimals#122]
+- Filter ((isnotnull(ints#118) AND (ints#118 > 90)) AND isnotnull(decimals#122))
   +- Scan ExistingRDD[ints#118,longs#119L,doubles#120,strings#121,decimals#122]

After this plan goes through the wrapAndTagPlan phase of spark-rapids’s modifications, we see that the plan is enriched with additional information indicating which operations will run on the GPU:

GpuProjectExecMeta@15444 "*Exec <ProjectExec> will run on GPU
  *Exec <FilterExec> will run on GPU
    *Expression <And> ((isnotnull(ints#118) AND (ints#118 > 90)) AND isnotnull(decimals#122)) will run on GPU
      *Expression <And> (isnotnull(ints#118) AND (ints#118 > 90)) will run on GPU
        *Expression <IsNotNull> isnotnull(ints#118) will run on GPU
        *Expression <GreaterThan> (ints#118 > 90) will run on GPU
      *Expression <IsNotNull> isnotnull(decimals#122) will run on GPU
    ! <RDDScanExec> cannot run on GPU because GPU does not currently support the operator class org.apache.spark.sql.execution.RDDScanExec

The call to convertPlan returns our plan to the SparkPlan primitive, and the operators in the plan have been replaced with their GPU equivalents:

GpuProjectExec@15543 "GpuProject [ints#118, strings#121, decimals#122], true
+- GpuFilter ((gpuisnotnull(ints#118) AND (ints#118 > 90)) AND gpuisnotnull(decimals#122)), true
   +- Scan ExistingRDD[ints#118,longs#119L,doubles#120,strings#121,decimals#122]

During the Spark-native transition phase, we see that a RowToColumnar operator is inserted into the plan automatically to ensure that the data hitting our GpuFilter is in the appropriate format:

ColumnarToRowExec@15573 "ColumnarToRow
+- GpuProject [ints#118, strings#121, decimals#122], true
   +- GpuFilter ((gpuisnotnull(ints#118) AND (ints#118 > 90)) AND gpuisnotnull(decimals#122)), true
      +- RowToColumnar
         +- Scan ExistingRDD[ints#118,longs#119L,doubles#120,strings#121,decimals#122]

Finally, the postColumnarTransitions rules replace the Spark-native transition operator with an optimized spark-rapids version as well:

GpuColumnarToRowExec@15599 "GpuColumnarToRow false
+- GpuProject [ints#118, strings#121, decimals#122], true
   +- GpuCoalesceBatches targetsize(2147483647)
      +- GpuFilter ((gpuisnotnull(ints#118) AND (ints#118 > 90)) AND gpuisnotnull(decimals#122)), true
         +- GpuRowToColumnar targetsize(2147483647)
            +- Scan ExistingRDD[ints#118,longs#119L,doubles#120,strings#121,decimals#122]

In the end, spark-rapids has constructed a plan that is logically equivalent to the original raw Spark plan, but has inserted the GPU-optimized operators to provide better performance. Now, let’s dive into the implementations of the GPU filter operator to understand how it differs from the Spark-native version.

Filter Implementation - Native vs. GPU

The native Spark implementation of the filter operator is done in FilterExec. This operator follows the default Spark pattern of using code generation to create an optimally pipelines stage execution with minimal virtual function call overhead. In practice, this means the FilterExec implements the CodegenSupport interface by providing a doConsume function which returns the Scala code necessary to represent the filter operation:

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
    val numOutput = metricTerm(ctx, "numOutputRows")

    /**
     * Generates code for `c`, using `in` for input attributes and `attrs` for nullability.
     */
    def genPredicate(c: Expression, in: Seq[ExprCode], attrs: Seq[Attribute]): String = {
      val bound = BindReferences.bindReference(c, attrs)
      val evaluated = evaluateRequiredVariables(child.output, in, c.references)

      // Generate the code for the predicate.
      val ev = ExpressionCanonicalizer.execute(bound).genCode(ctx)
      val nullCheck = if (bound.nullable) {
        s"${ev.isNull} || "
      } else {
        s""
      }

      s"""
         |$evaluated
         |${ev.code}
         |if (${nullCheck}!${ev.value}) continue;
       """.stripMargin
    }

    // To generate the predicates we will follow this algorithm.
    // For each predicate that is not IsNotNull, we will generate them one by one loading attributes
    // as necessary. For each of both attributes, if there is an IsNotNull predicate we will
    // generate that check *before* the predicate. After all of these predicates, we will generate
    // the remaining IsNotNull checks that were not part of other predicates.
    // This has the property of not doing redundant IsNotNull checks and taking better advantage of
    // short-circuiting, not loading attributes until they are needed.
    // This is very perf sensitive.
    // TODO: revisit this. We can consider reordering predicates as well.
    val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length)
    val extraIsNotNullAttrs = mutable.Set[Attribute]()
    val generated = otherPreds.map { c =>
      val nullChecks = c.references.map { r =>
        val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)}
        if (idx != -1 && !generatedIsNotNullChecks(idx)) {
          generatedIsNotNullChecks(idx) = true
          // Use the child's output. The nullability is what the child produced.
          genPredicate(notNullPreds(idx), input, child.output)
        } else if (notNullAttributes.contains(r.exprId) && !extraIsNotNullAttrs.contains(r)) {
          extraIsNotNullAttrs += r
          genPredicate(IsNotNull(r), input, child.output)
        } else {
          ""
        }
      }.mkString("\n").trim

      // Here we use *this* operator's output with this output's nullability since we already
      // enforced them with the IsNotNull checks above.
      s"""
         |$nullChecks
         |${genPredicate(c, input, output)}
       """.stripMargin.trim
    }.mkString("\n")

    val nullChecks = notNullPreds.zipWithIndex.map { case (c, idx) =>
      if (!generatedIsNotNullChecks(idx)) {
        genPredicate(c, input, child.output)
      } else {
        ""
      }
    }.mkString("\n")

    // Reset the isNull to false for the not-null columns, then the followed operators could
    // generate better code (remove dead branches).
    val resultVars = input.zipWithIndex.map { case (ev, i) =>
      if (notNullAttributes.contains(child.output(i).exprId)) {
        ev.isNull = FalseLiteral
      }
      ev
    }

    // Note: wrap in "do { } while(false);", so the generated checks can jump out with "continue;"
    s"""
       |do {
       |  $generated
       |  $nullChecks
       |  $numOutput.add(1);
       |  ${consume(ctx, resultVars)}
       |} while(false);
     """.stripMargin
  }

In contrast, the GpuFilterExec simply calls a GPU operation on the columnar batch it is working on:

override def internalDoExecuteColumnar(): RDD[ColumnarBatch] = {
    val numOutputRows = gpuLongMetric(NUM_OUTPUT_ROWS)
    val numOutputBatches = gpuLongMetric(NUM_OUTPUT_BATCHES)
    val opTime = gpuLongMetric(OP_TIME)
    val boundCondition = GpuBindReferences.bindReference(condition, child.output)
    val rdd = child.executeColumnar()
    rdd.map { batch =>
      GpuFilter.filterAndClose(batch, boundCondition, numOutputRows, numOutputBatches, opTime)
    }
  }

Stepping in a bit further, we can see that the filter operation is implemented using primitives from ai.rapids.cudf:

private def doFilter(checkedFilterMask: Option[cudf.ColumnVector],
      cb: ColumnarBatch): ColumnarBatch = {
    checkedFilterMask.map { checkedFilterMask =>
      withResource(checkedFilterMask) { checkedFilterMask =>
        val colTypes = GpuColumnVector.extractTypes(cb)
        withResource(GpuColumnVector.from(cb)) { tbl =>
          withResource(tbl.filter(checkedFilterMask)) { filteredData =>
            GpuColumnVector.from(filteredData, colTypes)
          }
        }
      }
    }.getOrElse {
      // Nothing to filter so it is a NOOP
      GpuColumnVector.incRefCounts(cb)
    }
  }

Conclusion

The basic mechanisms of additional plan transformations are similar to the example walked through above. But, the intricacy of their implementations and application cannot be understated. spark-rapids is a 200,000 line Scala repo supporting a massive set of GPU transformations for Spark, that has been under development since 2019. This article provides a basic introduction to some of the mechanisms, but truly grasping the codebase would require much careful study and contribution.