Coding/설계 | 경험

FastAPI에서 SQLAlchemy Session 다루는 방법

Hide­ 2022. 3. 23. 19:57
반응형

개요

FastAPI는 비동기 프레임워크이다. 내부적으로는 Starlette를 Wrapping한 프레임워크인데 일반적인 동기 프레임워크인 Flask, Django등과는 다른 패러다임을 가지고 있다. 그에 따라 프로덕션 레벨에서 서비스를 운영하며 많은 트러블슈팅이 있었고 그 중 데이터베이스와 통신을 위해 사용하는 SQLAlchemy 라이브러리 또한 큰 이슈가 발생했었다.

기존에는 SQLAlchemy 1.3버전을 사용하고 있었기 때문에 비동기를 지원하지 않았지만 1.4버전부터 AsyncSession을 통한 비동기를 지원하기 시작했고 현재 사내 코드를 마이그레이션 하기위해 개인적인 테스트를 진행하고 있다.

본 포스팅에서는 비동기 프레임워크에 SQLAlchemy를 올바르게 사용하기 위한 트러블 슈팅을 포함하여 최종적으로 적용했던 과정에 대해 기술한다.

트랜잭션

본문의 예제를 보다보면 @Transactional()이라는 데코레이터가 등장한다.

class Transactional:
    def __init__(self, propagation: Propagation = Propagation.REQUIRED):
        self.propagation = propagation
    
    def __call__(self, function): 
        @wraps(function) 
        async def decorator(*args, **kwargs):
            try:
                result = await function(*args, **kwargs)
                await session.commit()
            except Exception as e:
                await session.rollback() 
                raise e 
            finally: 
                await session.remove() 
            return result
    return decorator

위 코드는 DB Persist를 하나로 통합하여 트랜잭션을 이루기위해 사용되며, @Transactional() 데코레이터를 제외한 타 코드에서 명시적인 commit()은 사용하지 않는다.

Context-local

비동기 프레임워크는 코루틴 기반으로 동작한다. 코루틴은 하나의 스레드에서 Concurrency하게 돌아가기 때문에 작업들 간 서로 제어권을 주고받으며 동작한다. 멀티 스레드 또는 프로세스와는 달리 싱글 스레드에서 동작하며 각 작업들 간 서로 다른 Context를 가지게 된다. 이 때 각 Context별로 독립적인 환경을 가지는 것을 Context-local이라고 한다. Thread-local과 비교하며 생각해보면 쉽게 이해할 수 있을 것이다. 그렇다면 Context-local이 왜 중요할까? 예를 들어 아래와 같은 상황이 존재한다고 가정해보자.

Request-1이 들어옴과 동시에 Request-2가 들어온 상황이다. logic_with_sleep3()은 3초간의 sleep후 예외가 발생하는 메소드이고 logic()의 경우 별다른 sleep없이 정상적으로 처리를 끝내는 메소드이다. 만약 이럴 때 Context-local이 보장되지 않은 공유 객체를 사용한다면 어떻게 될까?

예를 들어 logic()에서는 SQLAlchemy의 session.add()를 사용하여 특정 모델을 session에 등록한 상황이라고 생각해보자. logic()에서 정상적으로 데이터를 추가했지만 logic_with_sleep3()에서 예외가 발생했기 때문에 exception propogation으로 인해 logic()또한 정상적인 DB persist작업이 완료될 수 없다. 트랜잭션은 @Transactional()에서 이루어지는데, 코드를 보면 알겠지만 예외가 발생했을 경우 commit()이 아닌 rollback()이 이루어지게 되기 때문이다.

그렇기 때문에 각 Request는 별도의 독립적인 Context를 가져야 한다. 하나의 Request에서 진행되는 작업은 독립적으로 존재해야하며 타 Request의 작업에 영향을 줘서도, 받아서도 안된다. Context-local을 보장해야하는 이유가 바로 이것이다.

scoped_session()

