Ask AI

Validating Data with Dagster Type Factories#

This guide uses Pandas dataframes and the Pandera validation library to illustrate how to integrate arbitrary data validation into Dagster. Users looking for a simpler out-of-the-box validation solution should check out Dagster's dedicated integration library dagster-pandera.

Dagster is designed to interoperate easily with other components of your stack. In this guide we'll illustrate a technique for integrating third-party data validation libraries into your Dagster jobs: the Dagster Type factory pattern.

Dagster types perform arbitrary runtime validation on data as it flows between ops. If we apply standard Python type annotations to an op compute function, Dagster will infer corresponding Dagster types. However, standard Python type annotations are of limited expressiveness. To specify arbitrary validation logic, we must define custom Dagster types. Here is a type that asserts an object is a set containing the number 1 (this is not expressible with standard Python types):

from dagster import DagsterType, check_dagster_type

set_containing_1 = DagsterType(
    name="set_containing_1",
    description="A set containing the value 1. May contain any other values.",
    type_check_fn=lambda _context, obj: isinstance(obj, set) and 1 in obj,
)

assert check_dagster_type(set_containing_1, {1, 2}).success  # => True

Sometimes, we may even need to define an entire suite of related types. Factory functions are an efficient solution to this problem. Here is a factory function that generalizes the set-membership-check pattern for arbitrary values:

def set_has_element_type_factory(x):
    return DagsterType(
        name=f"set_containing_{x}",
        description=f"A set containing the value {x}. May contain any other values.",
        type_check_fn=lambda _context, obj: isinstance(obj, set) and x in obj,
    )


set_containing_2 = set_has_element_type_factory(2)
assert check_dagster_type(set_containing_2, {1, 2}).success  # => True

Factories are particularly powerful when interfacing with external data validation libraries. Below, we'll see how a small factory function allows us to integrate the Pandera dataframe validation library into our Dagster jobs. Pandera provides an API for defining dataframe schemas. Schemas encode data type, boundaries, nullability, and other constraints at both the dataframe and individual column level. They can be used to validate dataframes from Pandas, Dask, or several other popular Python libraries.

Setup#

Suppose we are data scientists working for an e-bike rental company. We've been tasked with setting up some pipelines to analyze trip data, and have been experimenting with implementing them in Dagster. We're currently looking at the distribution of trip lengths, so we've exported a slice of trips from October into a simple CSV file with 3 columns: bike_id, start_time, and end_time.

 bike_id            start_time              end_time
       9   2021-10-20 10:39:00   2021-10-20 12:29:00
       9   2021-10-14 00:15:00   2021-10-14 00:33:00
       5   2021-10-02 12:57:00   2021-10-02 13:23:00
       1   2021-10-28 00:01:00   2021-10-28 00:34:00
       8   2021-10-23 14:22:00   2021-10-23 15:26:00
...

The first thing we want to do is visualize the distribution of trip lengths. We set up a simple Dagster job with two ops: load_trips reads a CSV with our data, and generate_plot creates a plot of the distribution using the Matplotlib plotting library.

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from dagster import AssetMaterialization, OpExecutionContext, job, op


@op
def load_trips():
    return pd.read_csv(
        "./ebike_trips.csv",
        parse_dates=["start_time", "end_time"],
    )


@op
def generate_plot(context: OpExecutionContext, trips):
    minute_lengths = [x.total_seconds() / 60 for x in trips.end_time - trips.start_time]
    bin_edges = np.histogram_bin_edges(minute_lengths, 15)
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.set(title="Trip lengths", xlabel="Minutes", ylabel="Count")
    ax.hist(minute_lengths, bins=bin_edges)
    fig.savefig("trip_lengths.png")
    context.log_event(
        AssetMaterialization(
            asset_key="trip_dist_plot", description="Distribution of trip lengths."
        )
    )


@job
def generate_trip_distribution_plot():
    generate_plot(load_trips())

If we try running this job, our np.histogram_bin_edges call in generate_plot throws an error:

from .job_1 import generate_trip_distribution_plot

generate_trip_distribution_plot.execute_in_process()
# => ...
# => ValueError: autodetected range of [nan, nan] is not finite

