Decorating workflows
The behavior of workflows can be modified in a lightweight fashion by using the built-in functools.wraps
decorator pattern, similar to using decorators to
customize task behavior. However, unlike in the case of
tasks, we need to do a little extra work to make sure that the DAG underlying the workflow executes tasks in the correct order.
Setup-teardown pattern
The main use case of decorating @fl.workflow
-decorated functions is to establish a setup-teardown pattern to execute task
before and after your main workflow logic. This is useful when integrating with other external services
like wandb or clearml, which enable you to track metrics of model training runs.
To begin, create a file called decorating_workflows
.
Import the necessary libraries:
from functools import partial, wraps
from unittest.mock import MagicMock
import flytekit as fl
from flytekit import FlyteContextManager
from flytekit.core.node_creation import create_node
Let’s define the tasks we need for setup and teardown. In this example, we use the
unittest.mock.MagicMock
class to create a fake external service that we want to initialize at the
beginning of our workflow and finish at the end.
external_service = MagicMock()
@fl.task
def setup():
print("initializing external service")
external_service.initialize(id=flytekit.current_context().execution_id)
@fl.task
def teardown():
print("finish external service")
external_service.complete(id=flytekit.current_context().execution_id)
As you can see, you can even use Flytekit’s current context to access the execution_id
of the current workflow
if you need to link Flyte with the external service so that you reference the same unique identifier in both the
external service and Flyte.
Workflow decorator
We create a decorator that we want to use to wrap our workflow function.
def setup_teardown(fn=None, *, before, after):
@wraps(fn)
def wrapper(*args, **kwargs):
# get the current flyte context to obtain access to the compilation state of the workflow DAG.
ctx = FlyteContextManager.current_context()
# defines before node
before_node = create_node(before)
# ctx.compilation_state.nodes == [before_node]
# under the hood, flytekit compiler defines and threads
# together nodes within the `my_workflow` function body
outputs = fn(*args, **kwargs)
# ctx.compilation_state.nodes == [before_node, *nodes_created_by_fn]
# defines the after node
after_node = create_node(after)
# ctx.compilation_state.nodes == [before_node, *nodes_created_by_fn, after_node]
# compile the workflow correctly by making sure `before_node`
# runs before the first workflow node and `after_node`
# runs after the last workflow node.
if ctx.compilation_state is not None:
# ctx.compilation_state.nodes is a list of nodes defined in the
# order of execution above
workflow_node0 = ctx.compilation_state.nodes[1]
workflow_node1 = ctx.compilation_state.nodes[-2]
before_node >> workflow_node0
workflow_node1 >> after_node
return outputs
if fn is None:
return partial(setup_teardown, before=before, after=after)
return wrapper
There are a few key pieces to note in the setup_teardown
decorator above:
- It takes a
before
andafter
argument, both of which need to be@fl.task
-decorated functions. These tasks will run before and after the main workflow function body. - The create_node function
to create nodes associated with the
before
andafter
tasks. - When
fn
is called, under the hood the system creates all the nodes associated with the workflow function body - The code within the
if ctx.compilation_state is not None:
conditional is executed at compile time, which is where we extract the first and last nodes associated with the workflow function body at index1
and-2
. - The
>>
right shift operator ensures thatbefore_node
executes before the first node andafter_node
executes after the last node of the main workflow function body.
Defining the DAG
We define two tasks that will constitute the workflow.
@fl.task
def t1(x: float) -> float:
return x - 1
@fl.task
def t2(x: float) -> float:
return x**2
And then create our decorated workflow:
@fl.workflow
@setup_teardown(before=setup, after=teardown)
def decorating_workflow(x: float) -> float:
return t2(x=t1(x=x))
Run the example on the Flyte cluster
To run the provided workflow on the Flyte cluster, use the following command:
pyflyte run --remote \
https://raw.githubusercontent.com/flyteorg/flytesnacks/69dbe4840031a85d79d9ded25f80397c6834752d/examples/advanced_composition/advanced_composition/decorating_workflows.py \
decorating_workflow --x 10.0