SQLAlchemy

Multithreading

Typically, SQLAlchemy does not support multithreading. This can be an issue if your API/web client uses multiple workers. However, we can support multithreading as follows:

## This is a global util file

import os
from contextlib import contextmanager
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from api.config import db_username, db_password, db_private_ip, db_database_name

DB_CREDENTIALS = f"{db_username}:{db_password}@{db_private_ip}:3306/{db_database_name}"

def create_engine_and_session():
      # Ensure pymysql is installed
    db_uri = f"mysql+pymysql://{DB_CREDENTIALS}"

    new_engine = create_engine(db_uri, echo=False)
    new_session = sessionmaker(autocommit=False,
                               autoflush=False,
                               bind=new_engine)

    return new_engine, new_session

current_engine, Session = create_engine_and_session()

@contextmanager
def session_scope():
    session = Session()
    try:
        yield session
        session.commit()
    except:
        session.rollback()
    finally:
        session.close()
## Example file that manipulates the database in a multithreaded environment
## Keep in mind you should keep all sessions separate. Do not pass sessions from function to function.
from api.util import session_scope # importing function from file above

def example_function():
    with session_scope() as session:
        learner_name = session.query(
            DatabaseModel.field
        ).first[0]

Enums

from sqlalchemy import Enum
from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()


class BaseEnum(enum.Enum):
    """Defines utility methods for accessing enum fields"""

    @classmethod
    def list_enums(cls) -> list:
        """Lists the enums"""
        return cls.__members__ # When you call AuthorTeams.list_enums(), it will output [TEAM_!, TEAM_2, TEAM_3]

class Teams(BaseEnum):
    """Enums for teams"""

    TEAM_1 = 1
    TEAM_2 = 2
    TEAM_3 = 3


class SomeTable(Base):
    __tablename__ = "some_table"

    id = Column(Integer, primary_key=True)
    team = Column(Enum(Teams))

Base Class

When each SQLAlchemy model/table class has common methods, we can put them in the Base class that all the tables inherit from.

Here are some methods I found to be useful:

from sqlalchemy.ext.declarative import declarative_base

class Base:
    """Parent class for all SQLA models"""

    def generate_key_value(self) -> Iterator[Dict[str, any]]:
        """Generate attribute name/val pairs, filtering out SQLA attributes."""
        exclude = ("_sa_adapter", "_sa_instance_state")
        for key, value in vars(self).items():
            if not key.startswith("_") and not any(
                hasattr(value, attr) for attr in exclude
            ):
                yield key, value

    def __repr__(self) -> str:
        """Prints the queries all nice"""
        params = ", ".join(
            f"\n\t{key}={value}" for key, value in self.generate_key_value()
        )
        return f"{self.__class__.__name__}({params}\n)"

    def get_unique_constraint_columns(self) -> List[str]:
        """Gets columns that have unique constraints""" 
        # This is given the unique constraints are defined in the __table_args__
        if hasattr(self, "__table_args__"):
            table_args = self.__table_args__
        else:
            return []

        if table_args:
            return [column.name for column in table_args[0].columns]

    def get_columns(self) -> Dict[str, any]:
        """Gets columns values of entry excluding SQLAlchemy default columns"""
        return {
            column: value
            for column, value in vars(self).items()
            if column != "_sa_instance_state"
        }

Base = declarative_base(cls=Base)

class SomeTable(Base):
    __tablename__ = "some_table"

    id = Column(Integer, primary_key=True)
    something_unique = Column(String(255), nullable=False)

    __table_args__ = (UniqueConstraint("something_unique"),) # also works with multiple columns

Foreign Keys

class ParentTable(Base):
    __tablename__ = "parent_table"

    id = Column(Integer, primary_key=True)
    child = Column(ForeignKey("ChildTable.id"), index=True)
    child_table = relationship("ChildTable")


class ChildTable(Base):
    __tablename__ = "child_table"

    id = Column(Integer, primary_key=True)

Datetime/Timestamp Column

from datetime import datetime
from sqlalchemy.dialects.mysql import DATETIME


class SomeTable(Base):
    __tablename__ = "some_table"

    id = Column(Integer, primary_key=True)
    timestamp = Column(DATETIME(fsp=6), default=datetime.utcnow) # PASS IN THE FUNCTION (DONT PASS IT IN WITH ())

Useful SQLAlchemy Utility Functions

"""SQLAlchemy functions that are commonly used among the bitbucket classes"""
from typing import Union

from sqlalchemy.ext.declarative.api import DeclarativeMeta
from sqlalchemy.orm.session import Session

from logger import logger

# <some_model_class_object> = e.g. SomeTable(..., ...)
def get_unique_constraint_filters(entry: <some_model_class_object>) -> (DeclarativeMeta, dict):
    """Gets the table and filters needed in order to query for an entry"""
    table = entry.__class__
    return table, {
        getattr(table, unique_col): getattr(entry, unique_col)
        for unique_col in entry.get_unique_constraint_columns()
    }


def get_entry_id(entry: <some_model_class_object>, session: Session) -> Union[None, int]:
    """Gets ID of a given entry in MySQL"""
    table, filters = get_unique_constraint_filters(entry)
    query = session.query(table)

    for column, value in filters.items():
        query = query.filter(column == value)

    if query.scalar():
        return query.first().id
    else:
        logger.debug("ID was not found in table %s", table)
        return None


def upsert_entry_to_mysql(entry: entry: <some_model_class_object>, session: Session) -> None:
    """Upserts entry to MySQL"""
    table, filters = get_unique_constraint_filters(entry)
    entry_query = session.query(table)

    for column, value in filters.items():
        entry_query = entry_query.filter(column == value)

    if entry_query.scalar():
        entry_values = entry.get_columns()
        query_id = session.query(table).filter(table.id == entry_query.first().id)

        for column, value in entry_values.items():
            query_id.update({column: value})
    else:
        session.add(entry)

    session.commit()

Last updated