Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-10 08:39:07

0001 """
0002 WrappedCursor for a generic database connection proxy
0003 
0004 """
0005 
0006 import os
0007 import re
0008 import warnings
0009 
0010 from pandacommon.pandalogger.PandaLogger import PandaLogger
0011 
0012 from pandaserver.config import panda_config
0013 
0014 warnings.filterwarnings("ignore")
0015 
0016 _logger = PandaLogger().getLogger("WrappedCursor")
0017 
0018 
0019 # extract table names from sql query
0020 def extract_table_names(sql):
0021     table_names = []
0022     for item in re.findall(r" FROM (.+?)(WHERE|$)", sql, flags=re.IGNORECASE):
0023         if "FROM" in item[0]:
0024             table_names += extract_table_names(item[0])
0025         else:
0026             table_strs = item[0].split(",")
0027             table_names += [re.sub(r"\(|\)", "", t.strip().lower()) for table_str in table_strs for t in table_str.split() if t.strip()]
0028     return table_names
0029 
0030 
0031 # convert SQL and parameters in_printf format
0032 def convert_query_in_printf_format(sql, var_dict_list, sql_conv_map):
0033     if sql in sql_conv_map:
0034         sql = sql_conv_map[sql]
0035     else:
0036         old_sql = sql
0037         if var_dict_list:
0038             var_dict = var_dict_list[0]
0039         else:
0040             var_dict = {}
0041         # %
0042         sql = re.sub(r"%", r"%%", sql)
0043         # current date except for being used for interval
0044         if re.search(r"CURRENT_DATE\s*[\+-]", sql, flags=re.IGNORECASE) is None:
0045             sql = re.sub(r"CURRENT_DATE", r"CURRENT_TIMESTAMP", sql, flags=re.IGNORECASE)
0046         # sequence
0047         sql = re.sub(r"""([^ $,()]+).currval""", r"currval('\1')", sql, flags=re.IGNORECASE)
0048         sql = re.sub(r"""([^ $,()]+).nextval""", r"nextval('\1')", sql, flags=re.IGNORECASE)
0049         # returning
0050         sql = re.sub(r"(RETURNING\s+\S+\s+)INTO\s+\S+", r"\1", sql, flags=re.IGNORECASE)
0051         # sub query + rownum
0052         sql = re.sub(r"\)\s+WHERE\s+rownum", r") tmp_sub WHERE rownum", sql, flags=re.IGNORECASE)
0053         # sub query + GROUP BY
0054         if re.search(r"FROM\s+\(\s*SELECT", sql, flags=re.IGNORECASE):
0055             sql = re.sub(r"\)\s+GROUP\s+BY", r") tmp_sub GROUP BY", sql, flags=re.IGNORECASE)
0056         # rownum
0057         sql = re.sub(
0058             r"(WHERE|AND)\s+rownum[^\d:]+(\d+|:[^ \)]+)",
0059             r" LIMIT \2",
0060             sql,
0061             flags=re.IGNORECASE,
0062         )
0063         # NVL
0064         sql = re.sub(r"NVL\(", r"COALESCE(", sql, flags=re.IGNORECASE)
0065         # INSTR
0066         sql = re.sub(r"INSTR\(", r"STRPOS(", sql, flags=re.IGNORECASE)
0067         # random
0068         sql = re.sub(r"DBMS_RANDOM.value", r"RANDOM()", sql, flags=re.IGNORECASE)
0069         # MINUS
0070         sql = re.sub(r" MINUS ", r" EXCEPT ", sql, flags=re.IGNORECASE)
0071         # GENERATE_SERIES
0072         sql = re.sub(
0073             r"\(SELECT\s+level\s+FROM\s+dual\s+CONNECT\s+BY\s+level\s*<=\s*(:[^ \)]+)\)*",
0074             r"GENERATE_SERIES(1,\1)",
0075             sql,
0076             flags=re.IGNORECASE,
0077         )
0078         # dual
0079         sql = re.sub(r"FROM dual", "", sql, flags=re.IGNORECASE)
0080         # NOWAIT
0081         sql = re.sub(r"FOR UPDATE NOWAIT", "FOR UPDATE SKIP LOCKED", sql, flags=re.IGNORECASE)
0082         # json
0083         if "/* use_json_type */" in sql:
0084             # remove \n to make regexp easier
0085             sql = re.sub(r"\n", r" ", sql)
0086             # collect table names
0087             table_names = set(extract_table_names(sql))
0088             checked_items = set()
0089             # look for a.b(.c)*
0090             for item in re.findall(r"(\w+\.\w+\.*\w*)", sql):
0091                 # skip if already checked
0092                 if item in checked_items:
0093                     continue
0094                 checked_items.add(item)
0095                 item_l = item.lower()
0096                 # ignore tables
0097                 if item_l in table_names:
0098                     continue
0099                 # ignore float
0100                 if item.replace(".", "", 1).isdigit():
0101                     continue
0102                 to_skip = False
0103                 new_pat = None
0104                 # check if table.column.field
0105                 for table_name in table_names:
0106                     if item_l.startswith(f"{table_name}."):
0107                         item_body = re.sub(f"^{table_name}" + r"\.", "", item, flags=re.IGNORECASE)
0108                         # no json field
0109                         if item_body.count(".") == 0:
0110                             to_skip = True
0111                             break
0112                         # convert . to ->>''
0113                         new_body = re.sub(r"\.(?P<pat>\w+)", r"->>'\1'", item_body)
0114                         # prepend the table name
0115                         new_pat = ".".join(item.split(".")[: -(1 + item_body.count("."))]) + "." + new_body
0116                         break
0117                 # ignore table.column
0118                 if to_skip:
0119                     continue
0120                 old_pat = item
0121                 # colum.field
0122                 if not new_pat:
0123                     new_pat = re.sub(r"\.(?P<pat>\w+)", r"->>'\1'", item)
0124                 # guess type
0125                 right_vals = re.findall(old_pat + r"\s*[=<>!*]+\s*([\w:\']+)", sql)
0126                 for right_val in right_vals:
0127                     # string
0128                     if "'" in right_val:
0129                         break
0130                     # integer
0131                     if right_val.isdigit():
0132                         new_pat = f"CAST({new_pat} AS integer)"
0133                         break
0134                     # float
0135                     if right_val.replace(".", "", 1).isdigit():
0136                         new_pat = f"CAST({new_pat} AS float)"
0137                         break
0138                     # bind variable
0139                     if right_val.startswith(":"):
0140                         if right_val not in var_dict:
0141                             raise KeyError(f"{right_val} is missing to guess data type")
0142                         if isinstance(var_dict[right_val], int):
0143                             new_pat = f"CAST({new_pat} AS integer)"
0144                             break
0145                         if isinstance(var_dict[right_val], float):
0146                             new_pat = f"CAST({new_pat} AS float)"
0147                             break
0148                 # replace
0149                 print(old_pat, new_pat)
0150                 sql = sql.replace(old_pat, new_pat)
0151             # cache
0152             sql_conv_map[old_sql] = sql
0153     # extract placeholders
0154     params_list = []
0155     items = re.findall(r":[^ $,)\+\-\n]+", sql)
0156     for var_dict in var_dict_list:
0157         params = []
0158         for item in items:
0159             if item not in var_dict:
0160                 raise KeyError(f"{item} is missing in SQL parameters")
0161             params.append(var_dict[item])
0162         params_list.append(params)
0163     # using the printf style syntax
0164     sql = re.sub(":[^ $,)\+\-]+", "%s", sql)
0165     return sql, params_list
0166 
0167 
0168 # proxy
0169 class WrappedCursor(object):
0170     # constructor
0171     def __init__(self, connection):
0172         # connection object
0173         self.conn = connection
0174         # cursor object
0175         self.cur = self.conn.cursor()
0176         # backend
0177         self.backend = panda_config.backend
0178         # statement
0179         self.statement = None
0180         # dump
0181         if hasattr(panda_config, "cursor_dump") and panda_config.cursor_dump:
0182             self.dump = True
0183         else:
0184             self.dump = False
0185         # SQL conversion map
0186         self.sql_conv_map = {}
0187         # executemany
0188         if self.backend == "postgres":
0189             from psycopg2.extras import execute_batch
0190 
0191             self.alt_executemany = execute_batch
0192         else:
0193             self.alt_executemany = None
0194 
0195     # __iter__
0196     def __iter__(self):
0197         return iter(self.cur)
0198 
0199     # serialize
0200     def __str__(self):
0201         return f"WrappedCursor[{self.conn}]"
0202 
0203     # initialize
0204     def initialize(self):
0205         hostname = None
0206         if self.backend == "oracle":
0207             # get hostname
0208             self.execute("SELECT SYS_CONTEXT('USERENV','HOST') FROM dual")
0209             res = self.fetchone()
0210             if res is not None:
0211                 hostname = res[0]
0212             # set TZ
0213             self.execute("ALTER SESSION SET TIME_ZONE='UTC'")
0214             # set DATE format
0215             self.execute("ALTER SESSION SET NLS_DATE_FORMAT='YYYY/MM/DD HH24:MI:SS'")
0216             # set Oracle optimizer version. This is done only temporarily for controlled migration to 19c
0217             self.execute("ALTER SESSION SET optimizer_features_enable='19.1.0'")
0218 
0219         elif self.backend == "postgres":
0220             # disable autocommit
0221             # make sure that always have commit() since any query execution, including SELECT will start a transaction
0222             self.conn.set_session(autocommit=False)
0223             # encoding
0224             self.conn.set_client_encoding("UTF-8")
0225             # TZ
0226             self.execute("SET timezone=0")
0227             # commit to set session params permanently
0228             self.conn.commit()
0229         else:
0230             # get hostname
0231             self.execute("SELECT SUBSTRING_INDEX(USER(),'@',-1)")
0232             res = self.fetchone()
0233             if res is not None:
0234                 hostname = res[0]
0235             # set TZ
0236             self.execute("SET @@SESSION.TIME_ZONE = '+00:00'")
0237             # set DATE format
0238             # self.execute("SET @@SESSION.DATETIME_FORMAT='%%Y/%%m/%%d %%H:%%i:%%s'")
0239             # disable autocommit
0240             self.execute("SET autocommit=0")
0241         return hostname
0242 
0243     # execute query on cursor
0244     def execute(self, sql, varDict=None, cur=None):  # , returningInto=None
0245         if varDict is None:
0246             varDict = {}
0247         if cur is None:
0248             cur = self.cur
0249         ret = None
0250         # schema names
0251         sql = self.change_schema(sql)
0252         # remove `
0253         sql = re.sub("`", "", sql)
0254         if self.backend == "oracle":
0255             ret = cur.execute(sql, varDict)
0256         elif self.backend == "postgres":
0257             if self.dump:
0258                 _logger.debug(f"OLD: {sql} {str(varDict)}")
0259             sql, vars_list = convert_query_in_printf_format(sql, [varDict], self.sql_conv_map)
0260             varList = vars_list[0]
0261             if self.dump:
0262                 _logger.debug(f"NEW: {sql} {str(varList)}")
0263             # counting the number of unlocked rows and raise when nothing available, to mimic NOWAIT behavior in Oracle
0264             if "FOR UPDATE SKIP LOCKED" in sql:
0265                 cur.execute(f"SELECT COUNT(*) FROM ({sql}) t", varList)
0266                 if cur.fetchone()[0] == 0:
0267                     from psycopg2.errors import LockNotAvailable
0268 
0269                     raise LockNotAvailable("could not obtain lock on row")
0270             ret = cur.execute(sql, varList)
0271         elif self.backend == "mysql":
0272             print(f"DEBUG execute : original SQL     {sql} ")
0273             print(f"DEBUG execute : original varDict {varDict} ")
0274             # CURRENT_DATE interval
0275             sql = re.sub(
0276                 "CURRENT_DATE\s*-\s*(\d+|:[^\s\)]+)",
0277                 "DATE_SUB(CURRENT_TIMESTAMP,INTERVAL \g<1> DAY)",
0278                 sql,
0279             )
0280             # CURRENT_DATE
0281             sql = re.sub("CURRENT_DATE", "CURRENT_TIMESTAMP", sql)
0282             # SYSDATE interval
0283             sql = re.sub(
0284                 "SYSDATE\s*-\s*(\d+|:[^\s\)]+)",
0285                 "DATE_SUB(SYSDATE,INTERVAL \g<1> DAY)",
0286                 sql,
0287             )
0288             # SYSDATE
0289             sql = re.sub("SYSDATE", "SYSDATE()", sql)
0290             # EMPTY_CLOB()
0291             sql = re.sub("EMPTY_CLOB\(\)", "''", sql)
0292             # ROWNUM
0293             sql = re.sub("(?i)(AND)*\s*ROWNUM.*(\d+)", " LIMIT \g<2>", sql)
0294             sql = re.sub("(?i)(WHERE)\s*LIMIT\s*(\d+)", " LIMIT \g<2>", sql)
0295             # NOWAIT
0296             sql = re.sub("NOWAIT", "", sql)
0297             # RETURNING INTO
0298             returningInto = None
0299             m = re.search("RETURNING ([^\s]+) INTO ([^\s]+)", sql, re.I)
0300             if m is not None:
0301                 returningInto = [{"returning": m.group(1), "into": m.group(2)}]
0302                 self._returningIntoMySQLpre(returningInto, varDict, cur)
0303                 sql = re.sub(m.group(0), "", sql)
0304             # Addressing sequence
0305             if "INSERT" in sql:
0306                 sql = re.sub("[a-zA-Z\._]+\.nextval", "NULL", sql)
0307             # schema names
0308             sql = re.sub("ATLAS_PANDA\.", panda_config.schemaPANDA + ".", sql)
0309             sql = re.sub("ATLAS_PANDAMETA\.", panda_config.schemaMETA + ".", sql)
0310             sql = re.sub("ATLAS_GRISLI\.", panda_config.schemaGRISLI + ".", sql)
0311             sql = re.sub("ATLAS_PANDAARCH\.", panda_config.schemaPANDAARCH + ".", sql)
0312             # bind variables
0313             newVarDict = {}
0314             # make sure that :prodDBlockToken will not be replaced by %(prodDBlock)sToken
0315             keys = sorted(list(varDict), key=lambda s: -len(str(s)))
0316             for key in keys:
0317                 val = varDict[key]
0318                 if key[0] == ":":
0319                     newKey = key[1:]
0320                     sql = sql.replace(key, "%(" + newKey + ")s")
0321                 else:
0322                     newKey = key
0323                     sql = sql.replace(":" + key, "%(" + newKey + ")s")
0324                 newVarDict[newKey] = val
0325             try:
0326                 # from PanDA monitor it is hard to log queries sometimes, so let's debug with hardcoded query dumps
0327                 import time
0328 
0329                 if os.path.exists("/data/atlpan/oracle/panda/monitor/logs/write_queries.txt"):
0330                     f = open(
0331                         "/data/atlpan/oracle/panda/monitor/logs/mysql_queries_WrappedCursor.txt",
0332                         "a",
0333                     )
0334                     f.write(f"mysql|{str(time.time())}|{str(sql)}|{str(newVarDict)}\n")
0335                     f.close()
0336             except Exception:
0337                 pass
0338             _logger.debug(f"execute : SQL     {sql} ")
0339             _logger.debug(f"execute : varDict {newVarDict} ")
0340             print(f"DEBUG execute : SQL     {sql} ")
0341             print(f"DEBUG execute : varDict {newVarDict} ")
0342             ret = cur.execute(sql, newVarDict)
0343             if returningInto is not None:
0344                 ret = self._returningIntoMySQLpost(returningInto, varDict, cur)
0345         return ret
0346 
0347     def _returningIntoOracle(self, returningInputData, varDict, cur, dryRun=False):
0348         # returningInputData=[{'returning': 'PandaID', 'into': ':newPandaID'}, {'returning': 'row_ID', 'into': ':newRowID'}]
0349         result = ""
0350         if returningInputData is not None:
0351             try:
0352                 valReturning = str(",").join([x["returning"] for x in returningInputData])
0353                 listInto = [x["into"] for x in returningInputData]
0354                 valInto = str(",").join(listInto)
0355                 # assuming that we use RETURNING INTO only for PandaID or row_ID columns
0356                 if not dryRun:
0357                     for x in listInto:
0358                         varDict[x] = cur.var(oracledb.NUMBER)
0359                 result = f" RETURNING {valReturning} INTO {valInto} "
0360             except Exception:
0361                 pass
0362         return result
0363 
0364     def _returningIntoMySQLpre(self, returningInputData, varDict, cur):
0365         # returningInputData=[{'returning': 'PandaID', 'into': ':newPandaID'}, {'returning': 'row_ID', 'into': ':newRowID'}]
0366         if returningInputData is not None:
0367             try:
0368                 # get rid of "returning into" items in varDict
0369                 listInto = [x["into"] for x in returningInputData]
0370                 for x in listInto:
0371                     try:
0372                         del varDict[x]
0373                     except KeyError:
0374                         pass
0375                 if len(returningInputData) == 1:
0376                     # and set original value in varDict to null, let auto_increment do the work
0377                     listReturning = [x["returning"] for x in returningInputData]
0378                     for x in listReturning:
0379                         varDict[":" + x] = None
0380             except Exception:
0381                 pass
0382 
0383     def _returningIntoMySQLpost(self, returningInputData, varDict, cur):
0384         # returningInputData=[{'returning': 'PandaID', 'into': ':newPandaID'}, {'returning': 'row_ID', 'into': ':newRowID'}]
0385         result = int(0)
0386         if len(returningInputData) == 1:
0387             ret = self.cur.execute(""" SELECT LAST_INSERT_ID() """)
0388             (result,) = self.cur.fetchone()
0389             if returningInputData is not None:
0390                 try:
0391                     # update of "returning into" items in varDict
0392                     listInto = [x["into"] for x in returningInputData]
0393                     for x in listInto:
0394                         try:
0395                             varDict[x] = int(result)
0396                         except KeyError:
0397                             pass
0398                 except Exception:
0399                     pass
0400         return result
0401 
0402     # fetchall
0403     def fetchall(self):
0404         return self.cur.fetchall()
0405 
0406     # fetchmany
0407     def fetchmany(self, arraysize=1000):
0408         self.cur.arraysize = arraysize
0409         return self.cur.fetchmany()
0410 
0411     # fetchall
0412     def fetchone(self):
0413         return self.cur.fetchone()
0414 
0415     # var
0416     def var(self, dataType, *args, **kwargs):
0417         if self.backend == "mysql":
0418             return apply(dataType, [0])
0419         elif self.backend == "postgres":
0420             return None
0421         else:
0422             return self.cur.var(dataType, *args, **kwargs)
0423 
0424     # get value
0425     def getvalue(self, dataItem):
0426         if self.backend == "mysql":
0427             return dataItem
0428         elif self.backend == "postgres":
0429             return self.cur.fetchone()[0]
0430         else:
0431             return dataItem.getvalue()
0432 
0433     # next
0434     def next(self):
0435         if self.backend == "mysql":
0436             return self.cur.fetchone()
0437         return self.cur.next()
0438 
0439     # close
0440     def close(self):
0441         return self.cur.close()
0442 
0443     # prepare
0444     def prepare(self, statement):
0445         self.statement = statement
0446 
0447     # executemany
0448     def executemany(self, sql, params):
0449         if sql is None:
0450             sql = self.statement
0451         sql = self.change_schema(sql)
0452         if self.backend == "postgres":
0453             sql, vars_list = convert_query_in_printf_format(sql, params, self.sql_conv_map)
0454             self.alt_executemany(self.cur, sql, vars_list)
0455         else:
0456             self.cur.executemany(sql, params)
0457 
0458     # get_description
0459     @property
0460     def description(self):
0461         return self.cur.description
0462 
0463     # rowcount
0464     @property
0465     def rowcount(self):
0466         return self.cur.rowcount
0467 
0468     # arraysize
0469     @property
0470     def arraysize(self):
0471         return self.cur.arraysize
0472 
0473     @arraysize.setter
0474     def arraysize(self, val):
0475         self.cur.arraysize = val
0476 
0477     # change schema
0478     def change_schema(self, sql):
0479         if panda_config.schemaPANDA != "ATLAS_PANDA":
0480             sql = re.sub("ATLAS_PANDA\.", panda_config.schemaPANDA + ".", sql)
0481         if panda_config.schemaMETA != "ATLAS_PANDAMETA":
0482             sql = re.sub("ATLAS_PANDAMETA\.", panda_config.schemaMETA + ".", sql)
0483         if panda_config.schemaGRISLI != "ATLAS_GRISLI":
0484             sql = re.sub("ATLAS_GRISLI\.", panda_config.schemaGRISLI + ".", sql)
0485         if panda_config.schemaPANDAARCH != "ATLAS_PANDAARCH":
0486             sql = re.sub("ATLAS_PANDAARCH\.", panda_config.schemaPANDAARCH + ".", sql)
0487         return sql