2.4 Branch Operator
The BranchOperator
is used to decide which path to run based on the input data.
Example, if you have two paths, you can decide which path to run based on the input data.
There are two ways to use the BranchOperator
:
Build A BranchOperator
With A Branch Mapping
Pass a dictionary of branch functions and task names to the BranchOperator
constructor.
from dbgpt.core.awel import DAG, BranchOperator, MapOperator
def branch_even(x: int) -> bool:
return x % 2 == 0
def branch_odd(x: int) -> bool:
return not branch_even(x)
branch_mapping = {
branch_even: "even_task",
branch_odd: "odd_task"
}
with DAG("awel_branch_operator") as dag:
task = BranchOperator(branches=branch_mapping)
even_task = MapOperator(
task_name="even_task",
map_function=lambda x: print(f"{x} is even")
)
odd_task = MapOperator(
task_name="odd_task",
map_function=lambda x: print(f"{x} is odd")
)
In above example, the BranchOperator
has two child tasks, even_task
and odd_task
.
The BranchOperator
will decide which child task to run based on the input data.
So we pass a dictionary of branch functions and task names to the BranchOperator
constructor to define the branch mapping, in dictionary, the key is the branch function,
and the value is the task name, when run the branch task, all the branch function will
be executed, if the branch function return True
, the task will be executed, otherwise,
it will be skipped.
Implement A Custom BranchOperator
Just override the branches
method to return a dictionary of branch functions and task names.
from dbgpt.core.awel import DAG, BranchOperator, MapOperator
def branch_even(x: int) -> bool:
return x % 2 == 0
def branch_odd(x: int) -> bool:
return not branch_even(x)
class MyBranchOperator(BranchOperator[int]):
def __init__(self, even_task_name: str, odd_task_name: str, **kwargs):
self.even_task_name = even_task_name
self.odd_task_name = odd_task_name
super().__init__(**kwargs)
async def branches(self):
return {
branch_even: self.even_task_name,
branch_odd: self.odd_task_name
}
with DAG("awel_branch_operator") as dag:
task = MyBranchOperator(even_task_name="even_task", odd_task_name="odd_task")
even_task = MapOperator(
task_name="even_task",
map_function=lambda x: print(f"{x} is even")
)
odd_task = MapOperator(
task_name="odd_task",
map_function=lambda x: print(f"{x} is odd")
)
Examples
Even Or Odd
Create a new file named branch_operator_even_or_odd.py
in the awel_tutorial
directory and add the following code:
import asyncio
from dbgpt.core.awel import (
DAG, BranchOperator, MapOperator, JoinOperator,
InputOperator, SimpleCallDataInputSource,
is_empty_data
)
def branch_even(x: int) -> bool:
return x % 2 == 0
def branch_odd(x: int) -> bool:
return not branch_even(x)
branch_mapping = {
branch_even: "even_task",
branch_odd: "odd_task"
}
def even_func(x: int) -> int:
print(f"Branch even, {x} is even, multiply by 10")
return x * 10
def odd_func(x: int) -> int:
print(f"Branch odd, {x} is odd, multiply by itself")
return x * x
def combine_function(x: int, y: int) -> int:
print(f"Received {x} and {y}")
# Return the first non-empty data
return x if not is_empty_data(x) else y
with DAG("awel_branch_operator") as dag:
input_task = InputOperator(input_source=SimpleCallDataInputSource())
task = BranchOperator(branches=branch_mapping)
even_task = MapOperator(task_name="even_task", map_function=even_func)
odd_task = MapOperator(task_name="odd_task", map_function=odd_func)
join_task = JoinOperator(combine_function=combine_function, can_skip_in_branch=False)
input_task >> task >> even_task >> join_task
input_task >> task >> odd_task >> join_task
print("First call, input is 5")
assert asyncio.run(join_task.call(call_data=5)) == 25
print("=" * 80)
print("Second call, input is 6")
assert asyncio.run(join_task.call(call_data=6)) == 60
Note: can_skip_in_branch
is used to control whether current task can be skipped in the branch.
Set it to False
to prevent the task from being skipped.
And run the following command to execute the code:
poetry run python awel_tutorial/branch_operator_even_or_odd.py
And you will see the following output printed to the console.
First call, input is 5
Branch odd, 5 is odd, multiply by itself
Received EmptyData(SKIP_DATA) and 25
================================================================================
Second call, input is 6
Branch even, 6 is even, multiply by 10
Received 60 and EmptyData(SKIP_DATA)
The graph of the DAG is like this:
In above example, the BranchOperator
has two child tasks, even_task
and odd_task
,
it will decide which child task to run based on the input data and the branches mapping.
We also use the JoinOperator
to combine the data from both child tasks, if a path is
skipped, the JoinOperator
will receive an EmptyData(SKIP_DATA)
as input data, and we
can use dbgpt.core.awel.is_empty_data
to check if the data is empty data.