Unit Testing Spark Jobs

I recently went down the rabbit hole of Spark unit testing. As it’s uncharted territory, you find a plethora of conflicting opinions. So I decided to hack my way through it. This article is a summary of my thoughts and findings. I try to go into the process I went through to resolve the problems I faced during my Spark journey.

To escape the theory, I will be using concrete examples throughout the article. The data represent the number of flights count, from origin countries to destination countries. It has the following structure:

Destination Origin Count
Morocco Spain 3
Morocco Egypt 5
France Germany 10

The structure is mindlessly inspired by the Spark: The Definitive Guide’s Code Repository. You find my code here.

Before going into the testing section, I want to lay down the tools I am using, the basics of my project structure, and those of a testable Spark job. First thing first, this is the structure of my project:

+-- project
|   +-- build.properties
|   +-- plugins.sbt
+-- src
|   +-- main
|   +-- test
+-- .gitignore
+-- .scalafix.conf
+-- .scalafmt.conf
+-- LICENSE
+-- build.sbt
+-- README.md

Dependencies

Let’s start with Spark dependencies:

// build.sbt
libraryDependencies ++= Seq(
  "org.apache.spark" %% "spark-core" % sparkVersion,
  "org.apache.spark" %% "spark-sql" % sparkVersion,
  "org.apache.spark" %% "spark-hive" % sparkVersion % provided
)

Loading configuration

For configuration loading, I am using Config. But I am curious about using it along with Guice. But for now:

// build.sbt
// https://mvnrepository.com/artifact/com.typesafe/config
libraryDependencies += "com.typesafe" % "config" % "1.4.1"

To add a configuration file inside your project. Create a /src/main/resources/reference.conf with the following:

# reference.conf
settings {
  database = "dev"
}

Or any configuration you need. Then use Config to load it as follows:

val config = ConfigFactory.load()
val db: String = config.getString("settings.database")

And you are ready to go! Check the docs to find out more.

Testing tool

For testing, we will be using:

// build.sbt
// https://mvnrepository.com/artifact/me.vican.jorge/dijon
libraryDependencies += "me.vican.jorge" %% "dijon" % "0.4.0" % "test"

// https://mvnrepository.com/artifact/org.scalatest/scalatest-flatspec
libraryDependencies += "org.scalatest" %% "scalatest-flatspec" % "3.3.0-SNAP3" % "test"

// https://github.com/MrPowers/spark-fast-tests
libraryDependencies += "com.github.mrpowers" %% "spark-fast-tests" % "0.21.3" % "test"

Scalatest for creating unit tests, Spark-fast-tests for comparing DataFrames, and Sijon for creating dynamically typed JSON. Scalatest is a flexible testing tool. It is simple to get up and running. An example from the front page:

import collection.mutable.Stack
import org.scalatest._
import flatspec._
import matchers._

class ExampleSpec extends AnyFlatSpec with should.Matchers {

  "A Stack" should "pop values in last-in-first-out order" in {
    val stack = new Stack[Int]
    stack.push(1)
    stack.push(2)
    stack.pop() should be (2)
    stack.pop() should be (1)
  }

  it should "throw NoSuchElementException if an empty stack is popped" in {
    val emptyStack = new Stack[Int]
    a [NoSuchElementException] should be thrownBy {
      emptyStack.pop()
    } 
  }
}

Plugins

For the plugins, I am using these two:

// project/plugins.sbt

addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.9.24")
addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.4.0")

Scalafmt

Scalafmt for code formating. You can enable it in your IntelliJ IDEA by going to: CTRL + ALT + S > Editor > Code Style > Scala > Formatter Change it from IntelliJ to Scalafmt. Then add the following configuration to the root directory of your project:

# .scalafmt.conf
version = "2.7.5"
# Check https://github.com/monix/monix/blob/series/3.x/.scalafmt.conf for full configuration

Scalafix

Scalafix is a linting and refactoring tool. It needs some configuration in the build.sbt file:

// built.sbt
// https://scalacenter.github.io/scalafix/docs/users/installation.html

inThisBuild(
  List(
    scalaVersion := "2.12.12",
    semanticdbEnabled := true, // enable SemanticDB
    semanticdbVersion := scalafixSemanticdb.revision // use Scalafix compatible version
  )
)

