跳到主要内容
版本:v0.6.3

2.4 Branch Operator

The BranchOperator is used to decide which path to run based on the input data. Example, if you have two paths, you can decide which path to run based on the input data.

There are two ways to use the BranchOperator:

Build A BranchOperator With A Branch Mapping

Pass a dictionary of branch functions and task names to the BranchOperator constructor.

from dbgpt.core.awel import DAG, BranchOperator, MapOperator

def branch_even(x: int) -> bool:
return x % 2 == 0

def branch_odd(x: int) -> bool:
return not branch_even(x)

branch_mapping = {
branch_even: "even_task",
branch_odd: "odd_task"
}

with DAG("awel_branch_operator") as dag:
task = BranchOperator(branches=branch_mapping)
even_task = MapOperator(
task_name="even_task",
map_function=lambda x: print(f"{x} is even")
)
odd_task = MapOperator(
task_name="odd_task",
map_function=lambda x: print(f"{x} is odd")
)

In above example, the BranchOperator has two child tasks, even_task and odd_task. The BranchOperator will decide which child task to run based on the input data. So we pass a dictionary of branch functions and task names to the BranchOperator constructor to define the branch mapping, in dictionary, the key is the branch function, and the value is the task name, when run the branch task, all the branch function will be executed, if the branch function return True, the task will be executed, otherwise, it will be skipped.

Implement A Custom BranchOperator

Just override the branches method to return a dictionary of branch functions and task names.

from dbgpt.core.awel import DAG, BranchOperator, MapOperator

def branch_even(x: int) -> bool:
return x % 2 == 0

def branch_odd(x: int) -> bool:
return not branch_even(x)

class MyBranchOperator(BranchOperator[int]):
def __init__(self, even_task_name: str, odd_task_name: str, **kwargs):
self.even_task_name = even_task_name
self.odd_task_name = odd_task_name
super().__init__(**kwargs)

async def branches(self):
return {
branch_even: self.even_task_name,
branch_odd: self.odd_task_name
}

with DAG("awel_branch_operator") as dag:
task = MyBranchOperator(even_task_name="even_task", odd_task_name="odd_task")
even_task = MapOperator(
task_name="even_task",
map_function=lambda x: print(f"{x} is even")
)
odd_task = MapOperator(
task_name="odd_task",
map_function=lambda x: print(f"{x} is odd")
)

Examples

Even Or Odd

Create a new file named branch_operator_even_or_odd.py in the awel_tutorial directory and add the following code:

import asyncio
from dbgpt.core.awel import (
DAG, BranchOperator, MapOperator, JoinOperator,
InputOperator, SimpleCallDataInputSource,
is_empty_data
)

def branch_even(x: int) -> bool:
return x % 2 == 0

def branch_odd(x: int) -> bool:
return not branch_even(x)

branch_mapping = {
branch_even: "even_task",
branch_odd: "odd_task"
}

def even_func(x: int) -> int:
print(f"Branch even, {x} is even, multiply by 10")
return x * 10

def odd_func(x: int) -> int:
print(f"Branch odd, {x} is odd, multiply by itself")
return x * x

def combine_function(x: int, y: int) -> int:
print(f"Received {x} and {y}")
# Return the first non-empty data
return x if not is_empty_data(x) else y

with DAG("awel_branch_operator") as dag:
input_task = InputOperator(input_source=SimpleCallDataInputSource())
task = BranchOperator(branches=branch_mapping)
even_task = MapOperator(task_name="even_task", map_function=even_func)
odd_task = MapOperator(task_name="odd_task", map_function=odd_func)
join_task = JoinOperator(combine_function=combine_function, can_skip_in_branch=False)
input_task >> task >> even_task >> join_task
input_task >> task >> odd_task >> join_task

print("First call, input is 5")
assert asyncio.run(join_task.call(call_data=5)) == 25
print("=" * 80)
print("Second call, input is 6")
assert asyncio.run(join_task.call(call_data=6)) == 60

Note: can_skip_in_branch is used to control whether current task can be skipped in the branch. Set it to False to prevent the task from being skipped.

And run the following command to execute the code:

poetry run python awel_tutorial/branch_operator_even_or_odd.py

And you will see the following output printed to the console.

First call, input is 5
Branch odd, 5 is odd, multiply by itself
Received EmptyData(SKIP_DATA) and 25
================================================================================
Second call, input is 6
Branch even, 6 is even, multiply by 10
Received 60 and EmptyData(SKIP_DATA)

The graph of the DAG is like this:

In above example, the BranchOperator has two child tasks, even_task and odd_task, it will decide which child task to run based on the input data and the branches mapping.

We also use the JoinOperator to combine the data from both child tasks, if a path is skipped, the JoinOperator will receive an EmptyData(SKIP_DATA) as input data, and we can use dbgpt.core.awel.is_empty_data to check if the data is empty data.