MLOps/Airflow

[Airflow] BranchPythonOperator로 Task 분기 처리하기

monkeykim 2024. 11. 8. 00:49

데이터 파이프라인을 구성할 때, 상황에 따라 특정 Task만 실행해야 하는 경우가 자주 발생합니다. Airflow에서는 BranchPythonOperator를 사용하여 Task의 분기 처리를 할 수 있고 @task.branch 데코레이터와 BaseBranchOperator를 상속하여 직접 커스터마이징하는 방법도 존재합니다.

이 글에서는 세가지 방법을 사용하여 Task를 분기하는 방법을 코드 예제와 함께 설명하겠습니다.


BranchPythonOperator로 Task 분기 처리하기

BranchPythonOperator는 특정 조건에 따라 실행할 Task의 ID를 리턴하여 분기 처리를 수행합니다. 이때, 함수의 리턴값이 분기처리의 핵심입니다. BranchPythonOperator에서 리턴된 값이 후속 Task의 ID가 되므로, 올바른 ID를 반환해야만 원하는 Task가 실행됩니다.

분기처리 리턴 값 설정

  • 단일 Task 분기: 후속 Task가 하나라면, 해당 Task의 ID를 문자열로 반환합니다.
  • 다중 Task 분기: 후속 Task가 여러 개라면, 실행할 Task ID들이 담긴 리스트를 반환합니다.

예를 들어, 선택된 값이 'A'인 경우에는 task_a를 실행하고, 'B' 또는 'C'인 경우에는 task_b와 task_c를 동시에 실행하도록 설정할 수 있습니다.

 

BranchPythonOperator 예제 코드

from airflow import DAG
import pendulum
from airflow.operators.python import BranchPythonOperator, PythonOperator

# DAG 설정
with DAG(
    dag_id="dag_branch_python_operator",
    schedule="10 9 * * *", 
    start_date=pendulum.datetime(2024, 3, 1, tz="Asia/Seoul"),
    catchup=False
) as dag:
    
    # Task 분기 처리 함수
    def select_random():
        import random

        item_list = ['A', 'B', 'C']
        selected_item = random.choice(item_list)
        if selected_item == 'A':
            return 'task_a'
        elif selected_item in ['B', 'C']:
            return ['task_b', 'task_c']

    # BranchPythonOperator 설정
    python_branch_task = BranchPythonOperator(
        task_id='python_branch_task',
        python_callable=select_random
    )

    # 공통 작업을 수행하는 함수
    def common_func(**kwargs):
        print(kwargs['selected'])

    # 각 분기 Task 정의
    task_a = PythonOperator(
        task_id='task_a',
        python_callable=common_func,
        op_kwargs={'selected': 'A'}
    )
    
    task_b = PythonOperator(
        task_id='task_b',
        python_callable=common_func,
        op_kwargs={'selected': 'B'}
    )
    
    task_c = PythonOperator(
        task_id='task_c',
        python_callable=common_func,
        op_kwargs={'selected': 'C'}
    )

    # 분기 처리 작업 연결
    python_branch_task >> [task_a, task_b, task_c]
  1. select_random() 함수: 이 함수는 ‘A’, ‘B’, ‘C’ 중 하나의 값을 랜덤으로 선택하여 조건에 맞는 Task ID를 반환합니다.
    • 'A'가 선택되면 task_a의 ID인 'task_a'를 반환하여 단일 분기 처리를 합니다.
    • 'B' 또는 'C'가 선택되면 ['task_b', 'task_c'] 리스트를 반환하여 다중 분기 처리를 합니다.
  2. PythonOperator Tasks: common_func 함수는 분기 Task에서 공통으로 사용되며, 선택된 값을 출력합니다. 각 Task는 op_kwargs를 사용해 전달된 선택된 값을 출력하도록 설정되었습니다.
  3. 분기 Task 연결: python_branch_task >> [task_a, task_b, task_c]는 python_branch_task의 결과에 따라 후속 Task들이 분기되도록 설정합니다.

@task.branch로 분기 처리하기

@task.branch 데코레이터는 Python 함수를 Task로 정의하면서 분기처리를 구현할 수 있습니다. 이를 통해 분기 로직을 한눈에 파악할 수 있고 코드의 간결성도 높일 수 있습니다.

from airflow import DAG
import pendulum
from airflow.operators.python import PythonOperator
from airflow.decorators import task

