Running PySpark code in solids

You can find the code for this example on Github.

Passing PySpark DataFrames between solids requires a little bit of extra care, for a couple reasons:

  • Spark has a lazy execution model, which means that PySpark won't process any data until an action like write or collect is called on a DataFrame.
  • PySpark DataFrames cannot be pickled, which means that IO Managers like the fs_io_manager won't work for them.

In this example, we've defined an Asset Store that knows how to store and retrieve PySpark DataFrames that are produced and consumed by solids.

This example assumes that all the outputs within the pipeline will be PySpark DataFrames and stored in the same way. To learn how to use different IO managers for different outputs within the same pipeline, take a look at the IO Manager overview.

This example writes out DataFrames to the local file system, but can be tweaked to write to cloud object stores like S3 by changing to the write and read invocations.

repo.py
import os

from dagster import IOManager, ModeDefinition, io_manager, pipeline, repository, solid
from pyspark.sql import Row, SparkSession
from pyspark.sql.types import IntegerType, StringType, StructField, StructType


class LocalParquetStore(IOManager):
    def _get_path(self, context):
        return os.path.join(context.run_id, context.step_key, context.name)

    def handle_output(self, context, obj):
        obj.write.parquet(self._get_path(context))

    def load_input(self, context):
        spark = SparkSession.builder.getOrCreate()
        return spark.read.parquet(self._get_path(context.upstream_output))


@io_manager
def local_parquet_store(_):
    return LocalParquetStore()


@solid
def make_people(_):
    schema = StructType([StructField("name", StringType()), StructField("age", IntegerType())])
    rows = [Row(name="Thom", age=51), Row(name="Jonny", age=48), Row(name="Nigel", age=49)]
    spark = SparkSession.builder.getOrCreate()
    return spark.createDataFrame(rows, schema)


@solid
def filter_over_50(_, people):
    return people.filter(people["age"] > 50)


@pipeline(mode_defs=[ModeDefinition(resource_defs={"io_manager": local_parquet_store})])
def my_pipeline():
    filter_over_50(make_people())

Open in a playground

Open in Gitpod

Download

curl https://codeload.github.com/dagster-io/dagster/tar.gz/master | tar -xz --strip=2 dagster-master/examples/basic_pyspark
cd basic_pyspark