Unit Testing Spark Jobs

I went down the rabbit hole of Spark unit testing. As it’s uncharted territory, you find a lot of conflicting opinions. I decided to hack my way through it. This post is a summary of my thoughts and findings. I try to go into the process I went through to resolve the problems encountered.

To escape the theory, I will use concrete examples throughout the article. The data represent the number of flights count, from origin countries to destination countries. It has the following structure:

| Destination | Origin  | Count |
|-------------|---------|-------|
| Morocco     | Spain   | 3     |
| Morocco     | Egypt   | 5     |
| France      | Germany | 10    |
| ...         | ...     | ...   |

The structure is mindlessly inspired by the Spark: The Definitive Guide’s Code Repository. You find the code here.

Before going into the testing section, the tools to use, the project’s structure, and the reasoning behind a testable Spark job. First things first, the structure of the project:

+-- project
|   +-- build.properties
|   +-- plugins.sbt
+-- src
|   +-- main
|   +-- test
+-- .gitignore
+-- .scalafix.conf
+-- .scalafmt.conf
+-- LICENSE
+-- build.sbt
+-- README.md

Dependencies

Starting with Spark dependencies:

// build.sbt
libraryDependencies ++= Seq(
  "org.apache.spark" %% "spark-core" % sparkVersion,
  "org.apache.spark" %% "spark-sql" % sparkVersion,
  "org.apache.spark" %% "spark-hive" % sparkVersion % provided
)

Loading configuration

For configuration loading, you can use PureConfig. Add the following to your build.sbt:

// build.sbt
// https://github.com/pureconfig/pureconfig
libraryDependencies += "com.github.pureconfig" %% "pureconfig" % "0.14.0"

Then create a /src/main/resources/reference.conf file:

# reference.conf
settings {
  database = "dev"
}

To be able to load the configuration, you need to create a Settings case class:

// src/main/scala/me/ayoublabiad/common/Settings.scala
package me.ayoublabiad.common

case class Settings(
  database: String
)

Then you can load the configuration wherever you need it:

import me.ayoublabiad.common.Settings
import pureconfig.ConfigSource

val settings: Settings = ConfigSource.default.loadOrThrow[Settings]

Then you can pass the settings object to your Spark job.

FlightsJob.process(spark, settings)

Testing tools

For testing, you will be using:

// build.sbt
// https://mvnrepository.com/artifact/com.fasterxml.jackson.module/jackson-module-scala
libraryDependencies += "com.fasterxml.jackson.module" %% "jackson-module-scala" % "2.13.4"

// https://mvnrepository.com/artifact/org.scalatest/scalatest-flatspec
libraryDependencies += "org.scalatest" %% "scalatest-flatspec" % "3.2.14" % "test"

// https://github.com/MrPowers/spark-fast-tests
libraryDependencies += "com.github.mrpowers" %% "spark-fast-tests" % "0.23.0" % "test

Scalatest for creating unit tests, Spark-fast-tests for comparing DataFrames, and jackson for parsing JSON.

Starting with Scalatest. It’s a testing framework for Scala. It’s a good fit for Spark as it’s a Scala library.

Plugins

To make your life easier, you can use the following plugins:

// project/plugins.sbt

addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.9.24")
addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.4.0")

Scalafmt

Scalafmt for code formatting. You can enable it in your IntelliJ IDEA by going to: CTRL + ALT + S > Editor > Code Style > Scala > Formatter Change it from IntelliJ to Scalafmt. Then add the following configuration to the root directory of your project:

# .scalafmt.conf
version = "2.7.5"
# Check https://github.com/monix/monix/blob/series/3.x/.scalafmt.conf for full configuration

The shortcut to format your code in IntelliJ IDEA is CTRL + ALT + L.

Scalafix

Scalafix is a linting and refactoring tool. Add the following configuration to the root directory of your project:

# .scalafix.conf
# https://spotify.github.io/scio/dev/Style-Guide.html
rules = [
  RemoveUnused,
  LeakingImplicitClassVal
]

Add the following to your build.sbt:

// built.sbt
// https://scalacenter.github.io/scalafix/docs/users/installation.html

inThisBuild(
  List(
    scalaVersion := "2.11.12",
    semanticdbEnabled := true, // enable SemanticDB
    semanticdbVersion := scalafixSemanticdb.revision // use Scalafix compatible version
  )
)

scalacOptions ++= List(
  "-Ywarn-unused"
)

To run the rules, execute this command:

sbt scalafixAll

Project structure

This might be the point of divergence. Each team has its own structure adapted to its specific needs. Here are a couple of things to consider:

scala
+-- me.ayoublabiad
|   +-- common
|   |   +-- SparkSessionWrapper.scala
|   +-- io
|   |   +-- Read.scala
|   |   +-- Write.scala
|   +-- job
|   |   +-- flights
|   |   |   +-- FlightsTransformation.scala
|   |   |   +-- FlightsJob.scala
|   |   +-- Job.scala
|   +-- Main.scala

We will spend the rest of the article on the FlightsJob and FlightsTransformation classes.

A case for testing Spark jobs

Spark job is composed of three phases:

The three phases of a Spark job

You read your data, transform it, and write it. While you aren’t trying to test Spark itself, you can test the transformations. You are trying to answer the question: “Does my transformation work as expected?”

Let’s concretize those ideas with an example. Based on the flight table described before, you want to get the total number of flights per destination. You can write the following transformation:

import org.apache.spark.sql.functions.sum
def getDestinationsWithTotalCount(flights: DataFrame): DataFrame =
    flights
      .groupBy("destination")
      .agg(sum("count").alias("total_count"))

