Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-10 08:39:08

0001 import copy
0002 import json
0003 import os
0004 import re
0005 import shlex
0006 import sys
0007 import tarfile
0008 import tempfile
0009 import traceback
0010 
0011 import requests
0012 from idds.atlas.workflowv2.atlaslocalpandawork import ATLASLocalPandaWork
0013 from idds.atlas.workflowv2.atlaspandawork import ATLASPandaWork
0014 from idds.workflowv2.workflow import AndCondition, Condition, OrCondition, Workflow
0015 from pandaclient import PhpoScript, PrunScript
0016 from pandacommon.pandalogger.LogWrapper import LogWrapper
0017 from pandacommon.pandalogger.PandaLogger import PandaLogger
0018 from ruamel.yaml import YAML
0019 
0020 # from pandaserver.srvcore.CoreUtils import clean_user_id
0021 from pandaserver.workflow import pcwl_utils, workflow_utils
0022 from pandaserver.workflow.snakeparser import Parser
0023 
0024 # supported workflow description languages
0025 SUPPORTED_WORKFLOW_LANGUAGES = ["cwl", "snakemake"]
0026 
0027 # main logger
0028 logger = PandaLogger().getLogger(__name__.split(".")[-1])
0029 
0030 
0031 # ==============================================================================
0032 # Native PanDA workflow functions
0033 # ==============================================================================
0034 
0035 
0036 def json_serialize_default(obj):
0037     """
0038     Default JSON serializer for non-serializable objects of Node object
0039 
0040     Args:
0041         obj (Any): Object to serialize
0042 
0043     Returns:
0044         Any: JSON serializable object
0045     """
0046     # convert set to list
0047     if isinstance(obj, set):
0048         return list(obj)
0049     elif isinstance(obj, workflow_utils.Node):
0050         return obj.id
0051     return obj
0052 
0053 
0054 def parse_raw_request(sandbox_url, log_token, user_name, raw_request_dict) -> tuple[bool, bool, dict]:
0055     """
0056     Parse raw request with files in sandbox into workflow definition
0057 
0058     Args:
0059         sandbox_url (str): URL to download sandbox
0060         log_token (str): Log token
0061         user_name (str): User name
0062         raw_request_dict (dict): Raw request dictionary
0063 
0064     Returns:
0065         bool: Whether the parsing is successful
0066         bool: Whether the failure is fatal
0067         dict: Workflow definition dictionary
0068     """
0069     tmp_log = LogWrapper(logger, log_token)
0070     is_ok = True
0071     is_fatal = False
0072     # request_id = None
0073     workflow_definition_dict = dict()
0074 
0075     def _is_within_directory(base_dir: str, target_path: str) -> bool:
0076         abs_base_dir = os.path.abspath(base_dir)
0077         abs_target_path = os.path.abspath(target_path)
0078         return os.path.commonpath([abs_base_dir, abs_target_path]) == abs_base_dir
0079 
0080     def _safe_extract_tar_gz(tar_path: str, extract_dir: str):
0081         with tarfile.open(tar_path, mode="r:gz") as tar:
0082             members = tar.getmembers()
0083             for member in members:
0084                 member_name = member.name
0085                 normalized_name = os.path.normpath(member_name)
0086                 # security checks for tar member name
0087                 if os.path.isabs(member_name):
0088                     raise ValueError(f"absolute path in tar member is not allowed: {member_name}")
0089                 if normalized_name in ("", ".", "..") or normalized_name.startswith(".." + os.path.sep):
0090                     raise ValueError(f"path traversal in tar member is not allowed: {member_name}")
0091                 if member.issym() or member.islnk():
0092                     raise ValueError(f"links in tar archive are not allowed: {member_name}")
0093                 if member.ischr() or member.isblk() or member.isfifo():
0094                     raise ValueError(f"special file in tar archive is not allowed: {member_name}")
0095                 # check that the extraction target is within the extract_dir
0096                 extraction_target = os.path.join(extract_dir, normalized_name)
0097                 if not _is_within_directory(extract_dir, extraction_target):
0098                     raise ValueError(f"tar member extracts outside target directory: {member_name}")
0099             # all checks passed, safe to extract
0100             tar.extractall(path=extract_dir, members=members)
0101 
0102     try:
0103         # use an isolated temp dir without changing process cwd
0104         with tempfile.TemporaryDirectory() as tmp_dirname:
0105             # download sandbox
0106             tmp_log.info(f"downloading sandbox from {sandbox_url}")
0107             with requests.get(sandbox_url, allow_redirects=True, stream=True) as r:
0108                 if r.status_code == 400:
0109                     tmp_log.error("not found")
0110                     is_fatal = True
0111                     is_ok = False
0112                 elif r.status_code != 200:
0113                     tmp_log.error(f"bad HTTP response {r.status_code}")
0114                     is_ok = False
0115                 # validate sandbox filename
0116                 sandbox_name = raw_request_dict.get("sandbox")
0117                 if is_ok:
0118                     if not isinstance(sandbox_name, str):
0119                         tmp_log.error("sandbox filename is missing or not a string")
0120                         is_fatal = True
0121                         is_ok = False
0122                     else:
0123                         # sandbox filename must not contain any path separators
0124                         seps = [os.path.sep]
0125                         if os.path.altsep:
0126                             seps.append(os.path.altsep)
0127                         if any(sep in sandbox_name for sep in seps):
0128                             tmp_log.error("sandbox filename must not contain path separators")
0129                             is_fatal = True
0130                             is_ok = False
0131                         else:
0132                             sandbox_name = os.path.basename(sandbox_name)
0133                 # extract sandbox
0134                 if is_ok:
0135                     sandbox_path = os.path.join(tmp_dirname, sandbox_name)
0136                     with open(sandbox_path, "wb") as fs:
0137                         for chunk in r.raw.stream(1024, decode_content=False):
0138                             if chunk:
0139                                 fs.write(chunk)
0140                         fs.close()
0141                         try:
0142                             _safe_extract_tar_gz(sandbox_path, tmp_dirname)
0143                         except Exception as e:
0144                             dump_str = f"failed to extract {sandbox_name}: {traceback.format_exc()}"
0145                             tmp_log.error(dump_str)
0146                             is_fatal = True
0147                             is_ok = False
0148                 # parse workflow files
0149                 if is_ok:
0150                     tmp_log.info("parse workflow")
0151                     workflow_name = None
0152                     if (wf_lang := raw_request_dict["language"]) in SUPPORTED_WORKFLOW_LANGUAGES:
0153                         if wf_lang == "cwl":
0154                             workflow_name = raw_request_dict.get("workflow_name")
0155                             workflow_spec_file = os.path.join(tmp_dirname, raw_request_dict["workflowSpecFile"])
0156                             workflow_input_file = os.path.join(tmp_dirname, raw_request_dict["workflowInputFile"])
0157                             nodes, root_in = pcwl_utils.parse_workflow_file(workflow_spec_file, tmp_log)
0158                             with open(workflow_input_file) as workflow_input:
0159                                 yaml = YAML(typ="safe", pure=True)
0160                                 data = yaml.load(workflow_input)
0161                         elif wf_lang == "snakemake":
0162                             workflow_spec_file = os.path.join(tmp_dirname, raw_request_dict["workflowSpecFile"])
0163                             parser = Parser(workflow_spec_file, logger=tmp_log)
0164                             nodes, root_in = parser.parse_nodes()
0165                             data = dict()
0166                         # resolve nodes
0167                         s_id, t_nodes, nodes = workflow_utils.resolve_nodes(nodes, root_in, data, 0, set(), raw_request_dict["outDS"], tmp_log)
0168                         workflow_utils.set_workflow_outputs(nodes)
0169                         id_node_map = workflow_utils.get_node_id_map(nodes)
0170                         [node.resolve_params(raw_request_dict["taskParams"], id_node_map) for node in nodes]
0171                         dump_str = "the description was internally converted as follows\n" + workflow_utils.dump_nodes(nodes)
0172                         tmp_log.info(dump_str)
0173                         for node in nodes:
0174                             s_check, o_check = node.verify()
0175                             tmp_str = f"Verification failure in ID:{node.id} {o_check}"
0176                             if not s_check:
0177                                 tmp_log.error(tmp_str)
0178                                 dump_str += tmp_str
0179                                 dump_str += "\n"
0180                                 is_fatal = True
0181                                 is_ok = False
0182                     else:
0183                         dump_str = f"{wf_lang} is not supported to describe the workflow"
0184                         tmp_log.error(dump_str)
0185                         is_fatal = True
0186                         is_ok = False
0187                 # genertate workflow definition
0188                 if is_ok:
0189                     # root inputs
0190                     root_inputs_dict = dict()
0191                     for k in root_in:
0192                         kk = k.split("#")[-1]
0193                         if kk in data:
0194                             root_inputs_dict[k] = data[kk]
0195                     # root outputs
0196                     root_outputs_dict = dict()
0197                     nodes_list = []
0198                     # nodes
0199                     for node in nodes:
0200                         nodes_list.append(vars(node))
0201                         if node.is_tail:
0202                             root_outputs_dict.update(node.outputs)
0203                             for out_val in root_outputs_dict.values():
0204                                 out_val["output_types"] = node.output_types
0205                     # workflow definition
0206                     workflow_definition_dict = {
0207                         "workflow_name": workflow_name,
0208                         "user_name": user_name,
0209                         "root_inputs": root_inputs_dict,
0210                         "root_outputs": root_outputs_dict,
0211                         "nodes": nodes_list,
0212                     }
0213     except Exception as e:
0214         is_ok = False
0215         is_fatal = True
0216         tmp_log.error(f"failed to run with {str(e)} {traceback.format_exc()}")
0217 
0218     # with tempfile.NamedTemporaryFile(delete=False, mode="w") as tmp_json:
0219     #     json.dump([is_ok, is_fatal, request_id, tmp_log.dumpToString()], tmp_json)
0220     #     print(tmp_json.name)
0221 
0222     return is_ok, is_fatal, workflow_definition_dict