일반적으로 SQLAlchemy에서 Thread-local을 보장하고 싶을 때 scoped_session() 메소드를 사용하여 세션을 관리한다. 해당 메소드에는 조금 특이한 인자가 하나 존재하는데 바로 scopefunc라는 인자이다. 

def __init__(self, session_factory, scopefunc=None):
    """Construct a new :class:`.scoped_session`.

    :param session_factory: a factory to create new :class:`.Session`
     instances. This is usually, but not necessarily, an instance
     of :class:`.sessionmaker`.
    :param scopefunc: optional function which defines
     the current scope.   If not passed, the :class:`.scoped_session`
     object assumes "thread-local" scope, and will use
     a Python ``threading.local()`` in order to maintain the current
     :class:`.Session`.  If passed, the function should return
     a hashable token; this token will be used as the key in a
     dictionary in order to store and retrieve the current
     :class:`.Session`.

    """
    self.session_factory = session_factory

    if scopefunc:
        self.registry = ScopedRegistry(session_factory, scopefunc)
    else:
        self.registry = ThreadLocalRegistry(session_factory)

메소드의 생성자를 살펴보면 위와 같은 코드를 볼 수 있다. scopefunc를 사용한다면 ScopedRegistry 객체를 생성한 후 registry에 저장한다.

def __init__(self, createfunc, scopefunc):
    """Construct a new :class:`.ScopedRegistry`.

    :param createfunc:  A creation function that will generate
      a new value for the current scope, if none is present.

    :param scopefunc:  A function that returns a hashable
      token representing the current scope (such as, current
      thread identifier).

    """
    self.createfunc = createfunc
    self.scopefunc = scopefunc
    self.registry = {}

def __call__(self):
    key = self.scopefunc()
    try:
        return self.registry[key]
    except KeyError:
        return self.registry.setdefault(key, self.createfunc())

ScopedRegistry의 생성자를 보면 처음 scoped_session()의 scopefunc인자가 그대로 넘어와서 저장됨을 확인할 수 있다. 또한 __call__ 메소드를 보면 scopefunc로 들어온 메소드를 실행하여 키를 결정하고 해당 키를 통해 내부 딕셔너리에서 값을 가져온다.

def set(self, obj):
    """Set the value for the current scope."""

    self.registry[self.scopefunc()] = obj

마찬가지로 저장시키는 부분 또한 scopefunc를 통해 키를 결정하고 obj(세션)를 registry에 저장시킨다. 이를 통해 우리가 알 수 있는 점은 다음과 같다.

scoped_session()은 내부적으로 registry라는 딕셔너리를 통해 세션을 관리하는데 scopefunc 인자에 메소드를 넣어주면 해당 메소드를 통해 registry에 접근하기 때문에 우리가 원하는 형태 또는 방식으로 세션을 관리할 수 있다.

Context-local Session

위 내용까지 이해했으면 이제 실제로 FastAPI에 적용해보자.

from contextvars import ContextVar, Token
from typing import Union

from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session, sessionmaker, Session

from core.config import config

session_context: ContextVar[str] = ContextVar("session_context")


def get_session_context() -> str:
    return session_context.get()


def set_session_context(session_id: str) -> Token:
    return session_context.set(session_id)


def reset_session_context(context: Token) -> None:
    session_context.reset(context)


engine = create_engine(config.DB_URL, pool_recycle=3600)
session: Union[Session, scoped_session] = scoped_session(
    sessionmaker(autocommit=True, autoflush=False, bind=engine),
    scopefunc=get_session_context,
)
Base = declarative_base()

SQLAlchemy 관련 세팅 코드는 위와 같다. 하나하나 살펴보자.

session_context: ContextVar[str] = ContextVar("session_context")


def get_session_context() -> str:
    return session_context.get()


def set_session_context(session_id: str) -> Token:
    return session_context.set(session_id)


def reset_session_context(context: Token) -> None:
    session_context.reset(context)

먼저 ContextVar를 사용해 해당 Context에서만 유효한 Context-local 변수를 하나 만든다. 해당 변수는 현재의 Request Context를 식별할 수 있는 특정한 랜덤 값(uuid4)을 담아주는 용도로 사용한다. 나머지 함수들은 해당 컨텍스트 변수에서 값을 가져오고 저장하고 삭제하는 역할을 한다.