scalacOptions ++= List(
  "-Ywarn-unused"
)

Then:

# .scalafix.conf
# https://spotify.github.io/scio/dev/Style-Guide.html
rules = [
  RemoveUnused,
  LeakingImplicitClassVal
]

Now that we know our tools, let’s explore our structure.

Project structure

This is the point of divergence. Each team will have its own structure adapted to its specific needs. Here are my suggestions:

scala
+-- me.ayoublabiad
|   +-- common
|   |   +-- SparkSessionWrapper.scala
|   +-- io
|   |   +-- Read.scala
|   |   +-- Write.scala
|   +-- job
|   |   +-- flights
|   |   |   +-- FlightsTransformation.scala
|   |   |   +-- FlightsJob.scala
|   |   +-- Job.scala
|   +-- Main.scala
  def readFromCsvFileWithSchema(location: String, schema: String): DataFrame =
    spark.read
      .option("header", "true")
      .schema(schema)
      .csv(location)

with val schema = "destination STRING, origin STRING, count INT". And the output:

  def writeTohive(spark: SparkSession, dataFrame: DataFrame, database: String, tableName: String): Unit = {
    dataFrame.createGlobalTempView(s"${tableName}_view")
    spark.sql(s"DROP TABLE IF EXISTS $database.$tableName")
    spark.sql(s"CREATE TABLE $database.$tableName AS SELECT * FROM ${tableName}_view")
  }

Inside the job package is your playground. Now that we’re all on the same page, let’s discuss the ideas behind testing Spark jobs.

A case for testing Spark jobs

I like to think about Spark jobs as three simple phases:

The three phases of a Spark job

You read your data, build your DAG, then you trigger the whole thing by a powerful almighty action. Now we have to keep in mind that our goal is not to test Spark! You are testing your logic and business logic, which is condensed mainly in the transformations step. You are trying to answer the question: Does this code reflect what I am thinking? Proceeding, I don’t find a use for testing the IO of your job, that’s on Spark. So we will stick to finding if we need to test those transformations.

Let’s concretize those ideas: Based on the data that we have, what the total count of flights for each destination?

def getDestinationsWithTotalCount(flights: DataFrame): DataFrame =
    flights
      .groupBy("destination")
      .agg(sum("count").alias("total_count"))

That’s simple enough. Let’s assume you need the origin country column to stay, and you are not a big fan of self-joins. The natural solution would be using a Window function:

  def addTotalCountColumn(flights: DataFrame): DataFrame = {
    val winExpDestination = Window.partitionBy("destination")

    flights
      .withColumn("total_count", count("count").over(winExpDestination))
  }

Is this function worthy of testing? Probably not if:

While this is a simple example to illustrate an idea, you can scale it up to more complex use cases. I think that Raymond Hettinger talk about The Mental Game of Python is relevant in this context. Unfortunately, we can only hold so much in our 7 slots brain.

Although I lied before to prove a point, the natural solution should be:

def getDestinationsWithTotalCount(flights: DataFrame): DataFrame =
    flights
      .groupBy("destination")
      .agg(
        collect_list("origin").alias("origins"),
        sum("count").alias("total_count"))

Or is it? But you get the idea.

Another example I want to look at, is having to retrieve a parameter stored in a table. Let’s look at the tests first:

  "getParamFromTable" should "turn the param if it exists" in {
    val input: DataFrame = createDataFrame(
      obj(
        "minFlightsThreshold" -> 4
      )
    )

    val threshold: Long = getParamFromTable(spark, input, "minFlightsThreshold")

    assert(threshold == 4)
  }

"getParamFromTable" should "throw an Exception if the threshold doesn't exist" in {
    val input: DataFrame = Seq.empty[Long].toDF("minFlightsThreshold")

    val caught =
      intercept[Exception] {
        getParamFromTable(spark, input, "minFlightsThreshold")
      }

    assert(caught.getMessage == "Threshold not found.")
  }

So our implementation looks something like this:

def getParamFromTable(spark: SparkSession, dataFrame: DataFrame, paramColumnName: String): Long = {
    import spark.implicits._
    val params: Array[Long] = dataFrame.select(paramColumnName).as[Long].collect
    if (params.isEmpty) {
      throw new Exception("Threshold not found.")
    } else {
      params(0)
    }
  }