with DAG(
    dag_id="dag_python_with_branch_decorator",
    schedule="10 9 * * *", 
    start_date=pendulum.datetime(2024, 3, 1, tz="Asia/Seoul"),
    catchup=False
) as dag:
    @task.branch(task_id='python_branch_task')
    def select_random():
        import random

        item_list = ['A', 'B', 'C']
        selected_item = random.choice(item_list)
        if selected_item == 'A':
            return 'task_a'
        elif selected_item in ['B', 'C']:
            return ['task_b', 'task_c']

    def common_func(**kwargs):
        print(kwargs['selected'])

    task_a = PythonOperator(
        task_id='task_a',
        python_callable=common_func,
        op_kwargs={'selected': 'A'}
    )
    
    task_b = PythonOperator(
        task_id='task_b',
        python_callable=common_func,
        op_kwargs={'selected': 'B'}
    )
    
    task_c = PythonOperator(
        task_id='task_c',
        python_callable=common_func,
        op_kwargs={'selected': 'C'}
    )

    select_random() >> [task_a, task_b, task_c]

이 코드에서 @task.branch를 사용하여 select_random 함수가 조건에 따라 Task를 분기합니다. 선택된 값에 따라 ‘A’일 경우 task_a만, ‘B’ 또는 ‘C’일 경우 task_b와 task_c가 동시에 실행됩니다. BranchPythonOperator와 유사하지만 데코레이터를 사용해 더 간결하게 표현할 수 있는 장점이 있습니다.


BaseBranchOperator로 분기 처리하기

분기 처리가 필요한 상황에서 BaseBranchOperator를 상속받아 커스터마이징하는 방법도 있습니다. BaseBranchOperator를 상속한 후 choose_branch 메서드를 오버라이드하여 원하는 분기 로직을 추가할 수 있습니다. 이 방법은 분기 로직이 더욱 복잡하거나 커스터마이징이 필요할 때 유용합니다.

from typing import Iterable
from airflow import DAG
from airflow.utils.context import Context
import pendulum
from airflow.operators.python import PythonOperator
from airflow.operators.branch import BaseBranchOperator

with DAG(
    dag_id="dag_base_branch_operator",
    schedule="10 9 * * *", 
    start_date=pendulum.datetime(2024, 3, 1, tz="Asia/Seoul"),
    catchup=False
) as dag:
    class CustomBranchOperator(BaseBranchOperator):
        def choose_branch(self, context: Context):  # choose_branch를 오버라이딩해야 함
            import random

            print(context)

            item_list = ['A', 'B', 'C']
            selected_item = random.choice(item_list)
            if selected_item == 'A':
                return 'task_a'
            elif selected_item in ['B', 'C']:
                return ['task_b', 'task_c']            
            
    custom_branch_operator = CustomBranchOperator(task_id='python_branch_task')

    def common_func(**kwargs):
        print(kwargs['selected'])

    task_a = PythonOperator(
        task_id='task_a',
        python_callable=common_func,
        op_kwargs={'selected': 'A'}
    )
    
    task_b = PythonOperator(
        task_id='task_b',
        python_callable=common_func,
        op_kwargs={'selected': 'B'}
    )
    
    task_c = PythonOperator(
        task_id='task_c',
        python_callable=common_func,
        op_kwargs={'selected': 'C'}
    )

    custom_branch_operator >> [task_a, task_b, task_c]

위 예제에서 CustomBranchOperator는 choose_branch 메서드를 오버라이드하여 ‘A’, ‘B’, ‘C’ 중 선택된 값에 따라 분기합니다. 이처럼 BaseBranchOperator를 상속받아 조건을 정의하고 분기하는 방식은 복잡한 분기처리가 필요한 경우에 적합합니다.


세가지 분기 처리의 사용법과 결과 값은 모두 다르지만, 분기를 실행하는 로직은 동일합니다. 아래 이미지는 예제 코드들을 graph로 표현한 그림입니다. 분기 처리 시 선택된 것은 붉게, 선택되지 않은 것은 회색 처리가 되어있는 것을 확인 할 수 있습니다. 이번 예제에서는 selected_item에 'B'와 'C' 둘 중에 하나가 나와 task_b와 task_c가 실행되어 붉게 처리 되었고, task_a는 skip이 되어 회색처리가 되었습니다.

뿐만 아니라 xcom도 동일하게 출력되는 것을 확인할 수 있습니다. xcom에서 return_value 외에도 skipmixin_key라는 key가 생긴 것을 확인할 수 있습니다. 이것은 후행으로 실행시킬 Task id를 보여줍니다.