session: Union[Session, scoped_session] = scoped_session(
    sessionmaker(autocommit=True, autoflush=False, bind=engine),
    scopefunc=get_session_context,
)

이 부분을 눈여겨 살펴봐야한다. 위에서 말했던 것 처럼 scopefunc의 인자로 get_session_context 함수를 넣어주고 있다. 따라서 세션이 사용되는 시점에 해당 함수를 통해 키를 결정하고 그 키를 통해 내부 registry에서 세션 객체를 가져오게 된다.

class SQLAlchemyMiddleware(BaseHTTPMiddleware):
    def __init__(self, app):
        super().__init__(app)

    async def dispatch(
        self, request: Request, call_next: RequestResponseEndpoint,
    ):
        session_id = str(uuid4())
        context = set_session_context(session_id=session_id)

        try:
            response = await call_next(request)
        except Exception as e:
            session.rollback()
            raise e
        finally:
            session.remove()
            reset_session_context(context=context)

        return response

다음으로 미들웨어를 하나 생성한다. 해당 미들웨어는 Request가 들어올 때 거치게 되는데, 그 시점에 uuid4를 통해 랜덤 값을 생성한 후 컨텍스트 변수에 담는 역할을 한다. 그리고 모든 작업이 완료된 후 컨텍스트 변수에서 값을 삭제해주는 역할 또한 한다.

위 작업들을 통해 Request가 들어올 때 SQLAlchemyMiddleware에서 해당 Request를 위한 랜덤값을 컨텍스트 변수에 세팅해주고, Session을 사용하는 시점에 컨텍스트 변수에 저장된 값을 통해 내부 registry에서 사용할 세션을 결정한다. 컨텍스트 변수는 해당 컨텍스트 내에서만 유효한 변수이므로 Context-local이 보장된다. 따라서 우리가 원하는 Request Per Session, Context Local Session을 구현할 수 있다.

그림으로 나타내면 위처럼 표시할 수 있겠다. Request 시작부터 끝나는 지점까지를 하나의 Context로 취급한다고 생각하면 된다. 이제 위 내용을 토대로 실제로 정상적으로 동작하는지 확인해보자.

@home_router.get("/")
async def home(): 
    from app.models import User 
    import asyncio 
    session.query(User).first() 
    
    print(f"[*] registry: {session.registry.registry}") 
    key = session.registry.scopefunc() 
    print(f"[*] key: {key}") 
    print(f"[*] session: {session.registry.registry[key]}") 
    print()
    
    await asyncio.sleep(3) 
    return {"status": True}

위 코드를 연속으로 2번 요청하면 아래와 같은 결과가 나온다.

[*] registry: {'0085ee4d-598d-43b3-8122-403d77f25ef7': <sqlalchemy.orm.session.Session object at 0x108d09cd0>} 
[*] key: 0085ee4d-598d-43b3-8122-403d77f25ef7 
[*] session: <sqlalchemy.orm.session.Session object at 0x108d09cd0> 
[*] registry: {'0085ee4d-598d-43b3-8122-403d77f25ef7': <sqlalchemy.orm.session.Session object at 0x108d09cd0>, '4d5d1ac3-fb88-41bb-bcfb-77de7811cb25': <sqlalchemy.orm.session.Session object at 0x108d67bb0>} 
[*] key: 4d5d1ac3-fb88-41bb-bcfb-77de7811cb25 
[*] session: <sqlalchemy.orm.session.Session object at 0x108d67bb0>

최초 요청 시 registry에는 하나의 키/값(세션)만 존재했다. 최초 요청의 sleep으로 인해 아직 해당 요청의 응답이 나가기 이전에 두번 째 요청이 들어갔기 때문에 registry에 두 번째 요청의 키/값(세션)이 추가되었다. 결과적으로 첫 번째 요청과 두 번째 요청의 키가 다르며 사용되는 세션의 id값 또한 다른 모습을 확인할 수 있다.