This function represents one of those edge cases that you need to deal with rigorously to avoid unnecessary bugs. Hence the tests.

It we agreed that we might need to test Spark jobs, How do we go about it?

How to test Spark jobs?

What we will be spending most of our time testing are the functions inside the transformation Classes. To do so, we need to create Spark Session first:

// https://github.com/MrPowers/spark-fast-tests/blob/master/src/test/scala/com/github/mrpowers/spark/fast/tests/SparkSessionTestWrapper.scala

trait SparkSessionTestWrapper {
  lazy val spark: SparkSession =
    SparkSession
      .builder()
      .master("local[1]") // One partition One core
      .appName("SparkTestingSession")
      .config("spark.sql.shuffle.partitions", "1")
      .getOrCreate()
}

Now we create the input DataFrames. There are three ways to do so:

Using .toDF

After importing spark.implicits._ you do the following manipulation:

val input: DataFrame = Seq(
      ("morocco", "spain", 3),
      ("morocco", "egypt", 5),
      ("france", "germany", 10)
    ).toDF("destination", "origin", "count")

It’s convenient and fast. But for me, it lacks readability as soon as we go beyond 6 or 7 columns. Beyond 7 columns, you will have to start counting the tuple position if you want to change a value. This way is also limited to 23 columns - the max size of a tuple in Scala.

To overcome these challenges, I was thinking about the most readable way of representing a DataFrame. We have column names and values, basically keys and values. Thus the following implementation:

From JSON to DataFrame

My idea was trying to represent DataFrame in key-value pairs. Where the key in the name of the column and the value is a value in a row. Each row would be a map. That way, I could go to each row/map, look for a specific column/key, and change the value. So, we would solve the readability and the 23 columns limitation. The challenge now is how to transform this into a DataFrame with as much work as possible:

Map(
  "columnName1" -> value1,
  "columnName1" -> value2
)

I came up with the following implementation using Dijon JSON library:

import dijon.{arr, obj, SomeJson}
import org.apache.spark.sql.DataFrame

def createDataFrame(data: SomeJson): DataFrame = {
    import spark.implicits._
    spark.read.json(Seq(data.toString()).toDS)
}

And for multiple rows:

def createDataFrame(data: SomeJson*): DataFrame = {
    import spark.implicits._
    spark.read.json(Seq(arr(data: _*).toString()).toDS)
}

Then to create a DataFrame, we simply:

val input: DataFrame = createDataFrame(
      obj( // First row
        "destination" -> "morocco",
        "origin" -> "spain",
        "count" -> 3
      ),
      obj( // Second row
        "destination" -> "morocco",
        "origin" -> "egypt",
        "count" -> 5
      ),
      obj(
        "destination" -> "france",
        "origin" -> "germany",
        "count" -> 10
      )
    )

This technique is not free. The parsing is hundreds of milliseconds slower and the inference is not great, so maybe think about providing your own schema to the function.

Loading a CSV file

I am not crazy about loading a file for each table every time I want to run my tests. I want to have more control over the values directly in my code. No need for adding a hard drive bottleneck to worry about during development too!

A full test

Now a test should look something like this:


    val input: DataFrame = createDataFrame(
      obj(
        "destination" -> "morocco",
        "origin" -> "spain",
        "count" -> 3
      ),
      obj(
        "destination" -> "morocco",
        "origin" -> "egypt",
        "count" -> 5
      ),
      obj(
        "destination" -> "france",
        "origin" -> "germany",
        "count" -> 10
      )
    )

    val actual: DataFrame = getDestinationsWithTotalCount(input)
    val expected: DataFrame = createDataFrame(
        obj(
          "destination" -> "france",
          "total_count" -> 10
        ),
        obj(
          "destination" -> "morocco",
          "total_count" -> 8
        )
    )

    assertSmallDataFrameEquality(actual, expected, ignoreNullable = true, orderedComparison = false)

The assertSmallDataFrameEquality function is from spark-fast-tests library.

Conclusion

We just about scratched the realms of possibility of testing Spark. If you have any remarks or suggestions, please don’t hesitate to contact me. I am eager to discuss these ideas and build on them.