Basic Schema Example

Once you have a Union account, install union:

pip install union

Export the following environment variable to build and push images to your own container registry:

# replace with your registry name
export IMAGE_SPEC_REGISTRY="<your-container-registry>"

Then run the following commands to run the workflow:

git clone https://github.com/unionai/unionai-examples
cd unionai-examples
union run --remote tutorials/sentiment_classifier/sentiment_classifier.py main --model distilbert-base-uncased

The source code for this tutorial can be found here {octicon}mark-github.

In this example we’ll show you how to use {ref}pandera.DataFrameModel <pandera:dataframe-models> to annotate dataframe inputs and outputs in your flyte tasks.

import typing

import flytekitplugins.pandera  # noqa : F401
import pandas as pd
import pandera as pa
from flytekit import ImageSpec, task, workflow
from pandera.typing import DataFrame, Series

custom_image = ImageSpec(registry="ghcr.io/flyteorg", packages=["flytekitplugins-pandera", "scikit-learn", "pyarrow"])

A Simple Data Processing Pipeline

Let’s first define a simple data processing pipeline in pure python.

def total_pay(df):
    return df.assign(total_pay=df.hourly_pay * df.hours_worked)


def add_id(df, worker_id):
    return df.assign(worker_id=worker_id)


def process_data(df, worker_id):
    return add_id(df=total_pay(df=df), worker_id=worker_id)

As you can see, the process_data function is composed of two simpler functions: One that computes total_pay and another that simply adds an id column to a pandas dataframe.

Defining DataFrame Schemas

Next we define the schemas that provide type and statistical annotations for the raw, intermediate, and final outputs of our pipeline.

class InSchema(pa.DataFrameModel):
    hourly_pay: Series[float] = pa.Field(ge=7)
    hours_worked: Series[float] = pa.Field(ge=10)

    @pa.check("hourly_pay", "hours_worked")
    def check_numbers_are_positive(cls, series: Series) -> Series[bool]:
        """Defines a column-level custom check."""
        return series > 0

    class Config:
        coerce = True


class IntermediateSchema(InSchema):
    total_pay: Series[float]

    @pa.dataframe_check
    def check_total_pay(cls, df: DataFrame) -> Series[bool]:
        """Defines a dataframe-level custom check."""
        return df["total_pay"] == df["hourly_pay"] * df["hours_worked"]


class OutSchema(IntermediateSchema):
    worker_id: Series[str] = pa.Field()

Columns are specified as class attributes with a specified data type using the type-hinting syntax, and you can place additional statistical constraints on the values of each column using pandera.api.pandas.model_components.Field. You can also define custom validation functions by decorating methods with pandera.api.pandas.model_components.check (column-level checks) or pandera.api.pandas.model_components.dataframe_check (dataframe-level checks), which automatically make them class methods. Pandera uses inheritance to make sure that pandera.api.pandas.model.DataFrameModel subclasses contain all of the same columns and custom check methods as their base class. Inheritance semantics apply to schema models so you can override column attributes or check methods in subclasses. This has the nice effect of providing an explicit graph of type dependencies as data flows through the various tasks in your workflow.

Type Annotating Tasks and Workflows

Finally, we can turn our data processing pipeline into a Flyte workflow by decorating our functions with the flytekit.task and flytekit.workflow decorators and annotating the inputs and outputs of those functions with the pandera schemas:

@task(container_image=custom_image)
def dict_to_dataframe(data: dict) -> DataFrame[InSchema]:
    """Helper task to convert a dictionary input to a dataframe."""
    return pd.DataFrame(data)


@task(container_image=custom_image)
def total_pay(df: DataFrame[InSchema]) -> DataFrame[IntermediateSchema]:  # noqa : F811
    return df.assign(total_pay=df.hourly_pay * df.hours_worked)


@task(container_image=custom_image)
def add_ids(df: DataFrame[IntermediateSchema], worker_ids: typing.List[str]) -> DataFrame[OutSchema]:
    return df.assign(worker_id=worker_ids)


@workflow
def process_data(  # noqa : F811
    data: dict = {
        "hourly_pay": [12.0, 13.5, 10.1],
        "hours_worked": [30.5, 40.0, 41.75],
    },
    worker_ids: typing.List[str] = ["a", "b", "c"],
) -> DataFrame[OutSchema]:
    return add_ids(df=total_pay(df=dict_to_dataframe(data=data)), worker_ids=worker_ids)


if __name__ == "__main__":
    print(f"Running {__file__} main...")
    result = process_data(
        data={"hourly_pay": [12.0, 13.5, 10.1], "hours_worked": [30.5, 40.0, 41.75]},
        worker_ids=["a", "b", "c"],
    )
    print(f"Running wf(), returns dataframe\n{result}\n{result.dtypes}")

Now your workflows and tasks are guarded against unexpected data at runtime!