SQLAlchemy 비동기 적용

개요에서 설명했듯이 SQLAlchemy 1.4버전부터는 비동기처리도 지원하기 시작했다. 현재 사내 코드를 마이그레이션 하기 전 테스트하는 과정에 있으며 본문부터는 Production ready 상태가 아님을 인지하고 읽기 바란다.

create_async_engine

from sqlalchemy.ext.asyncio import create_async_engine

engine = create_async_engine(
    "mysql+aiomysql://fastapi:fastapi@localhost:3306/fastapi", 
    pool_recycle=3600,
)

기존은 create_engine() 메소드를 사용했던 반면 비동기의 경우 ext.asyncio에 있는 create_async_engine() 메소드를 사용한다. 또한 데이터베이스 접속 라이브러리도 비동기를 지원하는 라이브러리로 교체해줘야한다. 본문에서는 MySQL을 사용하고 있으므로 pymysql에서 aiomysql로 바꿔주었다.

sessionmaker

from sqlalchemy.ext.asyncio import (
    AsyncSession,
    create_async_engine,
)
from sqlalchemy.orm import sessionmaker


engine = create_async_engine(config.DB_URL, pool_recycle=3600)
async_session_factory = sessionmaker(bind=engine, class_=AsyncSession)

sessionmaker의 경우 기존과 거의 동일하다. 차이점은 class_ 인자에 AsyncSession을 넣어준다는 점이다.

async_scoped_session

SQLAlchemy 공식문서를 보면 async_scoped_session() 관련하여 위와 같은 자료를 찾을 수 있다. 주목할점은 scopefunc인자에 asyncio의 current_task 메소드를 넣어준다는 점이다. current_task는 현재 실행중인 Task인스턴스를 반환하는 메소드이다. 따라서 세션이 사용될 때 각 컨텍스트에 맞는 Task값을 통해 세션을 가져옴으로써 Context-local을 보장할 수 있음을 나타내고 있다. 이전 동기 방식에서는 현재 Context를 구분하기위해 미들웨어에서 특정한 값을 컨텍스트 변수에 세팅해준 후 가져와서 사용하고 있었는데, 공식 문서에 나와있는 것 처럼 현재 Task를 출력할 수 있는 메소드가 이미 존재하니 해당 메소드를 사용하면 미들웨어도 걷어내고 보다 간결한 코드를 가져갈 수 있다고 생각했다. 검증해보기 위해 먼저 current_task를 단순 출력해봤다.

<Task pending name='Task-3' coro=<RequestResponseCycle.run_asgi() running at /Users/hide/.local/share/virtualenvs/fastapi-boilerplate-toX-v08U/lib/python3.8/site-packages/uvicorn/protocols/http/httptools_impl.py:372> cb=[set.discard()]>

재차 설명하지만, 코루틴은 싱글 스레드에서 여러개의 Context가 Concurrency하게 실행되는 형태이기에 현재 Task에 맞게 세션을 사용하도록 설정해준다면 큰 문제가 없을 것이라고 생각했었다. 그런데 여기서 asyncio.gather()를 통해 동시에 코드를 실행한다면 어떻게 될까?

class UserService: 
    async def test(self): 
        import asyncio 
        print(asyncio.current_task())

먼저 test() 라는 메소드를 하나 생성하고 현재 Task를 출력하도록 만든다.

async def home(): 
    import asyncio 
    from app.services.user 
    import UserService 
    
    print(asyncio.current_task()) 
    
    await asyncio.gather( 
        UserService().test(), 
        UserService().test(), 
    )

그리고 라우터단에서 현재 Task를 한번 출력하고 gather()를 통해 위에서 생성한 test() 메소드 2개를 동시에 실행시켜줬다.