Assuming that you need the origin country column to appear in the result, you can write the following transformation using Window:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.count


def addTotalCountColumn(flights: DataFrame): DataFrame = {
  val winExpDestination = Window.partitionBy("destination")

  flights
    .withColumn("total_count", count("count").over(winExpDestination))
}

Is this function worthy of testing? Probably not if:

While this is a simple example, you can imagine that the transformations can be more complex. You can also imagine that the transformations can be reused in other jobs. In this case, you want to make sure that the transformations work as expected. I think that Raymond Hettinger talk about The Mental Game of Python is relevant in this context. We can only keep a limited amount of information in our minds. We need to be able to trust our code.

Although the earlier implementation is correct, it’s not the most efficient. You can rewrite it using groupBy and sum:

def getDestinationsWithTotalCount(flights: DataFrame): DataFrame =
    flights
      .groupBy("destination")
      .agg(
        collect_list("origin").alias("origins"),
        sum("count").alias("total_count"))

Or is it? But you get the idea.

Another example to take a look at is getting a value from a column. The tests for this function would be:

  "getParamFromTable" should "return the param if it exists" in {
    val input: DataFrame = createDataFrame(
      obj(
        "minFlightsThreshold" -> 4
      )
    )

    val threshold: Long = getValueFromDataFrame(spark, input, "minFlightsThreshold")

    assert(threshold == 4)
  }

"getParamFromTable" should "throw an Exception if the threshold doesn't exist" in {
    val input: DataFrame = Seq.empty[Long].toDF("minFlightsThreshold")

    val caught =
      intercept[Exception] {
        getValueFromDataFrame(spark, input, "minFlightsThreshold")
      }

    assert(caught.getMessage == "Threshold not found.")
  }

The implementation looks something like this:

  def getValueFromDataFrame(dataFrame: DataFrame, columnName: String): Long =
    Try(dataFrame.select(columnName).first().getLong(0)).getOrElse(throw new Exception("Threshold not found."))

This function can go wrong in different ways. You can forget to check if the column exists. You can also forget to check if the column is empty. You can also forget to check if the column is of type Long.

If we agreed that we should test Spark jobs, How to go about it?

How to test Spark jobs?

Let’s create Spark Session first:

// https://github.com/MrPowers/spark-fast-tests/blob/master/src/test/scala/com/github/mrpowers/spark/fast/tests/SparkSessionTestWrapper.scala

trait SparkSessionTestWrapper {
  lazy val spark: SparkSession =
    SparkSession
      .builder()
      .master("local[1]") // One partition One core
      .appName("SparkTestingSession")
      .config("spark.sql.shuffle.partitions", "1")
      .getOrCreate()
}

To create a Spark DataFrame, you use one of the following methods:

Using .toDF

After importing spark.implicits._ you do the following manipulation:

val input: DataFrame = Seq(
      ("morocco", "spain", 3),
      ("morocco", "egypt", 5),
      ("france", "germany", 10)
    ).toDF("destination", "origin", "count")

This method is convenient and fast. But it has some limitations. You can’t create a DataFrame with a null value. You can’t create a DataFrame with a StructType column. You can’t create a DataFrame with an ArrayType column. And it lacks readability as soon as you have more than 3 columns.

We need a more flexible and readable way to create DataFrames. What’s the most intuitive way of representing a key-value pair? A Map! Let’s create a Map and convert it to a DataFrame:

From JSON to DataFrame

The idea is to represent each row as a Map. It might be verbose but it’s readable. You can use the following function to convert a Map to a DataFrame:

import org.apache.spark.sql.DataFrame
import spark.implicits._
import org.json4s.DefaultFormats
import org.json4s.jackson.Json

def createDataFrame(data: Map[String, Any]*): DataFrame =
  spark.read.json(Seq(Json(DefaultFormats).write(data)).toDS)

Then to create a DataFrame:

val input: DataFrame = createDataFrame(
      Map(
        "destination" -> "morocco",
        "origin" -> "spain",
        "count" -> 3
      ),
      Map(
        "destination" -> "morocco",
        "origin" -> "egypt",
        "count" -> 5
      ),
      Map(
        "destination" -> "france",
        "origin" -> "germany",
        "count" -> 10
      )
    )

This technique isn’t free. It’s slower than .toDF. But it’s more flexible and readable.

Loading a CSV file

Loading a CSV file with each test isn’t a good idea. It slows down the tests. It also makes the tests less readable. It might be a good idea to load the CSV file once and reuse it in the tests.

A full test

Now a test should look something like this:


    val input: DataFrame = createDataFrame(
      obj(
        "destination" -> "morocco",
        "origin" -> "spain",
        "count" -> 3
      ),
      obj(
        "destination" -> "morocco",
        "origin" -> "egypt",
        "count" -> 5
      ),
      obj(
        "destination" -> "france",
        "origin" -> "germany",
        "count" -> 10
      )
    )

    val actual: DataFrame = getDestinationsWithTotalCount(input)
    val expected: DataFrame = createDataFrame(
        obj(
          "destination" -> "france",
          "total_count" -> 10
        ),
        obj(
          "destination" -> "morocco",
          "total_count" -> 8
        )
    )

    assertSmallDataFrameEquality(actual, expected, ignoreNullable = true, orderedComparison = false)

The assertSmallDataFrameEquality function is from spark-fast-tests library. Check the repository for more examples.

Conclusion

And that’s it. I hope this blog was helpful. If you have any suggestions, feel free to open an issue in the repository.