File indexing completed on 2026-04-09 07:58:19
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011 """
0012 session decorator
0013
0014 Borrowed from:
0015 https://github.com/rucio/rucio/blob/master/lib/rucio/db/sqla/session.py
0016 """
0017
0018 import logging
0019 import sys
0020
0021 from functools import wraps
0022 from inspect import isgeneratorfunction
0023 from retrying import retry
0024 from threading import Lock
0025 from os.path import basename
0026
0027 from sqlalchemy import create_engine, event, inspect, select
0028 from sqlalchemy.exc import DatabaseError, DisconnectionError, OperationalError, TimeoutError
0029 from sqlalchemy.ext.declarative import declarative_base
0030 from sqlalchemy.orm import sessionmaker, scoped_session
0031
0032 from idds.common.config import config_get, config_has_option
0033 from idds.common.exceptions import IDDSException, DatabaseException
0034
0035
0036 LOG = logging.getLogger(__name__)
0037
0038 DATABASE_SECTION = 'database'
0039
0040 BASE = declarative_base()
0041
0042 DEFAULT_SCHEMA_NAME = None
0043 if config_has_option(DATABASE_SECTION, 'schema'):
0044 DEFAULT_SCHEMA_NAME = config_get(DATABASE_SECTION, 'schema')
0045 if DEFAULT_SCHEMA_NAME:
0046 BASE.metadata.schema = DEFAULT_SCHEMA_NAME
0047
0048 _MAKER, _ENGINE, _LOCK = None, None, Lock()
0049
0050
0051 def _fk_pragma_on_connect(dbapi_con, con_record):
0052
0053 try:
0054 dbapi_con.execute('pragma foreign_keys=ON')
0055 except AttributeError:
0056 pass
0057
0058
0059 def mysql_ping_listener(dbapi_conn, connection_rec, connection_proxy):
0060 """
0061 Ensures that MySQL connections checked out of the
0062 pool are alive.
0063 Borrowed from:
0064 http://groups.google.com/group/sqlalchemy/msg/a4ce563d802c929f
0065 :param dbapi_conn: DBAPI connection
0066 :param connection_rec: connection record
0067 :param connection_proxy: connection proxy
0068 """
0069
0070 try:
0071 dbapi_conn.cursor().execute('select 1')
0072 except dbapi_conn.OperationalError as ex:
0073 if ex.args[0] in (2006, 2013, 2014, 2045, 2055):
0074 msg = 'Got mysql server has gone away: %s' % ex
0075 raise DisconnectionError(msg)
0076 else:
0077 raise
0078
0079
0080 def mysql_convert_decimal_to_float(dbapi_conn, connection_rec):
0081 """
0082 The default datatype returned by mysql-python for numerics is decimal.Decimal.
0083 This type cannot be serialised to JSON, therefore we need to autoconvert to floats.
0084 Even worse, there's two types of decimals created by the MySQLdb driver, so we must
0085 override both.
0086 :param dbapi_conn: DBAPI connection
0087 :param connection_rec: connection record
0088 """
0089
0090 try:
0091 import MySQLdb.converters
0092 from MySQLdb.constants import FIELD_TYPE
0093 except:
0094 raise IDDSException('Trying to use MySQL without mysql-python installed!')
0095 conv = MySQLdb.converters.conversions.copy()
0096 conv[FIELD_TYPE.DECIMAL] = float
0097 conv[FIELD_TYPE.NEWDECIMAL] = float
0098 dbapi_conn.converter = conv
0099
0100
0101 def psql_convert_decimal_to_float(dbapi_conn, connection_rec):
0102 """
0103 The default datatype returned by psycopg2 for numerics is decimal.Decimal.
0104 This type cannot be serialised to JSON, therefore we need to autoconvert to floats.
0105 :param dbapi_conn: DBAPI connection
0106 :param connection_rec: connection record
0107 """
0108
0109 try:
0110 import psycopg2.extensions
0111 except:
0112 raise IDDSException('Trying to use PostgreSQL without psycopg2 installed!')
0113
0114 DEC2FLOAT = psycopg2.extensions.new_type(psycopg2.extensions.DECIMAL.values,
0115 'DEC2FLOAT',
0116 lambda value, curs: float(value) if value is not None else None)
0117 psycopg2.extensions.register_type(DEC2FLOAT)
0118
0119
0120 def my_on_connect(dbapi_con, connection_record):
0121 """ Adds information to track performance and ressource by module.
0122 Info are recorded in the V$SESSION and V$SQLAREA views.
0123 """
0124 caller = basename(sys.argv[0])
0125 dbapi_con.clientinfo = caller
0126 dbapi_con.client_identifier = caller
0127 dbapi_con.module = caller
0128 dbapi_con.action = caller
0129
0130
0131 def get_engine(echo=True):
0132 """ Creates a engine to a specific database.
0133 :returns: engine
0134 """
0135 global _ENGINE
0136 if not _ENGINE:
0137 sql_connection = config_get(DATABASE_SECTION, 'default')
0138 config_params = [('pool_size', int), ('max_overflow', int), ('pool_timeout', int),
0139 ('pool_recycle', int), ('echo', int), ('echo_pool', str),
0140 ('pool_reset_on_return', str), ('use_threadlocal', int)]
0141
0142 params = {}
0143 for param, param_type in config_params:
0144 try:
0145 params[param] = param_type(config_get(DATABASE_SECTION, param))
0146 except:
0147 pass
0148 params['execution_options'] = {'schema_translate_map': {None: DEFAULT_SCHEMA_NAME}}
0149 if 'oracledb' in sql_connection:
0150 try:
0151 import oracledb
0152 oracledb.init_oracle_client()
0153 params['thick_mode'] = True
0154 except Exception as err:
0155 LOG.warning('Could not start Oracle thick mode; falling back to thin: %s', err)
0156
0157 _ENGINE = create_engine(sql_connection, **params)
0158
0159 if 'mysql' in sql_connection:
0160 event.listen(_ENGINE, 'checkout', mysql_ping_listener)
0161 event.listen(_ENGINE, 'connect', mysql_convert_decimal_to_float)
0162 elif 'postgresql' in sql_connection:
0163 event.listen(_ENGINE, 'connect', psql_convert_decimal_to_float)
0164 elif 'sqlite' in sql_connection:
0165 event.listen(_ENGINE, 'connect', _fk_pragma_on_connect)
0166 elif 'oracle' in sql_connection:
0167 event.listen(_ENGINE, 'connect', my_on_connect)
0168 assert _ENGINE
0169 return _ENGINE
0170
0171
0172 def get_dump_engine(echo=False):
0173 """ Creates a dump engine to a specific database.
0174 :returns: engine """
0175
0176 statements = list()
0177
0178 def dump(sql, *multiparams, **params):
0179 statement = str(sql.compile(dialect=engine.dialect))
0180 if statement in statements:
0181 return
0182 statements.append(statement)
0183 if statement.endswith(')\n\n'):
0184 if engine.dialect.name == 'oracle':
0185 print(statement.replace(')\n\n', ') PCTFREE 0;\n'))
0186 else:
0187 print(statement.replace(')\n\n', ');\n'))
0188 elif statement.endswith(')'):
0189 print(statement.replace(')', ');\n'))
0190 else:
0191 print(statement)
0192 sql_connection = config_get(DATABASE_SECTION, 'default')
0193
0194 engine = create_engine(sql_connection, echo=echo, strategy='mock', executor=dump)
0195 return engine
0196
0197
0198 def get_maker():
0199 """
0200 Return a SQLAlchemy sessionmaker.
0201 May assign __MAKER if not already assigned.
0202 """
0203 global _MAKER, _ENGINE
0204 assert _ENGINE
0205 if not _MAKER:
0206 _MAKER = sessionmaker(bind=_ENGINE, autocommit=False, autoflush=False, expire_on_commit=True)
0207 return _MAKER
0208
0209
0210 def get_session():
0211 """ Creates a session to a specific database, assumes that schema already in place.
0212 :returns: session
0213 """
0214 global _MAKER, _LOCK
0215 if not _MAKER:
0216 _LOCK.acquire()
0217 try:
0218 get_engine()
0219 get_maker()
0220 finally:
0221 _LOCK.release()
0222 assert _MAKER
0223 session = scoped_session(_MAKER)
0224 session.schema = DEFAULT_SCHEMA_NAME
0225 return session
0226
0227
0228 def retry_if_db_connection_error(exception):
0229 """Return True if error in connecting to db."""
0230
0231 if isinstance(exception, OperationalError):
0232 conn_err_codes = ('2002', '2003', '2006',
0233 'ORA-00028',
0234 'ORA-01012',
0235 'ORA-03113',
0236 'ORA-03114',
0237 'ORA-03135',
0238 'ORA-25408',)
0239 for err_code in conn_err_codes:
0240 if exception.args[0].find(err_code) != -1:
0241 return True
0242 if isinstance(exception, DatabaseException):
0243 conn_err_codes = ('server closed the connection unexpectedly',
0244 'closed the connection',)
0245 for err_code in conn_err_codes:
0246 if str(exception.args[0]).find(err_code) != -1:
0247 return True
0248 return False
0249
0250
0251 def read_session(function):
0252 '''
0253 decorator that set the session variable to use inside a function.
0254 With that decorator it's possible to use the session variable like if a global variable session is declared.
0255 session is a sqlalchemy session, and you can get one calling get_session().
0256 This is useful if only SELECTs and the like are being done; anything involving
0257 INSERTs, UPDATEs etc should use transactional_session.
0258 '''
0259 @retry(retry_on_exception=retry_if_db_connection_error,
0260 wait_fixed=0.5,
0261 stop_max_attempt_number=2,
0262 wrap_exception=False)
0263 @wraps(function)
0264 def new_funct(*args, **kwargs):
0265
0266 if isgeneratorfunction(function):
0267 raise IDDSException('read_session decorator should not be used with generator. Use stream_session instead.')
0268
0269 if not kwargs.get('session'):
0270 session = get_session()
0271 try:
0272 kwargs['session'] = session
0273 result = function(*args, **kwargs)
0274 session.remove()
0275 return result
0276 except TimeoutError as error:
0277 session.rollback()
0278 raise DatabaseException(str(error))
0279 except DatabaseError as error:
0280 session.rollback()
0281 raise DatabaseException(str(error))
0282 except:
0283 session.rollback()
0284 raise
0285 finally:
0286 session.remove()
0287 try:
0288 return function(*args, **kwargs)
0289 except:
0290 raise
0291 new_funct.__doc__ = function.__doc__
0292 return new_funct
0293
0294
0295 def stream_session(function):
0296 '''
0297 decorator that set the session variable to use inside a function.
0298 With that decorator it's possible to use the session variable like if a global variable session is declared.
0299 session is a sqlalchemy session, and you can get one calling get_session().
0300 This is useful if only SELECTs and the like are being done; anything involving
0301 INSERTs, UPDATEs etc should use transactional_session.
0302 '''
0303 @retry(retry_on_exception=retry_if_db_connection_error,
0304 wait_fixed=0.5,
0305 stop_max_attempt_number=2,
0306 wrap_exception=False)
0307 @wraps(function)
0308 def new_funct(*args, **kwargs):
0309
0310 if not isgeneratorfunction(function):
0311 raise IDDSException('stream_session decorator should be used only with generator. Use read_session instead.')
0312
0313 if not kwargs.get('session'):
0314 session = get_session()
0315 try:
0316 kwargs['session'] = session
0317 for row in function(*args, **kwargs):
0318 yield row
0319 except TimeoutError as error:
0320 print(error)
0321 session.rollback()
0322 raise DatabaseException(str(error))
0323 except DatabaseError as error:
0324 print(error)
0325 session.rollback()
0326 raise DatabaseException(str(error))
0327 except:
0328 session.rollback()
0329 raise
0330 finally:
0331 session.remove()
0332 else:
0333 try:
0334 for row in function(*args, **kwargs):
0335 yield row
0336 except:
0337 raise
0338 new_funct.__doc__ = function.__doc__
0339 return new_funct
0340
0341
0342 def transactional_session(function):
0343 '''
0344 decorator that set the session variable to use inside a function.
0345 With that decorator it's possible to use the session variable like if a global variable session is declared.
0346 session is a sqlalchemy session, and you can get one calling get_session().
0347 '''
0348 @retry(retry_on_exception=retry_if_db_connection_error,
0349 wait_fixed=0.5,
0350 stop_max_attempt_number=2,
0351 wrap_exception=False)
0352 @wraps(function)
0353 def new_funct(*args, **kwargs):
0354 if not kwargs.get('session'):
0355 session = get_session()
0356 try:
0357 kwargs['session'] = session
0358 result = function(*args, **kwargs)
0359 session.commit()
0360 except TimeoutError as error:
0361 print(error)
0362 session.rollback()
0363 raise DatabaseException(str(error))
0364 except DatabaseError as error:
0365 print(error)
0366 session.rollback()
0367 raise DatabaseException(str(error))
0368 except:
0369 session.rollback()
0370 raise
0371 finally:
0372 session.remove()
0373 else:
0374 result = function(*args, **kwargs)
0375 return result
0376 new_funct.__doc__ = function.__doc__
0377 return new_funct
0378
0379
0380 def safe_bulk_update_mappings(session, model, mappings):
0381 if not mappings:
0382 return
0383
0384
0385 mapper = inspect(model)
0386 pk_cols = [col.name for col in mapper.primary_key]
0387
0388 if len(pk_cols) != 1:
0389
0390 return session.bulk_update_mappings(model, mappings)
0391
0392 pk = pk_cols[0]
0393
0394
0395 ids = [m[pk] for m in mappings if pk in m]
0396
0397 if not ids:
0398 return
0399
0400
0401 stmt = (
0402 select(getattr(model, pk))
0403 .where(getattr(model, pk).in_(ids))
0404 .with_for_update(skip_locked=True)
0405 )
0406 claimed = session.execute(stmt).scalars().all()
0407
0408 if not claimed:
0409 return
0410
0411 claimed_set = set(claimed)
0412
0413
0414 filtered = [m for m in mappings if m.get(pk) in claimed_set]
0415
0416 if filtered:
0417 return session.bulk_update_mappings(model, filtered)