<Task pending name='Task-3' coro=<RequestResponseCycle.run_asgi() running at /Users/hide/.local/share/virtualenvs/fastapi-boilerplate-toX-v08U/lib/python3.8/site-packages/uvicorn/protocols/http/httptools_impl.py:372> cb=[set.discard()]>
<Task pending name='Task-4' coro=<UserService.test() running at /Users/hide/fastapi-boilerplate/app/services/user.py:66> cb=[gather.<locals>._done_callback() at /Applications/Xcode.app/Contents/Developer/Library/Frameworks/Python3.framework/Versions/3.8/lib/python3.8/asyncio/tasks.py:769]>
<Task pending name='Task-5' coro=<UserService.test() running at /Users/hide/fastapi-boilerplate/app/services/user.py:66> cb=[gather.<locals>._done_callback() at /Applications/Xcode.app/Contents/Developer/Library/Frameworks/Python3.framework/Versions/3.8/lib/python3.8/asyncio/tasks.py:769]>

첫 번째 라인(Task-3)이 라우터에서 실행된 Task이고 나머지 라인(Task-4, Task-5)이 test() 메소드에서 출력한 현재 Task이다. 결과값을 보면 알겠지만 모두 동일 Task가 아닌 새로운 Task가 생성되어 돌아가게된다. 이렇게 되면 문제가 발생한다. 위 async_scoped_session의 scopefunc에 current_task를 통해 현재 Task객체를 통해 사용할 세션을 결정하도록 만들어놨는데, 위처럼 새로운 Task가 생기는 경우 Context-local이 보장되지 않기 때문이다.

문제점

또한 현재 코드 상황에서도 문제가 발생할 여지가 있었는데, 본 포스팅의 두 번째 주제를 확인해보면 @Transactional() 데코레이터를 통해 여러개의 DB Persist작업을 하나의 트랜잭션으로 묶는 행위를 한다고 했다. 여기서도 문제가 발생했다. 예를 들어 아래와 같은 코드가 있다고 생각해보자.

@Transactional() 
async def exception_add(self): 
    import asyncio 
    print(asyncio.current_task()) 
    
    user = User(email="email", password="password1", nickname="nickname") 
    session.add(user) 
    await asyncio.sleep(3) 
    raise DuplicateEmailOrNicknameException 
 
 
 @Transactional() 
 async def add(self): 
     import asyncio 
     print(asyncio.current_task()) 
     
     user = User(email="email", password="password1", nickname="nickname") 
     session.add(user)

exception_add()의 경우 세션에 모델을 추가하고 3초 후 예외를 발생시킨다. add()는 단순히 세션에 모델을 추가하는 행위만 한다. 두 메소드 모두 @Transactional() 데코레이터로 묶여있다.

async def home(): 
    import asyncio 
    from app.services.user import UserService 
    
    print(asyncio.current_task()) 
    
    await asyncio.gather( 
        UserService().exception_add(), 
        UserService().add(), 
    ) 
    return {"status": True}

위에서 만든 두개의 메소드를 gather()를 통해 실행시킨 결과는 다음과 같다.

<Task pending name='Task-3' coro=<RequestResponseCycle.run_asgi() running at /Users/hide/.local/share/virtualenvs/fastapi-boilerplate-toX-v08U/lib/python3.8/site-packages/uvicorn/protocols/http/httptools_impl.py:372> cb=[set.discard()]>
<Task pending name='Task-4' coro=<UserService.exception_add() running at /Users/hide/fastapi-boilerplate/core/db/transactional.py:20> cb=[gather.<locals>._done_callback() at /Applications/Xcode.app/Contents/Developer/Library/Frameworks/Python3.framework/Versions/3.8/lib/python3.8/asyncio/tasks.py:769]>
<Task pending name='Task-5' coro=<UserService.add() running at /Users/hide/fastapi-boilerplate/core/db/transactional.py:20> cb=[gather.<locals>._done_callback() at /Applications/Xcode.app/Contents/Developer/Library/Frameworks/Python3.framework/Versions/3.8/lib/python3.8/asyncio/tasks.py:769]>
INFO: 127.0.0.1:59173 - "GET / HTTP/1.1" 400 Bad Request

모두 각기 다른 Task를 통해 실행이 되고 이때 데이터베이스를 한번 확인해보면,