Validating dataframes with Pandera#

Since our logic looks OK, the problem likely originates in an incorrect assumption about the source data. To investigate, we write a Pandera schema. The schema defines a set of dataframe columns with corresponding datatypes. Pandera columns are by default required to exist and non-nullable, so there is also an implicit expectation that columns with the specified names exist, and that all values are defined (NOTE: this illustrates only a small slice of Pandera's functionality):

import pandas as pd
import pandera as pa

MIN_DATE = pd.Timestamp("2021-10-01")

df = pd.read_csv(
    "./ebike_trips.csv",
    parse_dates=["start_time", "end_time"],
)

trips_schema = pa.DataFrameSchema(
    columns={
        "bike_id": pa.Column(
            int, checks=pa.Check.ge(0)
        ),  # ge: greater than or equal to
        "start_time": pa.Column(pd.Timestamp, checks=pa.Check.ge(MIN_DATE)),
        "end_time": pa.Column(pd.Timestamp, checks=pa.Check.ge(MIN_DATE)),
    },
)

Let's try validating our data:

from .schema import df, trips_schema

trips_schema.validate(df)
# => SchemaError: non-nullable series 'end_time' contains null values:
# => 22   NaT
# => 43   NaT

Our call to trips_schema.validate has found the culprit. Somehow, null values have crept into our end_time column-- maybe some customers never returned their bikes...

Dagster type factory#

We are now facing a common situation in data science-- dirty data. This can cause puzzling bugs if the data travels a significant distance in a pipeline before causing a problem. In our example, it triggered an obtuse numpy error.

We can prevent these situations by validating our data before passing it downstream. There are several ways to solve this problem in Dagster. We could inline the schema and validation call in our load_trips op. Alternatively, we could break validation into a dedicated op that sits in between load_trips and any downstream computations. However, both of these approaches complicate our logic at either the graph or single op level. If this validation strategy were scaled up across all present and future analysis jobs, it could quickly become a headache.

Dagster Types offer a more robust and elegant solution. Every op input and output can have a Dagster type that encapsulates runtime validation checks. Dagster automatically runs these checks during job execution. So if we can wrap our trip dataframe schema in a Dagster type, then we won't need to further complicate our graphs or individual ops with explicit validation calls and error handling-- Dagster's machinery will take care of this for us.

Fortunately, it's easy to write a factory function that will wrap any Pandera schema in a unique Dagster type:

import pandas as pd
import pandera as pa

from dagster import DagsterType, TypeCheck


def pandera_schema_to_dagster_type(schema, name, description):
    def type_check_fn(_context, value):
        if not isinstance(value, pd.DataFrame):
            return TypeCheck(
                success=False,
                description=f"Must be pandas.DataFrame, not {type(value).__name__}.",
            )
        try:
            # `lazy` instructs pandera to capture every (not just the first) validation error
            schema.validate(value, lazy=True)
        except pa.errors.SchemaErrors as e:
            return TypeCheck(
                success=False,
                description=str(e),
                metadata={
                    "num_violations": len(e.failure_cases),
                },
            )

        return TypeCheck(success=True)

    return DagsterType(
        type_check_fn=type_check_fn,
        name=name,
        description=description,
    )

Let's unpack the above code. Our factory function is pandera_schema_to_dagster_type. It's purpose is to construct a custom type check function (type_check_fn) and corresponding DagsterType.

Inside type_check_fn, we let pandera do the heavy lifting. After confirming value is in fact a DataFrame, we call schema.validate and tack the message of any resulting SchemaErrors instance onto our returned TypeCheck instance. TypeCheck is an optional rich return type for Dagster type check functions (it is also possible to return a boolean indicating a simple pass/fail). A TypeCheck is appropriate here because we want to capture specific information about any validation failures.

Notice also that we provide a metadata dictionary to the TypeCheck constructor when validation fails. This can hold arbitrary information that is visible in the Dagster UI and can be useful in debugging. Here we've only provided a count num_violations of the schema violations (SchemaErrors.failure_cases is a dataframe tabulating each violation), but we could easily include richer information to annotate more complex validations.

Annotate our job with Dagster types#

Now that we can turn Pandera schemas into Dagster types, we only need to tweak our job definition a little bit to integrate Pandera validation.

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pandera as pa

from dagster import AssetMaterialization, Field, In, OpExecutionContext, Out, job, op

from .factory import pandera_schema_to_dagster_type

MIN_DATE = pd.Timestamp("2021-10-01")

trips_schema = pa.DataFrameSchema(
    columns={
        "bike_id": pa.Column(
            int, checks=pa.Check.ge(0)
        ),  # ge: greater than or equal to
        "start_time": pa.Column(pd.Timestamp, checks=pa.Check.ge(MIN_DATE)),
        "end_time": pa.Column(pd.Timestamp, checks=pa.Check.ge(MIN_DATE)),
    },
)

# This is a Dagster type that wraps the schema
TripsDataFrame = pandera_schema_to_dagster_type(
    trips_schema, "TripsDataFrame", "DataFrame type for e-bike trips."
)


# We've added a Dagster type for this op's output
@op(out=Out(TripsDataFrame), config_schema={"clean": Field(bool, False)})
def load_trips(context: OpExecutionContext):
    df = pd.read_csv(
        "./ebike_trips.csv",
        parse_dates=["start_time", "end_time"],
    )
    if context.op_config["clean"]:
        df = df[pd.notna(df.end_time)]
    return df


# We've added a Dagster type for this op's input
@op(ins={"trips": In(TripsDataFrame)})
def generate_plot(context: OpExecutionContext, trips):
    minute_lengths = [x.total_seconds() / 60 for x in trips.end_time - trips.start_time]
    bin_edges = np.histogram_bin_edges(minute_lengths, 15)
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.set(title="Trip lengths", xlabel="Minutes", ylabel="Count")
    ax.hist(minute_lengths, bins=bin_edges)
    fig.savefig("trip_lengths.png")
    context.log_event(
        AssetMaterialization(
            asset_key="trip_dist_plot", description="Distribution of trip lengths."
        )
    )


@job
def generate_trip_distribution_plot():
    generate_plot(load_trips())

The above job definition makes two modifications to the previous one. First, we've added ins/out arguments to the op decorators. We specify that load_trips must return a TripsDataFrame as output, and generate_plot must take a TripsDataFrame as input. Now when we execute the generate_trip_distribution_plot job, Dagster will automatically apply the TripsDataFrame validation to our dataframe before it enters generate_plot, and terminate execution if the validation fails.

Second, we've added a clean configuration parameter to load_trips. When clean is true, load_trips strips out records with missing end dates before outputting a dataframe.

Let's see what happens when we run without setting clean:

generate_trip_distribution_plot.execute_in_process()
# => ...
# => dagster._core.errors.DagsterTypeCheckDidNotPass: Type check failed for step output "result" - expected type "TripsDataFrame".
# => ...

Just as we expect, Dagster's runtime validation caught the error and raised a DagsterTypeCheckDidNotPass exception. Now let's try with clean set. Stripping out the offending records should allow our dataframe to pass the TripsDataframe validation:

generate_trip_distribution_plot.execute_in_process(
    run_config={"ops": {"load_trips": {"config": {"clean": True}}}}
)
# => ...
# => 2021-11-11 19:54:26 - dagster - DEBUG - generate_trip_plots - 3e00e9e3-27f3-490e-b1bd-ec17b92e5599 - 28168 - RUN_SUCCESS - Finished execution of run for "generate_trip_plots".

Success! We also find that our histogram was successfully rendered to ./trip_lengths.png.

trip_lengths.png

Conclusion#

In this guide, we saw how to use Dagster type factories to generate collections of Dagster types sharing a common validation pattern. We further saw how a factory can smoothly integrate a third-party data validation library with Dagster.

We used Pandas and Pandera as an example, but the same technique can be used for any data validation library or logic of your own. The basic pattern is to wrap validation entities in Dagster Types. Use of a factory function achieves this with little boilerplate. Inside the factory, you are free to translate validation results into TypeCheck objects however you like. In sum, Dagster type factories can imbue your jobs with arbitrary runtime data validation at the cost of very little code.