Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-15 08:35:37

0001 import copy
0002 import json
0003 import re
0004 import shlex
0005 import tempfile
0006 
0007 from pandaclient import PhpoScript, PrunScript
0008 
0009 
0010 # merge job parameters
0011 def merge_job_params(base_params, io_params):
0012     new_params = []
0013     # remove exec stuff from base_params
0014     exec_start = False
0015     end_exec = False
0016     for tmp_item in base_params:
0017         if tmp_item["type"] == "constant" and tmp_item["value"].startswith("-p "):
0018             exec_start = True
0019             continue
0020         if exec_start:
0021             if end_exec:
0022                 pass
0023             elif tmp_item["type"] == "constant" and "padding" not in tmp_item:
0024                 end_exec = True
0025                 continue
0026         if exec_start and not end_exec:
0027             continue
0028         new_params.append(tmp_item)
0029     # take exec and IO stuff from io_params
0030     exec_start = False
0031     for tmp_item in io_params:
0032         if tmp_item["type"] == "constant" and tmp_item["value"] == "__delimiter__":
0033             exec_start = True
0034             continue
0035         # ignore archive option
0036         if tmp_item["type"] == "constant" and tmp_item["value"].startswith("-a "):
0037             continue
0038         if not exec_start:
0039             continue
0040         new_params.append(tmp_item)
0041     return new_params
0042 
0043 
0044 # DAG vertex
0045 class Node(object):
0046     def __init__(self, id, node_type, data, is_leaf, name):
0047         self.id = id
0048         self.type = node_type
0049         self.data = data
0050         self.is_leaf = is_leaf
0051         self.is_tail = False
0052         self.is_head = False
0053         self.inputs = {}
0054         self.outputs = {}
0055         self.output_types = []
0056         self.scatter = None
0057         self.parents = set()
0058         self.name = name
0059         self.sub_nodes = set()
0060         self.root_inputs = None
0061         self.task_params = None
0062         self.condition = None
0063         self.is_workflow_output = False
0064         self.loop = False
0065         self.in_loop = False
0066         self.upper_root_inputs = None
0067 
0068     def add_parent(self, id):
0069         self.parents.add(id)
0070 
0071     # set real input values
0072     def set_input_value(self, key, src_key, src_value):
0073         # replace the value with a list of parameter names and indexes if value is a list,
0074         # and src and dst are looping params
0075         if isinstance(src_value, list):
0076             src_loop_param_name = self.get_loop_param_name(src_key)
0077             loop_params = self.get_loop_param_name(key.split("/")[-1]) is not None and src_loop_param_name is not None
0078             if loop_params:
0079                 src_value = [{"src": src_loop_param_name, "idx": i} for i in range(len(src_value))]
0080         # resolve values
0081         if isinstance(self.inputs[key]["source"], list):
0082             self.inputs[key].setdefault("value", copy.copy(self.inputs[key]["source"]))
0083             tmp_list = []
0084             for k in self.inputs[key]["value"]:
0085                 if k == src_key:
0086                     tmp_list.append(src_value)
0087                 else:
0088                     tmp_list.append(k)
0089             self.inputs[key]["value"] = tmp_list
0090         else:
0091             self.inputs[key]["value"] = src_value
0092 
0093     # convert inputs to dict inputs
0094     def convert_dict_inputs(self, skip_suppressed=False):
0095         data = {}
0096         for k, v in self.inputs.items():
0097             if skip_suppressed and "suppressed" in v and v["suppressed"]:
0098                 continue
0099             y_name = k.split("/")[-1]
0100             if "value" in v:
0101                 data[y_name] = v["value"]
0102             elif "default" in v:
0103                 data[y_name] = v["default"]
0104             else:
0105                 raise ReferenceError(f"{k} is not resolved")
0106         return data
0107 
0108     # convert outputs to set
0109     def convert_set_outputs(self):
0110         data = set()
0111         for k, v in self.outputs.items():
0112             if "value" in v:
0113                 data.add(v["value"])
0114         return data
0115 
0116     # verify
0117     def verify(self):
0118         if self.is_leaf:
0119             dict_inputs = self.convert_dict_inputs(True)
0120             # check input
0121             for k, v in dict_inputs.items():
0122                 if v is None:
0123                     return False, f"{k} is unresolved"
0124             # check args
0125             for k in ["opt_exec", "opt_args"]:
0126                 test_str = dict_inputs.get(k)
0127                 if test_str:
0128                     m = re.search(r"%{[A-Z]*DS(\d+|\*)}", test_str)
0129                     if m:
0130                         return False, f"{m.group(0)} is unresolved in {k}"
0131             if self.type == "prun":
0132                 for k in dict_inputs:
0133                     if k not in [
0134                         "opt_inDS",
0135                         "opt_inDsType",
0136                         "opt_secondaryDSs",
0137                         "opt_secondaryDsTypes",
0138                         "opt_args",
0139                         "opt_exec",
0140                         "opt_useAthenaPackages",
0141                         "opt_containerImage",
0142                     ]:
0143                         return False, f"unknown input parameter {k} for {self.type}"
0144             elif self.type in ["junction", "reana"]:
0145                 for k in dict_inputs:
0146                     if k not in [
0147                         "opt_inDS",
0148                         "opt_inDsType",
0149                         "opt_args",
0150                         "opt_exec",
0151                         "opt_containerImage",
0152                     ]:
0153                         return False, f"unknown input parameter {k} for {self.type}"
0154             elif self.type == "phpo":
0155                 for k in dict_inputs:
0156                     if k not in ["opt_trainingDS", "opt_trainingDsType", "opt_args"]:
0157                         return False, f"unknown input parameter {k} for {self.type}"
0158             elif self.type == "gitlab":
0159                 for k in dict_inputs:
0160                     if k not in [
0161                         "opt_inDS",
0162                         "opt_args",
0163                         "opt_api",
0164                         "opt_projectID",
0165                         "opt_ref",
0166                         "opt_triggerToken",
0167                         "opt_accessToken",
0168                         "opt_site",
0169                         "opt_input_location",
0170                     ]:
0171                         return False, f"unknown input parameter {k} for {self.type}"
0172         elif self.type == "workflow":
0173             reserved_params = ["i"]
0174             loop_global, workflow_global = self.get_global_parameters()
0175             if loop_global:
0176                 for k in reserved_params:
0177                     if k in loop_global:
0178                         return (
0179                             False,
0180                             f"parameter {k} cannot be used since it is reserved by the system",
0181                         )
0182         return True, ""
0183 
0184     # string representation
0185     def __str__(self):
0186         outstr = f"ID:{self.id} Name:{self.name} Type:{self.type}\n"
0187         outstr += f"  Parent:{','.join([str(p) for p in self.parents])}\n"
0188         outstr += "  Input:\n"
0189         for k, v in self.convert_dict_inputs().items():
0190             outstr += f"     {k}: {v}\n"
0191         outstr += "  Output:\n"
0192         for k, v in self.outputs.items():
0193             if "value" in v:
0194                 v = v["value"]
0195             else:
0196                 v = "NA"
0197             outstr += f"     {v}\n"
0198         return outstr
0199 
0200     # short description
0201     def short_desc(self):
0202         return f"ID:{self.id} Name:{self.name} Type:{self.type}"
0203 
0204     # resolve workload-specific parameters
0205     def resolve_params(self, task_template=None, id_map=None, workflow=None):
0206         if self.type in ["prun", "junction", "reana"]:
0207             dict_inputs = self.convert_dict_inputs()
0208             if "opt_secondaryDSs" in dict_inputs:
0209                 # look for secondaryDsTypes if missing
0210                 if "opt_secondaryDsTypes" not in dict_inputs:
0211                     dict_inputs["opt_secondaryDsTypes"] = []
0212                     for ds_name in dict_inputs["opt_secondaryDSs"]:
0213                         added = False
0214                         for pid in self.parents:
0215                             parent_node = id_map[pid]
0216                             if ds_name in parent_node.convert_set_outputs():
0217                                 dict_inputs["opt_secondaryDsTypes"].append(parent_node.output_types[0])
0218                                 added = True
0219                                 break
0220                         if not added:
0221                             # use None if not found
0222                             dict_inputs["opt_secondaryDsTypes"].append(None)
0223                 # resolve secondary dataset names
0224                 idx = 1
0225                 list_sec_ds = []
0226                 for ds_name, ds_type in zip(dict_inputs["opt_secondaryDSs"], dict_inputs["opt_secondaryDsTypes"]):
0227                     if ds_type and "*" in ds_type:
0228                         ds_type = ds_type.replace("*", "XYZ")
0229                         ds_type += ".tgz"
0230                     src = f"%{{SECDS{idx}}}"
0231                     if ds_type:
0232                         dst = f"{ds_name}_{ds_type}/"
0233                     else:
0234                         dst = f"{ds_name}/"
0235                     dict_inputs["opt_exec"] = re.sub(src, dst, dict_inputs["opt_exec"])
0236                     dict_inputs["opt_args"] = re.sub(src, dst, dict_inputs["opt_args"])
0237                     idx += 1
0238                     list_sec_ds.append(src)
0239                 if list_sec_ds:
0240                     src = r"%{SECDS\*}"
0241                     if "opt_exec" in dict_inputs:
0242                         dict_inputs["opt_exec"] = re.sub(src, ",".join(list_sec_ds), dict_inputs["opt_exec"])
0243                     if "opt_args" in dict_inputs:
0244                         dict_inputs["opt_args"] = re.sub(src, ",".join(list_sec_ds), dict_inputs["opt_args"])
0245                 for k, v in self.inputs.items():
0246                     if k.endswith("opt_exec"):
0247                         v["value"] = dict_inputs["opt_exec"]
0248                     elif k.endswith("opt_args"):
0249                         v["value"] = dict_inputs["opt_args"]
0250                     # Set requirement for secondary datasets
0251                     if k.endswith("opt_secondaryDSs"):
0252                         v.setdefault("requirements", {})["requires_complete"] = True
0253         if self.is_leaf and task_template:
0254             self.task_params = self.make_task_params(task_template, id_map, workflow)
0255         [n.resolve_params(task_template, id_map, self) for n in self.sub_nodes]
0256 
0257     # create task params
0258     def make_task_params(self, task_template, id_map, workflow_node):
0259         # task name
0260         for k, v in self.outputs.items():
0261             task_name = v["value"]
0262             break
0263         if self.type in ["prun", "junction", "reana"]:
0264             dict_inputs = self.convert_dict_inputs(skip_suppressed=True)
0265             # check type
0266             use_athena = False
0267             if "opt_useAthenaPackages" in dict_inputs and dict_inputs["opt_useAthenaPackages"] and self.type != "reana":
0268                 use_athena = True
0269             container_image = None
0270             if "opt_containerImage" in dict_inputs and dict_inputs["opt_containerImage"]:
0271                 container_image = dict_inputs["opt_containerImage"]
0272             if use_athena:
0273                 task_params = copy.deepcopy(task_template["athena"])
0274             else:
0275                 task_params = copy.deepcopy(task_template["container"])
0276             task_params["taskName"] = task_name
0277             # cli params
0278             com = ["prun"]
0279             if self.type == "junction":
0280                 # add default output for junction
0281                 if "opt_args" not in dict_inputs:
0282                     dict_inputs["opt_args"] = ""
0283                 results_json = "results.json"
0284                 if "--outputs" not in dict_inputs["opt_args"]:
0285                     dict_inputs["opt_args"] += f" --outputs {results_json}"
0286                 else:
0287                     m = re.search("(--outputs)( +|=)([^ ]+)", dict_inputs["opt_args"])
0288                     if results_json not in m.group(3):
0289                         tmp_dst = m.group(1) + "=" + m.group(3) + "," + results_json
0290                         dict_inputs["opt_args"] = re.sub(m.group(0), tmp_dst, dict_inputs["opt_args"])
0291             com += shlex.split(dict_inputs["opt_args"])
0292             if "opt_inDS" in dict_inputs and dict_inputs["opt_inDS"]:
0293                 list_in_ds = self.get_input_ds_list(dict_inputs, id_map)
0294                 if self.type not in ["reana"]:
0295                     in_ds_str = ",".join(list_in_ds)
0296                     com += ["--inDS", in_ds_str, "--notExpandInDS", "--notExpandSecDSs"]
0297                     if self.type in ["junction"]:
0298                         com += ["--forceStaged", "--forceStagedSecondary"]
0299                 if self.type in ["prun", "junction", "reana"]:
0300                     # replace placeholders in opt_exec and opt_args
0301                     for idx, dst in enumerate(list_in_ds):
0302                         src = f"%{{DS{idx + 1}}}"
0303                         if "opt_exec" in dict_inputs:
0304                             dict_inputs["opt_exec"] = re.sub(src, dst, dict_inputs["opt_exec"])
0305                         if "opt_args" in dict_inputs:
0306                             dict_inputs["opt_args"] = re.sub(src, dst, dict_inputs["opt_args"])
0307                     if list_in_ds:
0308                         src = r"%{DS\*}"
0309                         if "opt_exec" in dict_inputs:
0310                             dict_inputs["opt_exec"] = re.sub(src, ",".join(list_in_ds), dict_inputs["opt_exec"])
0311                         if "opt_args" in dict_inputs:
0312                             dict_inputs["opt_args"] = re.sub(src, ",".join(list_in_ds), dict_inputs["opt_args"])
0313                     for k, v in self.inputs.items():
0314                         if k.endswith("opt_exec"):
0315                             v["value"] = dict_inputs["opt_exec"]
0316                         elif k.endswith("opt_args"):
0317                             v["value"] = dict_inputs["opt_args"]
0318             # global parameters
0319             if workflow_node:
0320                 tmp_global, tmp_workflow_global = workflow_node.get_global_parameters()
0321                 src_dst_list = []
0322                 # looping globals
0323                 if tmp_global:
0324                     for k in tmp_global:
0325                         tmp_src = f"%{{{k}}}"
0326                         tmp_dst = f"___idds___user_{k}___"
0327                         src_dst_list.append((tmp_src, tmp_dst))
0328                 # workflow globls
0329                 if tmp_workflow_global:
0330                     for k, v in tmp_workflow_global.items():
0331                         tmp_src = f"%{{{k}}}"
0332                         tmp_dst = f"{v}"
0333                         src_dst_list.append((tmp_src, tmp_dst))
0334                 # iteration count
0335                 tmp_src = "%{i}"
0336                 tmp_dst = "___idds___num_run___"
0337                 src_dst_list.append((tmp_src, tmp_dst))
0338                 # replace
0339                 for tmp_src, tmp_dst in src_dst_list:
0340                     if "opt_exec" in dict_inputs:
0341                         dict_inputs["opt_exec"] = re.sub(tmp_src, tmp_dst, dict_inputs["opt_exec"])
0342                     if "opt_args" in dict_inputs:
0343                         dict_inputs["opt_args"] = re.sub(tmp_src, tmp_dst, dict_inputs["opt_args"])
0344             com += ["--exec", dict_inputs["opt_exec"]]
0345             com += ["--outDS", task_name]
0346             if container_image:
0347                 com += ["--containerImage", container_image]
0348                 parse_com = copy.copy(com[1:])
0349             else:
0350                 # add dummy container to keep build step consistent
0351                 parse_com = copy.copy(com[1:])
0352                 parse_com += ["--containerImage", None]
0353             # force a writable temp base for dry parsing regardless of process cwd
0354             parse_com += ["--tmpDir", tempfile.gettempdir()]
0355             athena_tag = False
0356             if use_athena:
0357                 com += ["--useAthenaPackages"]
0358                 athena_tag = "--athenaTag" in com
0359                 # add cmtConfig
0360                 if athena_tag and "--cmtConfig" not in parse_com:
0361                     parse_com += [
0362                         "--cmtConfig",
0363                         task_params["architecture"].split("@")[0],
0364                     ]
0365             # parse args without setting --useAthenaPackages since it requires real Athena runtime
0366             parsed_params = PrunScript.main(True, parse_com, dry_mode=True)
0367             task_params["cliParams"] = " ".join(shlex.quote(x) for x in com)
0368             # set parsed parameters
0369             for p_key, p_value in parsed_params.items():
0370                 if p_key in ["buildSpec"]:
0371                     continue
0372                 if p_key not in task_params or p_key in [
0373                     "log",
0374                     "container_name",
0375                     "multiStepExec",
0376                     "site",
0377                     "excludedSite",
0378                     "includedSite",
0379                     "nMaxFilesPerJob",
0380                     "nGBPerJob",
0381                 ]:
0382                     task_params[p_key] = p_value
0383                 elif p_key == "architecture":
0384                     task_params[p_key] = p_value
0385                     if not container_image:
0386                         if task_params[p_key] is None:
0387                             task_params[p_key] = ""
0388                         if "@" not in task_params[p_key] and "basePlatform" in task_params:
0389                             task_params[p_key] = f"{task_params[p_key]}@{task_params['basePlatform']}"
0390                 elif athena_tag:
0391                     if p_key in ["transUses", "transHome"]:
0392                         task_params[p_key] = p_value
0393             # merge job params
0394             task_params["jobParameters"] = merge_job_params(task_params["jobParameters"], parsed_params["jobParameters"])
0395             # outputs
0396             for tmp_item in task_params["jobParameters"]:
0397                 if tmp_item["type"] == "template" and tmp_item["param_type"] == "output":
0398                     if tmp_item["value"].startswith("regex|"):
0399                         self.output_types.append(re.search(r"_([^_]+)/$", tmp_item["dataset"]).group(1))
0400                     else:
0401                         self.output_types.append(re.search(r"}\.(.+)$", tmp_item["value"]).group(1))
0402             # add a dummy output if empty. this is to allow association to downstream steps which is described through outputs
0403             if not self.output_types:
0404                 self.output_types.append("dummy")
0405             # container
0406             if not container_image:
0407                 if "container_name" in task_params:
0408                     del task_params["container_name"]
0409                 if "multiStepExec" in task_params:
0410                     del task_params["multiStepExec"]
0411             if "basePlatform" in task_params:
0412                 del task_params["basePlatform"]
0413             # no build
0414             if use_athena and "--noBuild" in parse_com:
0415                 for tmp_item in task_params["jobParameters"]:
0416                     if tmp_item["type"] == "constant" and tmp_item["value"] == "-l ${LIB}":
0417                         tmp_item["value"] = f"-a {task_params['buildSpec']['archiveName']}"
0418                 del task_params["buildSpec"]
0419             # parent
0420             # if self.parents and len(self.parents) == 1:
0421             #     task_params["noWaitParent"] = True
0422             #     task_params["parentTaskName"] = id_map[list(self.parents)[0]].task_params["taskName"]
0423             # notification
0424             if not self.is_workflow_output:
0425                 task_params["noEmail"] = True
0426             # use instant PQs
0427             if self.type in ["junction", "reana"]:
0428                 task_params["runOnInstant"] = True
0429             # return
0430             return task_params
0431         elif self.type == "phpo":
0432             dict_inputs = self.convert_dict_inputs(skip_suppressed=True)
0433             # extract source and base URL
0434             source_url = task_template["container"]["sourceURL"]
0435             source_name = None
0436             for tmp_item in task_template["container"]["jobParameters"]:
0437                 if tmp_item["type"] == "constant" and tmp_item["value"].startswith("-a "):
0438                     source_name = tmp_item["value"].split()[-1]
0439             # cli params
0440             com = shlex.split(dict_inputs["opt_args"])
0441             if "opt_trainingDS" in dict_inputs and dict_inputs["opt_trainingDS"]:
0442                 if "opt_trainingDsType" not in dict_inputs or not dict_inputs["opt_trainingDsType"]:
0443                     in_ds_suffix = None
0444                     for parent_id in self.parents:
0445                         parent_node = id_map[parent_id]
0446                         if dict_inputs["opt_trainingDS"] in parent_node.convert_set_outputs():
0447                             in_ds_suffix = parent_node.output_types[0]
0448                             break
0449                 else:
0450                     in_ds_suffix = dict_inputs["opt_inDsType"]
0451                 in_ds_str = f"{dict_inputs['opt_trainingDS']}_{in_ds_suffix}/"
0452                 com += ["--trainingDS", in_ds_str]
0453             com += ["--outDS", task_name]
0454             # get task params
0455             task_params = PhpoScript.main(True, com, dry_mode=True)
0456             # change sandbox
0457             new_job_params = []
0458             for tmp_item in task_params["jobParameters"]:
0459                 if tmp_item["type"] == "constant" and tmp_item["value"].startswith("-a "):
0460                     tmp_item["value"] = f"-a {source_name} --sourceURL {source_url}"
0461                 new_job_params.append(tmp_item)
0462             task_params["jobParameters"] = new_job_params
0463             # return
0464             return task_params
0465         elif self.type == "gitlab":
0466             dict_inputs = self.convert_dict_inputs(skip_suppressed=True)
0467             list_in_ds = self.get_input_ds_list(dict_inputs, id_map)
0468             task_params = copy.copy(task_template["container"])
0469             task_params["taskName"] = task_name
0470             task_params["noInput"] = True
0471             task_params["nEventsPerJob"] = 1
0472             task_params["nEvents"] = 1
0473             task_params["processingType"] = re.sub(r"-[^-]+$", "-gitlab", task_params["processingType"])
0474             task_params["useSecrets"] = True
0475             task_params["site"] = dict_inputs["opt_site"]
0476             task_params["cliParams"] = ""
0477             task_params["log"]["container"] = task_params["log"]["dataset"] = f"{task_name}.log/"
0478             # set gitlab parameters
0479             task_params["jobParameters"] = [
0480                 {
0481                     "type": "constant",
0482                     "value": json.dumps(
0483                         {
0484                             "project_api": dict_inputs["opt_api"],
0485                             "project_id": int(dict_inputs["opt_projectID"]),
0486                             "ref": dict_inputs["opt_ref"],
0487                             "trigger_token": dict_inputs["opt_triggerToken"],
0488                             "access_token": dict_inputs["opt_accessToken"],
0489                             "input_datasets": ",".join(list_in_ds),
0490                             "input_location": dict_inputs.get("opt_input_location"),
0491                         }
0492                     ),
0493                 }
0494             ]
0495 
0496             del task_params["container_name"]
0497             del task_params["multiStepExec"]
0498             return task_params
0499         return None
0500 
0501     # get global parameters in the workflow
0502     def get_global_parameters(self):
0503         if self.is_leaf:
0504             root_inputs = self.upper_root_inputs
0505         else:
0506             root_inputs = self.root_inputs
0507         if root_inputs is None:
0508             return None, None
0509         loop_params = {}
0510         workflow_params = {}
0511         for k, v in root_inputs.items():
0512             m = self.get_loop_param_name(k)
0513             if m:
0514                 loop_params[m] = v
0515             else:
0516                 param = k.split("#")[-1]
0517                 workflow_params[param] = v
0518         return loop_params, workflow_params
0519 
0520     # get all sub node IDs
0521     def get_all_sub_node_ids(self, all_ids=None):
0522         if all_ids is None:
0523             all_ids = set()
0524         all_ids.add(self.id)
0525         for sub_node in self.sub_nodes:
0526             all_ids.add(sub_node.id)
0527             if not sub_node.is_leaf:
0528                 sub_node.get_all_sub_node_ids(all_ids)
0529         return all_ids
0530 
0531     # get loop param name
0532     def get_loop_param_name(self, k):
0533         param = k.split("#")[-1]
0534         m = re.search(r"^param_(.+)", param)
0535         if m:
0536             return m.group(1)
0537         return None
0538 
0539     # def get input dataset list
0540     def get_input_ds_list(self, dict_inputs, id_map):
0541         if "opt_inDS" not in dict_inputs:
0542             return []
0543         if isinstance(dict_inputs["opt_inDS"], list):
0544             is_list_in_ds = True
0545         else:
0546             is_list_in_ds = False
0547         if "opt_inDsType" not in dict_inputs or not dict_inputs["opt_inDsType"]:
0548             if is_list_in_ds:
0549                 in_ds_suffix = []
0550                 in_ds_list = dict_inputs["opt_inDS"]
0551             else:
0552                 in_ds_suffix = None
0553                 in_ds_list = [dict_inputs["opt_inDS"]]
0554             for tmp_in_ds in in_ds_list:
0555                 for parent_id in self.parents:
0556                     parent_node = id_map[parent_id]
0557                     if tmp_in_ds in parent_node.convert_set_outputs():
0558                         if is_list_in_ds:
0559                             in_ds_suffix.append(parent_node.output_types[0])
0560                         else:
0561                             in_ds_suffix = parent_node.output_types[0]
0562                         break
0563         else:
0564             in_ds_suffix = dict_inputs["opt_inDsType"]
0565             if "*" in in_ds_suffix:
0566                 in_ds_suffix = in_ds_suffix.replace("*", "XYZ") + ".tgz"
0567         if is_list_in_ds:
0568             list_in_ds = [f"{s1}_{s2}/" if s2 else s1 for s1, s2 in zip(dict_inputs["opt_inDS"], in_ds_suffix)]
0569         else:
0570             list_in_ds = [f"{dict_inputs['opt_inDS']}_{in_ds_suffix}/" if in_ds_suffix else dict_inputs["opt_inDS"]]
0571         return list_in_ds
0572 
0573 
0574 # dump nodes
0575 def dump_nodes(node_list, dump_str=None, only_leaves=False):
0576     if dump_str is None:
0577         dump_str = "\n"
0578     for node in node_list:
0579         if node.is_leaf:
0580             dump_str += f"{node}"
0581             if node.task_params is not None:
0582                 dump_str += json.dumps(node.task_params, indent=4, sort_keys=True)
0583                 dump_str += "\n\n"
0584         else:
0585             if not only_leaves:
0586                 dump_str += f"{node}\n"
0587             dump_str = dump_nodes(node.sub_nodes, dump_str, only_leaves)
0588     return dump_str
0589 
0590 
0591 # get id map
0592 def get_node_id_map(node_list, id_map=None):
0593     if id_map is None:
0594         id_map = {}
0595     for node in node_list:
0596         id_map[node.id] = node
0597         if node.sub_nodes:
0598             id_map = get_node_id_map(node.sub_nodes, id_map)
0599     return id_map
0600 
0601 
0602 # get all parents
0603 def get_all_parents(node_list, all_parents=None):
0604     if all_parents is None:
0605         all_parents = set()
0606     for node in node_list:
0607         all_parents |= node.parents
0608         if node.sub_nodes:
0609             all_parents = get_all_parents(node.sub_nodes, all_parents)
0610     return all_parents
0611 
0612 
0613 # set workflow outputs
0614 def set_workflow_outputs(node_list, all_parents=None):
0615     if all_parents is None:
0616         all_parents = get_all_parents(node_list)
0617     for node in node_list:
0618         if node.is_leaf and node.id not in all_parents:
0619             node.is_workflow_output = True
0620         if node.sub_nodes:
0621             set_workflow_outputs(node.sub_nodes, all_parents)
0622 
0623 
0624 # NOTE: condition features are not yet implemented
0625 # TODO: implement condition support
0626 # def convert_params_in_condition_to_parent_ids(condition_item, input_data, id_map):
0627 #     for item in ["left", "right"]:
0628 #         param = getattr(condition_item, item)
0629 #         if isinstance(param, str):
0630 #             m = re.search(r"^[^\[]+\[(\d+)\]", param)
0631 #             if m:
0632 #                 param = param.split("[")[0]
0633 #                 idx = int(m.group(1))
0634 #             else:
0635 #                 idx = None
0636 #             isOK = False
0637 #             for tmp_name, tmp_data in input_data.items():
0638 #                 if param == tmp_name.split("/")[-1]:
0639 #                     isOK = True
0640 #                     if isinstance(tmp_data["parent_id"], list):
0641 #                         if idx is not None:
0642 #                             if idx < 0 or idx >= len(tmp_data["parent_id"]):
0643 #                                 raise IndexError(f"index {idx} is out of bounds for parameter {param} with {len(tmp_data['parent_id'])} parents")
0644 #                             parent_id = tmp_data["parent_id"][idx]
0645 #                             if parent_id not in id_map:
0646 #                                 raise ReferenceError(f"unresolved parent_id {parent_id} for parameter {param}[{idx}]")
0647 #                             setattr(condition_item, item, id_map[parent_id])
0648 #                         else:
0649 #                             resolved_parent_ids = set()
0650 #                             for parent_id in tmp_data["parent_id"]:
0651 #                                 if parent_id not in id_map:
0652 #                                     raise ReferenceError(f"unresolved parent_id {parent_id} for parameter {param}")
0653 #                                 resolved_parent_ids |= id_map[parent_id]
0654 #                             setattr(condition_item, item, list(resolved_parent_ids))
0655 #                     else:
0656 #                         if tmp_data["parent_id"] not in id_map:
0657 #                             raise ReferenceError(f"unresolved parent_id {tmp_data['parent_id']} for parameter {param}")
0658 #                         setattr(condition_item, item, id_map[tmp_data["parent_id"]])
0659 #                     break
0660 #             if not isOK:
0661 #                 raise ReferenceError(f"unresolved parameter {param} in the condition string")
0662 
0663 
0664 # resolve nodes
0665 def resolve_nodes(node_list, root_inputs, data, serial_id, parent_ids, out_ds_name, log_stream):
0666     for k in root_inputs:
0667         kk = k.split("#")[-1]
0668         if kk in data:
0669             root_inputs[k] = data[kk]
0670     tmp_to_real_id_map = {}
0671     resolved_map = {}
0672     # map of object identity to original temporary node ID used in resolved_map keys
0673     node_key_map = {}
0674     all_nodes = []
0675     for node in node_list:
0676         # resolve input
0677         for tmp_name, tmp_data in node.inputs.items():
0678             if not tmp_data["source"]:
0679                 continue
0680             if isinstance(tmp_data["source"], list):
0681                 tmp_sources = tmp_data["source"]
0682                 if "parent_id" in tmp_data:
0683                     # Make a copy to avoid mutating the original list stored in node.inputs
0684                     tmp_parent_ids = list(tmp_data["parent_id"])
0685                     tmp_parent_ids += [None] * (len(tmp_sources) - len(tmp_parent_ids))
0686                 else:
0687                     tmp_parent_ids = [None] * len(tmp_sources)
0688             else:
0689                 tmp_sources = [tmp_data["source"]]
0690                 if "parent_id" in tmp_data:
0691                     tmp_parent_ids = [tmp_data["parent_id"]]
0692                 else:
0693                     tmp_parent_ids = [None] * len(tmp_sources)
0694             for tmp_source, tmp_parent_id in zip(tmp_sources, tmp_parent_ids):
0695                 isOK = False
0696                 # check root input
0697                 if tmp_source in root_inputs:
0698                     node.is_head = True
0699                     node.set_input_value(tmp_name, tmp_source, root_inputs[tmp_source])
0700                     continue
0701                 # check parent output
0702                 for i in node.parents:
0703                     for r_node in resolved_map[i]:
0704                         if tmp_source in r_node.outputs:
0705                             node.set_input_value(
0706                                 tmp_name,
0707                                 tmp_source,
0708                                 r_node.outputs[tmp_source]["value"],
0709                             )
0710                             isOK = True
0711                             break
0712                     if isOK:
0713                         break
0714                 if isOK:
0715                     continue
0716                 # check resolved parent outputs
0717                 if tmp_parent_id is not None:
0718                     values = [list(r_node.outputs.values())[0]["value"] for r_node in resolved_map[tmp_parent_id]]
0719                     if len(values) == 1:
0720                         values = values[0]
0721                     node.set_input_value(tmp_name, tmp_source, values)
0722                     continue
0723         # scatter
0724         if node.scatter:
0725             # resolve scattered parameters
0726             scatters = None
0727             sc_nodes = []
0728             for item in node.scatter:
0729                 if scatters is None:
0730                     scatters = [{item: v} for v in node.inputs[item]["value"]]
0731                 else:
0732                     [i.update({item: v}) for i, v in zip(scatters, node.inputs[item]["value"])]
0733             for idx, item in enumerate(scatters):
0734                 sc_node = copy.deepcopy(node)
0735                 for k, v in item.items():
0736                     sc_node.inputs[k]["value"] = v
0737                 for tmp_node in sc_node.sub_nodes:
0738                     tmp_node.scatter_index = idx
0739                     tmp_node.upper_root_inputs = sc_node.root_inputs
0740                 sc_nodes.append(sc_node)
0741         else:
0742             sc_nodes = [node]
0743         # loop over scattered nodes
0744         for sc_node in sc_nodes:
0745             original_node_id = sc_node.id
0746             all_nodes.append(sc_node)
0747             node_key_map[id(sc_node)] = original_node_id
0748             # set real node ID
0749             resolved_map.setdefault(original_node_id, [])
0750             tmp_to_real_id_map.setdefault(original_node_id, set())
0751             # resolve parents
0752             real_parens = set()
0753             for i in sc_node.parents:
0754                 real_parens |= tmp_to_real_id_map[i]
0755             sc_node.parents = real_parens
0756             if sc_node.is_head:
0757                 sc_node.parents |= parent_ids
0758             if sc_node.is_leaf:
0759                 resolved_map[original_node_id].append(sc_node)
0760                 tmp_to_real_id_map[original_node_id].add(serial_id)
0761                 sc_node.id = serial_id
0762                 serial_id += 1
0763             else:
0764                 serial_id, sub_tail_nodes, sc_node.sub_nodes = resolve_nodes(
0765                     sc_node.sub_nodes,
0766                     sc_node.root_inputs,
0767                     sc_node.convert_dict_inputs(),
0768                     serial_id,
0769                     sc_node.parents,
0770                     out_ds_name,
0771                     log_stream,
0772                 )
0773                 resolved_map[original_node_id] += sub_tail_nodes
0774                 tmp_to_real_id_map[original_node_id] |= set([n.id for n in sub_tail_nodes])
0775                 sc_node.id = serial_id
0776                 serial_id += 1
0777             # convert parameters to parent IDs in conditions
0778             # TODO: condition features not yet implemented
0779             if sc_node.condition:
0780                 pass
0781                 # convert_params_in_condition_to_parent_ids(sc_node.condition, sc_node.inputs, tmp_to_real_id_map)
0782             # resolve outputs
0783             if sc_node.is_leaf:
0784                 for tmp_name, tmp_data in sc_node.outputs.items():
0785                     tmp_data["value"] = f"{out_ds_name}_{sc_node.id:03d}_{sc_node.name}"
0786                     # add loop count for nodes in a loop
0787                     if sc_node.in_loop:
0788                         tmp_data["value"] += ".___idds___num_run___"
0789     # return tails
0790     tail_nodes = []
0791     for node in all_nodes:
0792         original_node_id = node_key_map.get(id(node), node.id)
0793         if node.is_tail:
0794             tail_nodes.append(node)
0795         else:
0796             tail_nodes += resolved_map[original_node_id]
0797     return serial_id, tail_nodes, all_nodes
0798 
0799 
0800 # parse workflow data for native YAML workflow
0801 def parse_workflow_data(data, log_stream):
0802     # Handle both nested (workflow:{...}) and flat ({...}) structures
0803     workflow_data = data.get("workflow", data)
0804 
0805     # extract root inputs and outputs
0806     root_inputs = workflow_data.get("inputs", {})
0807     root_outputs = workflow_data.get("outputs", {})
0808     tail_node_names = {output_spec["from"].split("/")[0] for output_spec in root_outputs.values() if isinstance(output_spec, dict) and "from" in output_spec}
0809 
0810     # parse steps
0811     steps = workflow_data.get("steps", {})
0812     node_list = []
0813     node_name_map = {}
0814     serial_id = 0
0815 
0816     # first pass: create all nodes
0817     for step_name, step_spec in steps.items():
0818         serial_id += 1
0819         step_type = step_spec.get("type", "prun")
0820         is_leaf = step_type in ["prun", "phpo", "junction", "reana", "gitlab"]
0821         node = Node(serial_id, step_type, None, is_leaf, step_name)
0822         node_name_map[step_name] = node
0823 
0824         # parse inputs
0825         inputs = {}
0826         for key, yaml_key in [
0827             ("inDS", "opt_inDS"),
0828             ("args", "opt_args"),
0829             ("exec", "opt_exec"),
0830             ("containerImage", "opt_containerImage"),
0831             ("useAthenaPackages", "opt_useAthenaPackages"),
0832             ("secondaryDSs", "opt_secondaryDSs"),
0833             ("secondaryDsTypes", "opt_secondaryDsTypes"),
0834         ]:
0835             if key in step_spec:
0836                 inputs[f"{step_name}/{yaml_key}"] = {
0837                     "default": step_spec.get(key) if key not in ["inDS", "secondaryDSs"] else None,
0838                     "source": step_spec.get(key) if key in ["inDS", "secondaryDSs"] else None,
0839                 }
0840 
0841         node.inputs = inputs
0842         node.outputs = {f"{step_name}/outDS": {}}
0843         node.is_tail = step_name in tail_node_names
0844         node_list.append(node)
0845 
0846     # second pass: resolve parent relationships; note that the parent_id is not used in core workflow execution but only for parameter resolution
0847     for node in node_list:
0848         for input_name, input_data in node.inputs.items():
0849             source = input_data.get("source")
0850             if not source:
0851                 continue
0852 
0853             # resolve single source
0854             if isinstance(source, str):
0855                 if source.startswith("{") and source.endswith("}"):
0856                     input_data["source"] = source[1:-1]
0857                 elif "/" in source:
0858                     source_node_name = source.split("/")[0]
0859                     if source_node_name in node_name_map:
0860                         parent = node_name_map[source_node_name]
0861                         node.add_parent(parent.id)
0862                         input_data["parent_id"] = parent.id
0863             # resolve list of sources
0864             elif isinstance(source, list):
0865                 parent_ids = []
0866                 for src in source:
0867                     if isinstance(src, str) and "/" in src:
0868                         source_node_name = src.split("/")[0]
0869                         if source_node_name in node_name_map:
0870                             parent = node_name_map[source_node_name]
0871                             node.add_parent(parent.id)
0872                             parent_ids.append(parent.id)
0873                 if parent_ids:
0874                     input_data["parent_id"] = parent_ids
0875 
0876     # topological sort
0877     visited = set()
0878     sorted_nodes = []
0879 
0880     def visit(n):
0881         if n.id in visited:
0882             return
0883         for parent_id in n.parents:
0884             for other in node_list:
0885                 if other.id == parent_id:
0886                     visit(other)
0887         visited.add(n.id)
0888         sorted_nodes.append(n)
0889 
0890     for node in node_list:
0891         visit(node)
0892 
0893     return sorted_nodes, root_inputs