mysql> select * from users;
+-----+-----------+-------+----------+----------+---------------------+---------------------+
| id  | password  | email | nickname | is_admin | created_at          | updated_at          |
+-----+-----------+-------+----------+----------+---------------------+---------------------+
| 123 | password1 | email | nickname |        0 | 2022-03-21 16:12:22 | 2022-03-21 16:12:22 |
+-----+-----------+-------+----------+----------+---------------------+---------------------+
1 row in set (0.00 sec)

하나의 로우가 추가되었다. Task-4(exception_add())에서 예외가 발생하여 전체적으로 서버의 응답은 400 Bad Request로 나갔지만 Task-5(add())는 정상적으로 실행이 되었기 때문이다. @Transactional() 데코레이터는 세션을 통한 DB Persist 작업을 하나의 트랜잭션으로 묶기 위한 것인데 이렇게 된다면 해당 데코레이터가 의도적으로 동작할 수 없다. 처음에는 gather()를 통해 동시에 여러개의 작업을 실행해도 어느 한곳에서 예외가 발생하면 exception propogation으로 인해 다른 작업들도 취소될 것이라 생각했다. 하지만 그렇게 동작하지 않았고 공식문서에서 답을 찾을 수 있었다.

현재까지 구현한 내용을 통해 생각하고 바랬던 내용을 그림으로 표현하자면 아래와 같다.

하나의 Request가 들어온 후 내부 비즈니스 로직을 거치고 기타 작업들을 수행하는 모든 과정을 하나의 Context로 담고 싶었고 그렇게 동작한다고 생각했다.

하지만 결과적으로 위와 같이 동작했다. gather()를 통해 여러개의 작업을 한번에 실행시키면 내부적으로 새로운 Task를 생성하여 실행되기 때문에 Task-3, Task-4, Task-5 각기 다른 Context를 가진다. 따라서 공식문서에 나와있는 것 처럼 asyncio.current_task()를 scopefunc에 넣어주는 것만으로는 Context-local을 완벽하게 보장할 수 없다.

async def home():
    import asyncio
    print(asyncio.current_task())
    from app.services.user import UserService
    task1 = asyncio.create_task(UserService().exception_add())
    task2 = asyncio.create_task(UserService().add())
    await task1
    await task2
    return {"status": True}

참고로 위처럼 create_task() 메소드를 통해 작업해도 동일한 문제가 발생한다.

해결 방법

본문 초반에 설명한 방식처럼 미들웨어를 통해 직접 컨텍스트 변수를 세팅하고 사용하는 방식으로 세션을 다룸으로써 해결하였다.

생각해볼 점

이번 문제는 @Transactional() 데코레이터를 통해 하나의 트랜잭션으로 묶는 작업이 있기에 발생하는 문제이다. 글을 작성하고 나니 하나의 Request를 하나의 Context로 여기는게 맞을까라는 생각도 든다. 각각 @Transactional() 데코레이터가 붙은 메소드를 동시에 실행할 때 하나가 실패하면 다른 하나도 실패하는것이 과연 맞을까? 그리고 그런 작업이 과연 많을까? 라는 생각이 든다. 내부적으로 Domain Model Pattern이 아닌 Transaction Script Pattern을 사용하고 있어 서비스 레이어에 로직이 모여있는 경향이 있기에 이런 의문점이 드는 것 같기도 하다. 이 부분은 개발을 진행하며 추가적인 고찰이 필요한 것으로 보인다.

Ref

본 포스팅은 기존에 작성했던 모든 포스팅을 종합하여 나름 최종적인 결과물로 도출한 것이다. 아래의 포스팅을 통해 그간의 과정과 고민들을 같이 살펴보면 도움이 될 것 같다.

FastAPI SQLAlchemy Session 객체 연동하기

FastAPI SQLAlchemy 연동하며 발생한 문제 정리

SQLAlchemy AsyncSession으로 비동기 적용하기

SQLAlchemy AsyncSession으로 비동기 적용하며 생긴 문제점

 

모든 소스 코드는 https://github.com/teamhide/fastapi-boilerplate 에서 확인할 수 있다.