MLOps/Airflow

[Airflow] Custom Operator 만들기

monkeykim 2024. 11. 11. 21:04

Custom Operator가 필요한 이유

기본적으로 Airflow에는 PythonOperator, BashOperator, EmailOperator 등 자주 사용하는 작업에 대한 Operator가 준비되어 있지만, 모든 상황을 커버하지는 않습니다. 예를 들어, 외부 API 호출 후 데이터를 처리하는 작업이나 복잡한 데이터 파이프라인을 위한 연산이 필요할 때는 Custom Operator가 더욱 적합합니다.

Custom Operator를 통해 우리는 다음과 같은 이점을 얻을 수 있습니다.

  • 코드 재사용성: 특정 비즈니스 로직을 담은 Operator를 재사용하여 일관성 있게 워크플로우를 구축할 수 있습니다.
  • 코드 간소화: DAG에서 반복적인 코드를 줄이고, 명확한 역할을 가진 Operator를 만들어 가독성을 높일 수 있습니다.
  • 확장성: 필요에 맞게 기능을 확장하여 더욱 복잡한 작업을 구현할 수 있습니다.

Custom Operator 구현 단계

커스텀 Operator는 BaseOperator를 상속받아 만들 수 있습니다. 기본적인 구성 요소는 다음과 같습니다:

  1. Operator 생성: 새로운 클래스를 정의하고 BaseOperator를 상속받습니다.
  2. 파라미터 정의: Operator 실행에 필요한 파라미터를 정의합니다.
  3. execute 메서드 구현: Operator가 실제로 수행할 작업을 정의합니다.

1. Operator 생성

먼저 Custom Operator 클래스를 만듭니다. 여기서는 예시로 외부 API에서 데이터를 가져와 저장하는 ApiToDatabaseOperator를 만들어 보겠습니다.

from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
import requests

class ApiToDatabaseOperator(BaseOperator):
    
    @apply_defaults
    def __init__(self, api_url, db_connection, *args, **kwargs):
        super(ApiToDatabaseOperator, self).__init__(*args, **kwargs)
        self.api_url = api_url
        self.db_connection = db_connection

여기서 @apply_defaults 데코레이터를 통해 Operator의 초기화 메서드에서 기본 파라미터와 사용자가 입력한 파라미터를 관리할 수 있습니다

2. 파라미터 정의

__init__ 메서드에서 api_url과 db_connection이라는 파라미터를 정의하였습니다. 이 파라미터들은 Operator 실행 시 필요한 정보를 담고 있습니다.

 

3. Execute 메서드 구현

Custom Operator의 핵심은 execute 메서드입니다. 이 메서드는 DAG이 실행될 때 호출되며, Operator가 수행할 작업을 정의합니다.

def execute(self, context):
    self.log.info('Fetching data from API: %s', self.api_url)
    response = requests.get(self.api_url)
    if response.status_code != 200:
        self.log.error('Failed to fetch data')
        raise ValueError('API Request failed')
    
    data = response.json()
    
    self.log.info('Inserting data into database')
    # db_connection을 이용한 데이터베이스 저장 로직 추가
    with self.db_connection.cursor() as cursor:
        # 예시 데이터베이스 로직
        for record in data:
            cursor.execute("INSERT INTO my_table (field1, field2) VALUES (%s, %s)", (record['field1'], record['field2']))
    self.db_connection.commit()
    self.log.info('Data insertion complete')

전체 코드 예시

이제 Custom Operator의 전체 코드를 정리해보겠습니다.

from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
import requests

class ApiToDatabaseOperator(BaseOperator):

    @apply_defaults
    def __init__(self, api_url, db_connection, *args, **kwargs):
        super(ApiToDatabaseOperator, self).__init__(*args, **kwargs)
        self.api_url = api_url
        self.db_connection = db_connection

    def execute(self, context):
        self.log.info('Fetching data from API: %s', self.api_url)
        response = requests.get(self.api_url)
        if response.status_code != 200:
            self.log.error('Failed to fetch data')
            raise ValueError('API Request failed')
        
        data = response.json()
        
        self.log.info('Inserting data into database')
        with self.db_connection.cursor() as cursor:
            for record in data:
                cursor.execute("INSERT INTO my_table (field1, field2) VALUES (%s, %s)", (record['field1'], record['field2']))
        self.db_connection.commit()
        self.log.info('Data insertion complete')

DAG에서 Custom Operator 사용하기

이제 DAG에 ApiToDatabaseOperator를 추가해보겠습니다.

from airflow import DAG
from airflow.utils.dates import days_ago
from my_custom_operators import ApiToDatabaseOperator
import psycopg2

# DAG 기본 설정
default_args = {
    'owner': 'airflow',
    'retries': 1,
    'retry_delay': timedelta(minutes=5),
}

# 데이터베이스 연결 설정
db_connection = psycopg2.connect(
    host="localhost",
    database="mydatabase",
    user="myuser",
    password="mypassword"
)

with DAG(
    'api_to_database_dag',
    default_args=default_args,
    description='API에서 데이터를 받아 데이터베이스에 저장',
    schedule_interval=timedelta(days=1),
    start_date=days_ago(1),
) as dag:

    api_to_db = ApiToDatabaseOperator(
        task_id='api_to_database_task',
        api_url='https://example.com/api/data',
        db_connection=db_connection
    )

    api_to_db

이 DAG에서는 매일 예제 api에서 데이터를 가져와 데이터베이스에 저장하는 작업을 수행합니다. Custom Operator를 사용하여 코드의 가독성과 재사용성을 높였습니다.