Source code for fleetmanager.data_access.db_engine

import os
from contextlib import contextmanager

import pandas as pd
import sqlalchemy
from dotenv import load_dotenv
from importlib_resources import files
from sqlalchemy import create_engine, select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool

from .dbschema import (
    Base,
    FuelTypes,
    LeasingTypes,
    VehicleTypes,
    default_fuel_types,
    default_leasing_types,
    default_vehicle_types,
)

try:
    import datasources
    from pgrace import properties
except ImportError:
    GRACE_DATA_SOURCE = ""
else:
    GRACE_DATA_SOURCE = properties.get_shared_property("DATA_SOURCE")

load_dotenv()


[docs]def engine_creator( db_name=None, db_password=None, db_user=None, db_url=None, db_server=None, ) -> sqlalchemy.engine.Engine: """ Generic db engine creator. Loads env variables, e.g. in .env otherwise could be passed with click. Ensures that tables according to dbschema is created before returning Parameters ---------- db_name db_password db_user db_url Returns ------- sqlalchemy.engine """ if db_name is None: db_name = os.getenv("DB_NAME") if db_password is None: db_password = os.getenv("DB_PASSWORD") if db_user is None: db_user = os.getenv("DB_USER") if db_url is None: db_url = os.getenv("DB_URL") if db_server is None: db_server = os.getenv("DB_SERVER") if any((db_name, db_password, db_user, db_url, db_server)): db_engine = create_engine( f"{db_server}://{db_user}:{db_password}@{db_url}/{db_name}", encoding="latin-1", ) elif GRACE_DATA_SOURCE: db_engine = datasources.connect(GRACE_DATA_SOURCE) else: from sqlite3 import OperationalError db_engine = create_engine( "sqlite:///file:fleetdb?mode=memory&cache=shared&uri=true", connect_args={"check_same_thread": False}, poolclass=StaticPool, encoding="latin-1", ) try: db_engine.raw_connection().connection.executescript( files("fleetmanager") .joinpath("dummy_data.sql") .read_text(encoding="utf-8") ) except OperationalError: pass Base.metadata.create_all(db_engine) create_defaults(db_engine) return db_engine
[docs]def create_defaults(engine_): """ Function to load in the defaults defined in dbschema """ forms = [VehicleTypes, LeasingTypes, FuelTypes] default_entries = [default_vehicle_types, default_leasing_types, default_fuel_types] for k, (form, defaults) in enumerate(zip(forms, default_entries)): adds = [] current = pd.read_sql(select([form]), engine_) for default in defaults: if default["id"] not in current.id.values: adds.append(form(**default)) if adds: Session = session_factory(engine_) with Session.begin() as sess: sess.add_all(adds)
[docs]def session_factory(db_engine): """ Generating sessions used all over the project by feeding in a sqlalchemy engine. Handles rollback if a connection issue or the like occurs. Parameters ---------- db_engine : sqlalchemy.engine, the database connection Returns ------- ManagedSession """ Session = sessionmaker(bind=db_engine) class ManagedSession: @contextmanager def __call__(self): session = Session() try: yield session except Exception: session.rollback() raise finally: session.close() @contextmanager def begin(self): session = Session() try: yield session session.commit() except Exception: session.rollback() raise finally: session.close() return ManagedSession()