Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-09 07:58:19

0001 #!/usr/bin/env python
0002 #
0003 # Licensed under the Apache License, Version 2.0 (the "License");
0004 # You may not use this file except in compliance with the License.
0005 # You may obtain a copy of the License at
0006 # http://www.apache.org/licenses/LICENSE-2.0OA
0007 #
0008 # Authors:
0009 # - Wen Guan, <wen.guan@cern.ch>, 2019
0010 
0011 """
0012 session decorator
0013 
0014 Borrowed from:
0015 https://github.com/rucio/rucio/blob/master/lib/rucio/db/sqla/session.py
0016 """
0017 
0018 import logging
0019 import sys
0020 
0021 from functools import wraps
0022 from inspect import isgeneratorfunction
0023 from retrying import retry
0024 from threading import Lock
0025 from os.path import basename
0026 
0027 from sqlalchemy import create_engine, event, inspect, select
0028 from sqlalchemy.exc import DatabaseError, DisconnectionError, OperationalError, TimeoutError
0029 from sqlalchemy.ext.declarative import declarative_base
0030 from sqlalchemy.orm import sessionmaker, scoped_session
0031 
0032 from idds.common.config import config_get, config_has_option
0033 from idds.common.exceptions import IDDSException, DatabaseException
0034 
0035 
0036 LOG = logging.getLogger(__name__)
0037 
0038 DATABASE_SECTION = 'database'
0039 
0040 BASE = declarative_base()
0041 
0042 DEFAULT_SCHEMA_NAME = None
0043 if config_has_option(DATABASE_SECTION, 'schema'):
0044     DEFAULT_SCHEMA_NAME = config_get(DATABASE_SECTION, 'schema')
0045     if DEFAULT_SCHEMA_NAME:
0046         BASE.metadata.schema = DEFAULT_SCHEMA_NAME
0047 
0048 _MAKER, _ENGINE, _LOCK = None, None, Lock()
0049 
0050 
0051 def _fk_pragma_on_connect(dbapi_con, con_record):
0052     # Hack for previous versions of sqlite3
0053     try:
0054         dbapi_con.execute('pragma foreign_keys=ON')
0055     except AttributeError:
0056         pass
0057 
0058 
0059 def mysql_ping_listener(dbapi_conn, connection_rec, connection_proxy):
0060     """
0061     Ensures that MySQL connections checked out of the
0062     pool are alive.
0063     Borrowed from:
0064     http://groups.google.com/group/sqlalchemy/msg/a4ce563d802c929f
0065     :param dbapi_conn: DBAPI connection
0066     :param connection_rec: connection record
0067     :param connection_proxy: connection proxy
0068     """
0069 
0070     try:
0071         dbapi_conn.cursor().execute('select 1')
0072     except dbapi_conn.OperationalError as ex:
0073         if ex.args[0] in (2006, 2013, 2014, 2045, 2055):
0074             msg = 'Got mysql server has gone away: %s' % ex
0075             raise DisconnectionError(msg)
0076         else:
0077             raise
0078 
0079 
0080 def mysql_convert_decimal_to_float(dbapi_conn, connection_rec):
0081     """
0082     The default datatype returned by mysql-python for numerics is decimal.Decimal.
0083     This type cannot be serialised to JSON, therefore we need to autoconvert to floats.
0084     Even worse, there's two types of decimals created by the MySQLdb driver, so we must
0085     override both.
0086     :param dbapi_conn: DBAPI connection
0087     :param connection_rec: connection record
0088     """
0089 
0090     try:
0091         import MySQLdb.converters  # pylint: disable=import-error
0092         from MySQLdb.constants import FIELD_TYPE  # pylint: disable=import-error
0093     except:  # noqa: B901
0094         raise IDDSException('Trying to use MySQL without mysql-python installed!')
0095     conv = MySQLdb.converters.conversions.copy()
0096     conv[FIELD_TYPE.DECIMAL] = float
0097     conv[FIELD_TYPE.NEWDECIMAL] = float
0098     dbapi_conn.converter = conv
0099 
0100 
0101 def psql_convert_decimal_to_float(dbapi_conn, connection_rec):
0102     """
0103     The default datatype returned by psycopg2 for numerics is decimal.Decimal.
0104     This type cannot be serialised to JSON, therefore we need to autoconvert to floats.
0105     :param dbapi_conn: DBAPI connection
0106     :param connection_rec: connection record
0107     """
0108 
0109     try:
0110         import psycopg2.extensions  # pylint: disable=import-error
0111     except:  # noqa: B901
0112         raise IDDSException('Trying to use PostgreSQL without psycopg2 installed!')
0113 
0114     DEC2FLOAT = psycopg2.extensions.new_type(psycopg2.extensions.DECIMAL.values,
0115                                              'DEC2FLOAT',
0116                                              lambda value, curs: float(value) if value is not None else None)
0117     psycopg2.extensions.register_type(DEC2FLOAT)
0118 
0119 
0120 def my_on_connect(dbapi_con, connection_record):
0121     """ Adds information to track performance and ressource by module.
0122         Info are recorded in the V$SESSION and V$SQLAREA views.
0123     """
0124     caller = basename(sys.argv[0])
0125     dbapi_con.clientinfo = caller
0126     dbapi_con.client_identifier = caller
0127     dbapi_con.module = caller
0128     dbapi_con.action = caller
0129 
0130 
0131 def get_engine(echo=True):
0132     """ Creates a engine to a specific database.
0133         :returns: engine
0134     """
0135     global _ENGINE
0136     if not _ENGINE:
0137         sql_connection = config_get(DATABASE_SECTION, 'default')
0138         config_params = [('pool_size', int), ('max_overflow', int), ('pool_timeout', int),
0139                          ('pool_recycle', int), ('echo', int), ('echo_pool', str),
0140                          ('pool_reset_on_return', str), ('use_threadlocal', int)]
0141         # params = {'max_identifier_length': 128}
0142         params = {}
0143         for param, param_type in config_params:
0144             try:
0145                 params[param] = param_type(config_get(DATABASE_SECTION, param))
0146             except:  # noqa: B901
0147                 pass
0148         params['execution_options'] = {'schema_translate_map': {None: DEFAULT_SCHEMA_NAME}}
0149         if 'oracledb' in sql_connection:
0150             try:
0151                 import oracledb  # pylint: disable=import-error
0152                 oracledb.init_oracle_client()
0153                 params['thick_mode'] = True
0154             except Exception as err:
0155                 LOG.warning('Could not start Oracle thick mode; falling back to thin: %s', err)
0156 
0157         _ENGINE = create_engine(sql_connection, **params)
0158 
0159         if 'mysql' in sql_connection:
0160             event.listen(_ENGINE, 'checkout', mysql_ping_listener)
0161             event.listen(_ENGINE, 'connect', mysql_convert_decimal_to_float)
0162         elif 'postgresql' in sql_connection:
0163             event.listen(_ENGINE, 'connect', psql_convert_decimal_to_float)
0164         elif 'sqlite' in sql_connection:
0165             event.listen(_ENGINE, 'connect', _fk_pragma_on_connect)
0166         elif 'oracle' in sql_connection:
0167             event.listen(_ENGINE, 'connect', my_on_connect)
0168     assert _ENGINE
0169     return _ENGINE
0170 
0171 
0172 def get_dump_engine(echo=False):
0173     """ Creates a dump engine to a specific database.
0174         :returns: engine """
0175 
0176     statements = list()
0177 
0178     def dump(sql, *multiparams, **params):
0179         statement = str(sql.compile(dialect=engine.dialect))
0180         if statement in statements:
0181             return
0182         statements.append(statement)
0183         if statement.endswith(')\n\n'):
0184             if engine.dialect.name == 'oracle':
0185                 print(statement.replace(')\n\n', ') PCTFREE 0;\n'))
0186             else:
0187                 print(statement.replace(')\n\n', ');\n'))
0188         elif statement.endswith(')'):
0189             print(statement.replace(')', ');\n'))
0190         else:
0191             print(statement)
0192     sql_connection = config_get(DATABASE_SECTION, 'default')
0193 
0194     engine = create_engine(sql_connection, echo=echo, strategy='mock', executor=dump)
0195     return engine
0196 
0197 
0198 def get_maker():
0199     """
0200         Return a SQLAlchemy sessionmaker.
0201         May assign __MAKER if not already assigned.
0202     """
0203     global _MAKER, _ENGINE   # noqa: F824
0204     assert _ENGINE
0205     if not _MAKER:
0206         _MAKER = sessionmaker(bind=_ENGINE, autocommit=False, autoflush=False, expire_on_commit=True)
0207     return _MAKER
0208 
0209 
0210 def get_session():
0211     """ Creates a session to a specific database, assumes that schema already in place.
0212         :returns: session
0213     """
0214     global _MAKER, _LOCK  # noqa: F824
0215     if not _MAKER:
0216         _LOCK.acquire()
0217         try:
0218             get_engine()
0219             get_maker()
0220         finally:
0221             _LOCK.release()
0222     assert _MAKER
0223     session = scoped_session(_MAKER)
0224     session.schema = DEFAULT_SCHEMA_NAME
0225     return session
0226 
0227 
0228 def retry_if_db_connection_error(exception):
0229     """Return True if error in connecting to db."""
0230     # print(exception)
0231     if isinstance(exception, OperationalError):
0232         conn_err_codes = ('2002', '2003', '2006',  # MySQL
0233                           'ORA-00028',  # Oracle session has been killed
0234                           'ORA-01012',  # not logged on
0235                           'ORA-03113',  # end-of-file on communication channel
0236                           'ORA-03114',  # not connected to ORACLE
0237                           'ORA-03135',  # connection lost contact
0238                           'ORA-25408',)  # can not safely replay call
0239         for err_code in conn_err_codes:
0240             if exception.args[0].find(err_code) != -1:
0241                 return True
0242     if isinstance(exception, DatabaseException):
0243         conn_err_codes = ('server closed the connection unexpectedly',
0244                           'closed the connection',)
0245         for err_code in conn_err_codes:
0246             if str(exception.args[0]).find(err_code) != -1:
0247                 return True
0248     return False
0249 
0250 
0251 def read_session(function):
0252     '''
0253     decorator that set the session variable to use inside a function.
0254     With that decorator it's possible to use the session variable like if a global variable session is declared.
0255     session is a sqlalchemy session, and you can get one calling get_session().
0256     This is useful if only SELECTs and the like are being done; anything involving
0257     INSERTs, UPDATEs etc should use transactional_session.
0258     '''
0259     @retry(retry_on_exception=retry_if_db_connection_error,
0260            wait_fixed=0.5,
0261            stop_max_attempt_number=2,
0262            wrap_exception=False)
0263     @wraps(function)
0264     def new_funct(*args, **kwargs):
0265 
0266         if isgeneratorfunction(function):
0267             raise IDDSException('read_session decorator should not be used with generator. Use stream_session instead.')
0268 
0269         if not kwargs.get('session'):
0270             session = get_session()
0271             try:
0272                 kwargs['session'] = session
0273                 result = function(*args, **kwargs)
0274                 session.remove()
0275                 return result
0276             except TimeoutError as error:
0277                 session.rollback()  # pylint: disable=maybe-no-member
0278                 raise DatabaseException(str(error))
0279             except DatabaseError as error:
0280                 session.rollback()  # pylint: disable=maybe-no-member
0281                 raise DatabaseException(str(error))
0282             except:  # noqa: B901
0283                 session.rollback()  # pylint: disable=maybe-no-member
0284                 raise
0285             finally:
0286                 session.remove()
0287         try:
0288             return function(*args, **kwargs)
0289         except:  # noqa: B901
0290             raise
0291     new_funct.__doc__ = function.__doc__
0292     return new_funct
0293 
0294 
0295 def stream_session(function):
0296     '''
0297     decorator that set the session variable to use inside a function.
0298     With that decorator it's possible to use the session variable like if a global variable session is declared.
0299     session is a sqlalchemy session, and you can get one calling get_session().
0300     This is useful if only SELECTs and the like are being done; anything involving
0301     INSERTs, UPDATEs etc should use transactional_session.
0302     '''
0303     @retry(retry_on_exception=retry_if_db_connection_error,
0304            wait_fixed=0.5,
0305            stop_max_attempt_number=2,
0306            wrap_exception=False)
0307     @wraps(function)
0308     def new_funct(*args, **kwargs):
0309 
0310         if not isgeneratorfunction(function):
0311             raise IDDSException('stream_session decorator should be used only with generator. Use read_session instead.')
0312 
0313         if not kwargs.get('session'):
0314             session = get_session()
0315             try:
0316                 kwargs['session'] = session
0317                 for row in function(*args, **kwargs):
0318                     yield row
0319             except TimeoutError as error:
0320                 print(error)
0321                 session.rollback()  # pylint: disable=maybe-no-member
0322                 raise DatabaseException(str(error))
0323             except DatabaseError as error:
0324                 print(error)
0325                 session.rollback()  # pylint: disable=maybe-no-member
0326                 raise DatabaseException(str(error))
0327             except:  # noqa: B901
0328                 session.rollback()  # pylint: disable=maybe-no-member
0329                 raise
0330             finally:
0331                 session.remove()
0332         else:
0333             try:
0334                 for row in function(*args, **kwargs):
0335                     yield row
0336             except:  # noqa: B901
0337                 raise
0338     new_funct.__doc__ = function.__doc__
0339     return new_funct
0340 
0341 
0342 def transactional_session(function):
0343     '''
0344     decorator that set the session variable to use inside a function.
0345     With that decorator it's possible to use the session variable like if a global variable session is declared.
0346     session is a sqlalchemy session, and you can get one calling get_session().
0347     '''
0348     @retry(retry_on_exception=retry_if_db_connection_error,
0349            wait_fixed=0.5,
0350            stop_max_attempt_number=2,
0351            wrap_exception=False)
0352     @wraps(function)
0353     def new_funct(*args, **kwargs):
0354         if not kwargs.get('session'):
0355             session = get_session()
0356             try:
0357                 kwargs['session'] = session
0358                 result = function(*args, **kwargs)
0359                 session.commit()  # pylint: disable=maybe-no-member
0360             except TimeoutError as error:
0361                 print(error)
0362                 session.rollback()  # pylint: disable=maybe-no-member
0363                 raise DatabaseException(str(error))
0364             except DatabaseError as error:
0365                 print(error)
0366                 session.rollback()  # pylint: disable=maybe-no-member
0367                 raise DatabaseException(str(error))
0368             except:  # noqa: B901
0369                 session.rollback()  # pylint: disable=maybe-no-member
0370                 raise
0371             finally:
0372                 session.remove()  # pylint: disable=maybe-no-member
0373         else:
0374             result = function(*args, **kwargs)
0375         return result
0376     new_funct.__doc__ = function.__doc__
0377     return new_funct
0378 
0379 
0380 def safe_bulk_update_mappings(session, model, mappings):
0381     if not mappings:
0382         return
0383 
0384     # detect the primary key column name(s)
0385     mapper = inspect(model)
0386     pk_cols = [col.name for col in mapper.primary_key]
0387 
0388     if len(pk_cols) != 1:
0389         # raise ValueError("safe_bulk_update only supports single-column primary keys")
0390         return session.bulk_update_mappings(model, mappings)
0391 
0392     pk = pk_cols[0]
0393 
0394     # extract ids from mappings
0395     ids = [m[pk] for m in mappings if pk in m]
0396 
0397     if not ids:
0398         return
0399 
0400     # claim only rows that are not locked
0401     stmt = (
0402         select(getattr(model, pk))
0403         .where(getattr(model, pk).in_(ids))
0404         .with_for_update(skip_locked=True)
0405     )
0406     claimed = session.execute(stmt).scalars().all()
0407 
0408     if not claimed:
0409         return
0410 
0411     claimed_set = set(claimed)
0412 
0413     # filter mappings to only claimed rows
0414     filtered = [m for m in mappings if m.get(pk) in claimed_set]
0415 
0416     if filtered:
0417         return session.bulk_update_mappings(model, filtered)