33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141 | class DB:
"""
DB class
"""
@classmethod
def engine_async(cls):
try:
return cls._engine_async
except AttributeError:
cls.set_async_db()
return cls._engine_async
@classmethod
def engine_sync(cls):
try:
return cls._engine_sync
except AttributeError:
cls.set_sync_db()
return cls._engine_sync
@classmethod
def set_async_db(cls):
settings = Inject(get_settings)
settings.check_db()
if settings.DB_ENGINE == "sqlite":
logger.warning(SQLITE_WARNING_MESSAGE)
# Set some sqlite-specific options
engine_kwargs_async = dict(poolclass=StaticPool)
else:
engine_kwargs_async = {
"pool_pre_ping": True,
}
cls._engine_async = create_async_engine(
settings.DATABASE_ASYNC_URL,
echo=settings.DB_ECHO,
future=True,
**engine_kwargs_async,
)
cls._async_session_maker = sessionmaker(
cls._engine_async,
class_=AsyncSession,
expire_on_commit=False,
future=True,
)
@classmethod
def set_sync_db(cls):
settings = Inject(get_settings)
settings.check_db()
if settings.DB_ENGINE == "sqlite":
logger.warning(SQLITE_WARNING_MESSAGE)
# Set some sqlite-specific options
engine_kwargs_sync = dict(
poolclass=StaticPool,
connect_args={"check_same_thread": False},
)
else:
engine_kwargs_sync = {}
cls._engine_sync = create_engine(
settings.DATABASE_SYNC_URL,
echo=settings.DB_ECHO,
future=True,
**engine_kwargs_sync,
)
cls._sync_session_maker = sessionmaker(
bind=cls._engine_sync,
autocommit=False,
autoflush=False,
future=True,
)
@event.listens_for(cls._engine_sync, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
if settings.DB_ENGINE == "sqlite":
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.close()
@classmethod
async def get_async_db(cls) -> AsyncGenerator[AsyncSession, None]:
"""
Get async database session
"""
try:
session_maker = cls._async_session_maker()
except AttributeError:
cls.set_async_db()
session_maker = cls._async_session_maker()
async with session_maker as async_session:
yield async_session
@classmethod
def get_sync_db(cls) -> Generator[DBSyncSession, None, None]:
"""
Get sync database session
"""
try:
session_maker = cls._sync_session_maker()
except AttributeError:
cls.set_sync_db()
session_maker = cls._sync_session_maker()
with session_maker as sync_session:
yield sync_session
|