File indexing completed on 2026-05-15 08:35:37
0001 import copy
0002 import json
0003 import re
0004 import shlex
0005 import tempfile
0006
0007 from pandaclient import PhpoScript, PrunScript
0008
0009
0010
0011 def merge_job_params(base_params, io_params):
0012 new_params = []
0013
0014 exec_start = False
0015 end_exec = False
0016 for tmp_item in base_params:
0017 if tmp_item["type"] == "constant" and tmp_item["value"].startswith("-p "):
0018 exec_start = True
0019 continue
0020 if exec_start:
0021 if end_exec:
0022 pass
0023 elif tmp_item["type"] == "constant" and "padding" not in tmp_item:
0024 end_exec = True
0025 continue
0026 if exec_start and not end_exec:
0027 continue
0028 new_params.append(tmp_item)
0029
0030 exec_start = False
0031 for tmp_item in io_params:
0032 if tmp_item["type"] == "constant" and tmp_item["value"] == "__delimiter__":
0033 exec_start = True
0034 continue
0035
0036 if tmp_item["type"] == "constant" and tmp_item["value"].startswith("-a "):
0037 continue
0038 if not exec_start:
0039 continue
0040 new_params.append(tmp_item)
0041 return new_params
0042
0043
0044
0045 class Node(object):
0046 def __init__(self, id, node_type, data, is_leaf, name):
0047 self.id = id
0048 self.type = node_type
0049 self.data = data
0050 self.is_leaf = is_leaf
0051 self.is_tail = False
0052 self.is_head = False
0053 self.inputs = {}
0054 self.outputs = {}
0055 self.output_types = []
0056 self.scatter = None
0057 self.parents = set()
0058 self.name = name
0059 self.sub_nodes = set()
0060 self.root_inputs = None
0061 self.task_params = None
0062 self.condition = None
0063 self.is_workflow_output = False
0064 self.loop = False
0065 self.in_loop = False
0066 self.upper_root_inputs = None
0067
0068 def add_parent(self, id):
0069 self.parents.add(id)
0070
0071
0072 def set_input_value(self, key, src_key, src_value):
0073
0074
0075 if isinstance(src_value, list):
0076 src_loop_param_name = self.get_loop_param_name(src_key)
0077 loop_params = self.get_loop_param_name(key.split("/")[-1]) is not None and src_loop_param_name is not None
0078 if loop_params:
0079 src_value = [{"src": src_loop_param_name, "idx": i} for i in range(len(src_value))]
0080
0081 if isinstance(self.inputs[key]["source"], list):
0082 self.inputs[key].setdefault("value", copy.copy(self.inputs[key]["source"]))
0083 tmp_list = []
0084 for k in self.inputs[key]["value"]:
0085 if k == src_key:
0086 tmp_list.append(src_value)
0087 else:
0088 tmp_list.append(k)
0089 self.inputs[key]["value"] = tmp_list
0090 else:
0091 self.inputs[key]["value"] = src_value
0092
0093
0094 def convert_dict_inputs(self, skip_suppressed=False):
0095 data = {}
0096 for k, v in self.inputs.items():
0097 if skip_suppressed and "suppressed" in v and v["suppressed"]:
0098 continue
0099 y_name = k.split("/")[-1]
0100 if "value" in v:
0101 data[y_name] = v["value"]
0102 elif "default" in v:
0103 data[y_name] = v["default"]
0104 else:
0105 raise ReferenceError(f"{k} is not resolved")
0106 return data
0107
0108
0109 def convert_set_outputs(self):
0110 data = set()
0111 for k, v in self.outputs.items():
0112 if "value" in v:
0113 data.add(v["value"])
0114 return data
0115
0116
0117 def verify(self):
0118 if self.is_leaf:
0119 dict_inputs = self.convert_dict_inputs(True)
0120
0121 for k, v in dict_inputs.items():
0122 if v is None:
0123 return False, f"{k} is unresolved"
0124
0125 for k in ["opt_exec", "opt_args"]:
0126 test_str = dict_inputs.get(k)
0127 if test_str:
0128 m = re.search(r"%{[A-Z]*DS(\d+|\*)}", test_str)
0129 if m:
0130 return False, f"{m.group(0)} is unresolved in {k}"
0131 if self.type == "prun":
0132 for k in dict_inputs:
0133 if k not in [
0134 "opt_inDS",
0135 "opt_inDsType",
0136 "opt_secondaryDSs",
0137 "opt_secondaryDsTypes",
0138 "opt_args",
0139 "opt_exec",
0140 "opt_useAthenaPackages",
0141 "opt_containerImage",
0142 ]:
0143 return False, f"unknown input parameter {k} for {self.type}"
0144 elif self.type in ["junction", "reana"]:
0145 for k in dict_inputs:
0146 if k not in [
0147 "opt_inDS",
0148 "opt_inDsType",
0149 "opt_args",
0150 "opt_exec",
0151 "opt_containerImage",
0152 ]:
0153 return False, f"unknown input parameter {k} for {self.type}"
0154 elif self.type == "phpo":
0155 for k in dict_inputs:
0156 if k not in ["opt_trainingDS", "opt_trainingDsType", "opt_args"]:
0157 return False, f"unknown input parameter {k} for {self.type}"
0158 elif self.type == "gitlab":
0159 for k in dict_inputs:
0160 if k not in [
0161 "opt_inDS",
0162 "opt_args",
0163 "opt_api",
0164 "opt_projectID",
0165 "opt_ref",
0166 "opt_triggerToken",
0167 "opt_accessToken",
0168 "opt_site",
0169 "opt_input_location",
0170 ]:
0171 return False, f"unknown input parameter {k} for {self.type}"
0172 elif self.type == "workflow":
0173 reserved_params = ["i"]
0174 loop_global, workflow_global = self.get_global_parameters()
0175 if loop_global:
0176 for k in reserved_params:
0177 if k in loop_global:
0178 return (
0179 False,
0180 f"parameter {k} cannot be used since it is reserved by the system",
0181 )
0182 return True, ""
0183
0184
0185 def __str__(self):
0186 outstr = f"ID:{self.id} Name:{self.name} Type:{self.type}\n"
0187 outstr += f" Parent:{','.join([str(p) for p in self.parents])}\n"
0188 outstr += " Input:\n"
0189 for k, v in self.convert_dict_inputs().items():
0190 outstr += f" {k}: {v}\n"
0191 outstr += " Output:\n"
0192 for k, v in self.outputs.items():
0193 if "value" in v:
0194 v = v["value"]
0195 else:
0196 v = "NA"
0197 outstr += f" {v}\n"
0198 return outstr
0199
0200
0201 def short_desc(self):
0202 return f"ID:{self.id} Name:{self.name} Type:{self.type}"
0203
0204
0205 def resolve_params(self, task_template=None, id_map=None, workflow=None):
0206 if self.type in ["prun", "junction", "reana"]:
0207 dict_inputs = self.convert_dict_inputs()
0208 if "opt_secondaryDSs" in dict_inputs:
0209
0210 if "opt_secondaryDsTypes" not in dict_inputs:
0211 dict_inputs["opt_secondaryDsTypes"] = []
0212 for ds_name in dict_inputs["opt_secondaryDSs"]:
0213 added = False
0214 for pid in self.parents:
0215 parent_node = id_map[pid]
0216 if ds_name in parent_node.convert_set_outputs():
0217 dict_inputs["opt_secondaryDsTypes"].append(parent_node.output_types[0])
0218 added = True
0219 break
0220 if not added:
0221
0222 dict_inputs["opt_secondaryDsTypes"].append(None)
0223
0224 idx = 1
0225 list_sec_ds = []
0226 for ds_name, ds_type in zip(dict_inputs["opt_secondaryDSs"], dict_inputs["opt_secondaryDsTypes"]):
0227 if ds_type and "*" in ds_type:
0228 ds_type = ds_type.replace("*", "XYZ")
0229 ds_type += ".tgz"
0230 src = f"%{{SECDS{idx}}}"
0231 if ds_type:
0232 dst = f"{ds_name}_{ds_type}/"
0233 else:
0234 dst = f"{ds_name}/"
0235 dict_inputs["opt_exec"] = re.sub(src, dst, dict_inputs["opt_exec"])
0236 dict_inputs["opt_args"] = re.sub(src, dst, dict_inputs["opt_args"])
0237 idx += 1
0238 list_sec_ds.append(src)
0239 if list_sec_ds:
0240 src = r"%{SECDS\*}"
0241 if "opt_exec" in dict_inputs:
0242 dict_inputs["opt_exec"] = re.sub(src, ",".join(list_sec_ds), dict_inputs["opt_exec"])
0243 if "opt_args" in dict_inputs:
0244 dict_inputs["opt_args"] = re.sub(src, ",".join(list_sec_ds), dict_inputs["opt_args"])
0245 for k, v in self.inputs.items():
0246 if k.endswith("opt_exec"):
0247 v["value"] = dict_inputs["opt_exec"]
0248 elif k.endswith("opt_args"):
0249 v["value"] = dict_inputs["opt_args"]
0250
0251 if k.endswith("opt_secondaryDSs"):
0252 v.setdefault("requirements", {})["requires_complete"] = True
0253 if self.is_leaf and task_template:
0254 self.task_params = self.make_task_params(task_template, id_map, workflow)
0255 [n.resolve_params(task_template, id_map, self) for n in self.sub_nodes]
0256
0257
0258 def make_task_params(self, task_template, id_map, workflow_node):
0259
0260 for k, v in self.outputs.items():
0261 task_name = v["value"]
0262 break
0263 if self.type in ["prun", "junction", "reana"]:
0264 dict_inputs = self.convert_dict_inputs(skip_suppressed=True)
0265
0266 use_athena = False
0267 if "opt_useAthenaPackages" in dict_inputs and dict_inputs["opt_useAthenaPackages"] and self.type != "reana":
0268 use_athena = True
0269 container_image = None
0270 if "opt_containerImage" in dict_inputs and dict_inputs["opt_containerImage"]:
0271 container_image = dict_inputs["opt_containerImage"]
0272 if use_athena:
0273 task_params = copy.deepcopy(task_template["athena"])
0274 else:
0275 task_params = copy.deepcopy(task_template["container"])
0276 task_params["taskName"] = task_name
0277
0278 com = ["prun"]
0279 if self.type == "junction":
0280
0281 if "opt_args" not in dict_inputs:
0282 dict_inputs["opt_args"] = ""
0283 results_json = "results.json"
0284 if "--outputs" not in dict_inputs["opt_args"]:
0285 dict_inputs["opt_args"] += f" --outputs {results_json}"
0286 else:
0287 m = re.search("(--outputs)( +|=)([^ ]+)", dict_inputs["opt_args"])
0288 if results_json not in m.group(3):
0289 tmp_dst = m.group(1) + "=" + m.group(3) + "," + results_json
0290 dict_inputs["opt_args"] = re.sub(m.group(0), tmp_dst, dict_inputs["opt_args"])
0291 com += shlex.split(dict_inputs["opt_args"])
0292 if "opt_inDS" in dict_inputs and dict_inputs["opt_inDS"]:
0293 list_in_ds = self.get_input_ds_list(dict_inputs, id_map)
0294 if self.type not in ["reana"]:
0295 in_ds_str = ",".join(list_in_ds)
0296 com += ["--inDS", in_ds_str, "--notExpandInDS", "--notExpandSecDSs"]
0297 if self.type in ["junction"]:
0298 com += ["--forceStaged", "--forceStagedSecondary"]
0299 if self.type in ["prun", "junction", "reana"]:
0300
0301 for idx, dst in enumerate(list_in_ds):
0302 src = f"%{{DS{idx + 1}}}"
0303 if "opt_exec" in dict_inputs:
0304 dict_inputs["opt_exec"] = re.sub(src, dst, dict_inputs["opt_exec"])
0305 if "opt_args" in dict_inputs:
0306 dict_inputs["opt_args"] = re.sub(src, dst, dict_inputs["opt_args"])
0307 if list_in_ds:
0308 src = r"%{DS\*}"
0309 if "opt_exec" in dict_inputs:
0310 dict_inputs["opt_exec"] = re.sub(src, ",".join(list_in_ds), dict_inputs["opt_exec"])
0311 if "opt_args" in dict_inputs:
0312 dict_inputs["opt_args"] = re.sub(src, ",".join(list_in_ds), dict_inputs["opt_args"])
0313 for k, v in self.inputs.items():
0314 if k.endswith("opt_exec"):
0315 v["value"] = dict_inputs["opt_exec"]
0316 elif k.endswith("opt_args"):
0317 v["value"] = dict_inputs["opt_args"]
0318
0319 if workflow_node:
0320 tmp_global, tmp_workflow_global = workflow_node.get_global_parameters()
0321 src_dst_list = []
0322
0323 if tmp_global:
0324 for k in tmp_global:
0325 tmp_src = f"%{{{k}}}"
0326 tmp_dst = f"___idds___user_{k}___"
0327 src_dst_list.append((tmp_src, tmp_dst))
0328
0329 if tmp_workflow_global:
0330 for k, v in tmp_workflow_global.items():
0331 tmp_src = f"%{{{k}}}"
0332 tmp_dst = f"{v}"
0333 src_dst_list.append((tmp_src, tmp_dst))
0334
0335 tmp_src = "%{i}"
0336 tmp_dst = "___idds___num_run___"
0337 src_dst_list.append((tmp_src, tmp_dst))
0338
0339 for tmp_src, tmp_dst in src_dst_list:
0340 if "opt_exec" in dict_inputs:
0341 dict_inputs["opt_exec"] = re.sub(tmp_src, tmp_dst, dict_inputs["opt_exec"])
0342 if "opt_args" in dict_inputs:
0343 dict_inputs["opt_args"] = re.sub(tmp_src, tmp_dst, dict_inputs["opt_args"])
0344 com += ["--exec", dict_inputs["opt_exec"]]
0345 com += ["--outDS", task_name]
0346 if container_image:
0347 com += ["--containerImage", container_image]
0348 parse_com = copy.copy(com[1:])
0349 else:
0350
0351 parse_com = copy.copy(com[1:])
0352 parse_com += ["--containerImage", None]
0353
0354 parse_com += ["--tmpDir", tempfile.gettempdir()]
0355 athena_tag = False
0356 if use_athena:
0357 com += ["--useAthenaPackages"]
0358 athena_tag = "--athenaTag" in com
0359
0360 if athena_tag and "--cmtConfig" not in parse_com:
0361 parse_com += [
0362 "--cmtConfig",
0363 task_params["architecture"].split("@")[0],
0364 ]
0365
0366 parsed_params = PrunScript.main(True, parse_com, dry_mode=True)
0367 task_params["cliParams"] = " ".join(shlex.quote(x) for x in com)
0368
0369 for p_key, p_value in parsed_params.items():
0370 if p_key in ["buildSpec"]:
0371 continue
0372 if p_key not in task_params or p_key in [
0373 "log",
0374 "container_name",
0375 "multiStepExec",
0376 "site",
0377 "excludedSite",
0378 "includedSite",
0379 "nMaxFilesPerJob",
0380 "nGBPerJob",
0381 ]:
0382 task_params[p_key] = p_value
0383 elif p_key == "architecture":
0384 task_params[p_key] = p_value
0385 if not container_image:
0386 if task_params[p_key] is None:
0387 task_params[p_key] = ""
0388 if "@" not in task_params[p_key] and "basePlatform" in task_params:
0389 task_params[p_key] = f"{task_params[p_key]}@{task_params['basePlatform']}"
0390 elif athena_tag:
0391 if p_key in ["transUses", "transHome"]:
0392 task_params[p_key] = p_value
0393
0394 task_params["jobParameters"] = merge_job_params(task_params["jobParameters"], parsed_params["jobParameters"])
0395
0396 for tmp_item in task_params["jobParameters"]:
0397 if tmp_item["type"] == "template" and tmp_item["param_type"] == "output":
0398 if tmp_item["value"].startswith("regex|"):
0399 self.output_types.append(re.search(r"_([^_]+)/$", tmp_item["dataset"]).group(1))
0400 else:
0401 self.output_types.append(re.search(r"}\.(.+)$", tmp_item["value"]).group(1))
0402
0403 if not self.output_types:
0404 self.output_types.append("dummy")
0405
0406 if not container_image:
0407 if "container_name" in task_params:
0408 del task_params["container_name"]
0409 if "multiStepExec" in task_params:
0410 del task_params["multiStepExec"]
0411 if "basePlatform" in task_params:
0412 del task_params["basePlatform"]
0413
0414 if use_athena and "--noBuild" in parse_com:
0415 for tmp_item in task_params["jobParameters"]:
0416 if tmp_item["type"] == "constant" and tmp_item["value"] == "-l ${LIB}":
0417 tmp_item["value"] = f"-a {task_params['buildSpec']['archiveName']}"
0418 del task_params["buildSpec"]
0419
0420
0421
0422
0423
0424 if not self.is_workflow_output:
0425 task_params["noEmail"] = True
0426
0427 if self.type in ["junction", "reana"]:
0428 task_params["runOnInstant"] = True
0429
0430 return task_params
0431 elif self.type == "phpo":
0432 dict_inputs = self.convert_dict_inputs(skip_suppressed=True)
0433
0434 source_url = task_template["container"]["sourceURL"]
0435 source_name = None
0436 for tmp_item in task_template["container"]["jobParameters"]:
0437 if tmp_item["type"] == "constant" and tmp_item["value"].startswith("-a "):
0438 source_name = tmp_item["value"].split()[-1]
0439
0440 com = shlex.split(dict_inputs["opt_args"])
0441 if "opt_trainingDS" in dict_inputs and dict_inputs["opt_trainingDS"]:
0442 if "opt_trainingDsType" not in dict_inputs or not dict_inputs["opt_trainingDsType"]:
0443 in_ds_suffix = None
0444 for parent_id in self.parents:
0445 parent_node = id_map[parent_id]
0446 if dict_inputs["opt_trainingDS"] in parent_node.convert_set_outputs():
0447 in_ds_suffix = parent_node.output_types[0]
0448 break
0449 else:
0450 in_ds_suffix = dict_inputs["opt_inDsType"]
0451 in_ds_str = f"{dict_inputs['opt_trainingDS']}_{in_ds_suffix}/"
0452 com += ["--trainingDS", in_ds_str]
0453 com += ["--outDS", task_name]
0454
0455 task_params = PhpoScript.main(True, com, dry_mode=True)
0456
0457 new_job_params = []
0458 for tmp_item in task_params["jobParameters"]:
0459 if tmp_item["type"] == "constant" and tmp_item["value"].startswith("-a "):
0460 tmp_item["value"] = f"-a {source_name} --sourceURL {source_url}"
0461 new_job_params.append(tmp_item)
0462 task_params["jobParameters"] = new_job_params
0463
0464 return task_params
0465 elif self.type == "gitlab":
0466 dict_inputs = self.convert_dict_inputs(skip_suppressed=True)
0467 list_in_ds = self.get_input_ds_list(dict_inputs, id_map)
0468 task_params = copy.copy(task_template["container"])
0469 task_params["taskName"] = task_name
0470 task_params["noInput"] = True
0471 task_params["nEventsPerJob"] = 1
0472 task_params["nEvents"] = 1
0473 task_params["processingType"] = re.sub(r"-[^-]+$", "-gitlab", task_params["processingType"])
0474 task_params["useSecrets"] = True
0475 task_params["site"] = dict_inputs["opt_site"]
0476 task_params["cliParams"] = ""
0477 task_params["log"]["container"] = task_params["log"]["dataset"] = f"{task_name}.log/"
0478
0479 task_params["jobParameters"] = [
0480 {
0481 "type": "constant",
0482 "value": json.dumps(
0483 {
0484 "project_api": dict_inputs["opt_api"],
0485 "project_id": int(dict_inputs["opt_projectID"]),
0486 "ref": dict_inputs["opt_ref"],
0487 "trigger_token": dict_inputs["opt_triggerToken"],
0488 "access_token": dict_inputs["opt_accessToken"],
0489 "input_datasets": ",".join(list_in_ds),
0490 "input_location": dict_inputs.get("opt_input_location"),
0491 }
0492 ),
0493 }
0494 ]
0495
0496 del task_params["container_name"]
0497 del task_params["multiStepExec"]
0498 return task_params
0499 return None
0500
0501
0502 def get_global_parameters(self):
0503 if self.is_leaf:
0504 root_inputs = self.upper_root_inputs
0505 else:
0506 root_inputs = self.root_inputs
0507 if root_inputs is None:
0508 return None, None
0509 loop_params = {}
0510 workflow_params = {}
0511 for k, v in root_inputs.items():
0512 m = self.get_loop_param_name(k)
0513 if m:
0514 loop_params[m] = v
0515 else:
0516 param = k.split("#")[-1]
0517 workflow_params[param] = v
0518 return loop_params, workflow_params
0519
0520
0521 def get_all_sub_node_ids(self, all_ids=None):
0522 if all_ids is None:
0523 all_ids = set()
0524 all_ids.add(self.id)
0525 for sub_node in self.sub_nodes:
0526 all_ids.add(sub_node.id)
0527 if not sub_node.is_leaf:
0528 sub_node.get_all_sub_node_ids(all_ids)
0529 return all_ids
0530
0531
0532 def get_loop_param_name(self, k):
0533 param = k.split("#")[-1]
0534 m = re.search(r"^param_(.+)", param)
0535 if m:
0536 return m.group(1)
0537 return None
0538
0539
0540 def get_input_ds_list(self, dict_inputs, id_map):
0541 if "opt_inDS" not in dict_inputs:
0542 return []
0543 if isinstance(dict_inputs["opt_inDS"], list):
0544 is_list_in_ds = True
0545 else:
0546 is_list_in_ds = False
0547 if "opt_inDsType" not in dict_inputs or not dict_inputs["opt_inDsType"]:
0548 if is_list_in_ds:
0549 in_ds_suffix = []
0550 in_ds_list = dict_inputs["opt_inDS"]
0551 else:
0552 in_ds_suffix = None
0553 in_ds_list = [dict_inputs["opt_inDS"]]
0554 for tmp_in_ds in in_ds_list:
0555 for parent_id in self.parents:
0556 parent_node = id_map[parent_id]
0557 if tmp_in_ds in parent_node.convert_set_outputs():
0558 if is_list_in_ds:
0559 in_ds_suffix.append(parent_node.output_types[0])
0560 else:
0561 in_ds_suffix = parent_node.output_types[0]
0562 break
0563 else:
0564 in_ds_suffix = dict_inputs["opt_inDsType"]
0565 if "*" in in_ds_suffix:
0566 in_ds_suffix = in_ds_suffix.replace("*", "XYZ") + ".tgz"
0567 if is_list_in_ds:
0568 list_in_ds = [f"{s1}_{s2}/" if s2 else s1 for s1, s2 in zip(dict_inputs["opt_inDS"], in_ds_suffix)]
0569 else:
0570 list_in_ds = [f"{dict_inputs['opt_inDS']}_{in_ds_suffix}/" if in_ds_suffix else dict_inputs["opt_inDS"]]
0571 return list_in_ds
0572
0573
0574
0575 def dump_nodes(node_list, dump_str=None, only_leaves=False):
0576 if dump_str is None:
0577 dump_str = "\n"
0578 for node in node_list:
0579 if node.is_leaf:
0580 dump_str += f"{node}"
0581 if node.task_params is not None:
0582 dump_str += json.dumps(node.task_params, indent=4, sort_keys=True)
0583 dump_str += "\n\n"
0584 else:
0585 if not only_leaves:
0586 dump_str += f"{node}\n"
0587 dump_str = dump_nodes(node.sub_nodes, dump_str, only_leaves)
0588 return dump_str
0589
0590
0591
0592 def get_node_id_map(node_list, id_map=None):
0593 if id_map is None:
0594 id_map = {}
0595 for node in node_list:
0596 id_map[node.id] = node
0597 if node.sub_nodes:
0598 id_map = get_node_id_map(node.sub_nodes, id_map)
0599 return id_map
0600
0601
0602
0603 def get_all_parents(node_list, all_parents=None):
0604 if all_parents is None:
0605 all_parents = set()
0606 for node in node_list:
0607 all_parents |= node.parents
0608 if node.sub_nodes:
0609 all_parents = get_all_parents(node.sub_nodes, all_parents)
0610 return all_parents
0611
0612
0613
0614 def set_workflow_outputs(node_list, all_parents=None):
0615 if all_parents is None:
0616 all_parents = get_all_parents(node_list)
0617 for node in node_list:
0618 if node.is_leaf and node.id not in all_parents:
0619 node.is_workflow_output = True
0620 if node.sub_nodes:
0621 set_workflow_outputs(node.sub_nodes, all_parents)
0622
0623
0624
0625
0626
0627
0628
0629
0630
0631
0632
0633
0634
0635
0636
0637
0638
0639
0640
0641
0642
0643
0644
0645
0646
0647
0648
0649
0650
0651
0652
0653
0654
0655
0656
0657
0658
0659
0660
0661
0662
0663
0664
0665 def resolve_nodes(node_list, root_inputs, data, serial_id, parent_ids, out_ds_name, log_stream):
0666 for k in root_inputs:
0667 kk = k.split("#")[-1]
0668 if kk in data:
0669 root_inputs[k] = data[kk]
0670 tmp_to_real_id_map = {}
0671 resolved_map = {}
0672
0673 node_key_map = {}
0674 all_nodes = []
0675 for node in node_list:
0676
0677 for tmp_name, tmp_data in node.inputs.items():
0678 if not tmp_data["source"]:
0679 continue
0680 if isinstance(tmp_data["source"], list):
0681 tmp_sources = tmp_data["source"]
0682 if "parent_id" in tmp_data:
0683
0684 tmp_parent_ids = list(tmp_data["parent_id"])
0685 tmp_parent_ids += [None] * (len(tmp_sources) - len(tmp_parent_ids))
0686 else:
0687 tmp_parent_ids = [None] * len(tmp_sources)
0688 else:
0689 tmp_sources = [tmp_data["source"]]
0690 if "parent_id" in tmp_data:
0691 tmp_parent_ids = [tmp_data["parent_id"]]
0692 else:
0693 tmp_parent_ids = [None] * len(tmp_sources)
0694 for tmp_source, tmp_parent_id in zip(tmp_sources, tmp_parent_ids):
0695 isOK = False
0696
0697 if tmp_source in root_inputs:
0698 node.is_head = True
0699 node.set_input_value(tmp_name, tmp_source, root_inputs[tmp_source])
0700 continue
0701
0702 for i in node.parents:
0703 for r_node in resolved_map[i]:
0704 if tmp_source in r_node.outputs:
0705 node.set_input_value(
0706 tmp_name,
0707 tmp_source,
0708 r_node.outputs[tmp_source]["value"],
0709 )
0710 isOK = True
0711 break
0712 if isOK:
0713 break
0714 if isOK:
0715 continue
0716
0717 if tmp_parent_id is not None:
0718 values = [list(r_node.outputs.values())[0]["value"] for r_node in resolved_map[tmp_parent_id]]
0719 if len(values) == 1:
0720 values = values[0]
0721 node.set_input_value(tmp_name, tmp_source, values)
0722 continue
0723
0724 if node.scatter:
0725
0726 scatters = None
0727 sc_nodes = []
0728 for item in node.scatter:
0729 if scatters is None:
0730 scatters = [{item: v} for v in node.inputs[item]["value"]]
0731 else:
0732 [i.update({item: v}) for i, v in zip(scatters, node.inputs[item]["value"])]
0733 for idx, item in enumerate(scatters):
0734 sc_node = copy.deepcopy(node)
0735 for k, v in item.items():
0736 sc_node.inputs[k]["value"] = v
0737 for tmp_node in sc_node.sub_nodes:
0738 tmp_node.scatter_index = idx
0739 tmp_node.upper_root_inputs = sc_node.root_inputs
0740 sc_nodes.append(sc_node)
0741 else:
0742 sc_nodes = [node]
0743
0744 for sc_node in sc_nodes:
0745 original_node_id = sc_node.id
0746 all_nodes.append(sc_node)
0747 node_key_map[id(sc_node)] = original_node_id
0748
0749 resolved_map.setdefault(original_node_id, [])
0750 tmp_to_real_id_map.setdefault(original_node_id, set())
0751
0752 real_parens = set()
0753 for i in sc_node.parents:
0754 real_parens |= tmp_to_real_id_map[i]
0755 sc_node.parents = real_parens
0756 if sc_node.is_head:
0757 sc_node.parents |= parent_ids
0758 if sc_node.is_leaf:
0759 resolved_map[original_node_id].append(sc_node)
0760 tmp_to_real_id_map[original_node_id].add(serial_id)
0761 sc_node.id = serial_id
0762 serial_id += 1
0763 else:
0764 serial_id, sub_tail_nodes, sc_node.sub_nodes = resolve_nodes(
0765 sc_node.sub_nodes,
0766 sc_node.root_inputs,
0767 sc_node.convert_dict_inputs(),
0768 serial_id,
0769 sc_node.parents,
0770 out_ds_name,
0771 log_stream,
0772 )
0773 resolved_map[original_node_id] += sub_tail_nodes
0774 tmp_to_real_id_map[original_node_id] |= set([n.id for n in sub_tail_nodes])
0775 sc_node.id = serial_id
0776 serial_id += 1
0777
0778
0779 if sc_node.condition:
0780 pass
0781
0782
0783 if sc_node.is_leaf:
0784 for tmp_name, tmp_data in sc_node.outputs.items():
0785 tmp_data["value"] = f"{out_ds_name}_{sc_node.id:03d}_{sc_node.name}"
0786
0787 if sc_node.in_loop:
0788 tmp_data["value"] += ".___idds___num_run___"
0789
0790 tail_nodes = []
0791 for node in all_nodes:
0792 original_node_id = node_key_map.get(id(node), node.id)
0793 if node.is_tail:
0794 tail_nodes.append(node)
0795 else:
0796 tail_nodes += resolved_map[original_node_id]
0797 return serial_id, tail_nodes, all_nodes
0798
0799
0800
0801 def parse_workflow_data(data, log_stream):
0802
0803 workflow_data = data.get("workflow", data)
0804
0805
0806 root_inputs = workflow_data.get("inputs", {})
0807 root_outputs = workflow_data.get("outputs", {})
0808 tail_node_names = {output_spec["from"].split("/")[0] for output_spec in root_outputs.values() if isinstance(output_spec, dict) and "from" in output_spec}
0809
0810
0811 steps = workflow_data.get("steps", {})
0812 node_list = []
0813 node_name_map = {}
0814 serial_id = 0
0815
0816
0817 for step_name, step_spec in steps.items():
0818 serial_id += 1
0819 step_type = step_spec.get("type", "prun")
0820 is_leaf = step_type in ["prun", "phpo", "junction", "reana", "gitlab"]
0821 node = Node(serial_id, step_type, None, is_leaf, step_name)
0822 node_name_map[step_name] = node
0823
0824
0825 inputs = {}
0826 for key, yaml_key in [
0827 ("inDS", "opt_inDS"),
0828 ("args", "opt_args"),
0829 ("exec", "opt_exec"),
0830 ("containerImage", "opt_containerImage"),
0831 ("useAthenaPackages", "opt_useAthenaPackages"),
0832 ("secondaryDSs", "opt_secondaryDSs"),
0833 ("secondaryDsTypes", "opt_secondaryDsTypes"),
0834 ]:
0835 if key in step_spec:
0836 inputs[f"{step_name}/{yaml_key}"] = {
0837 "default": step_spec.get(key) if key not in ["inDS", "secondaryDSs"] else None,
0838 "source": step_spec.get(key) if key in ["inDS", "secondaryDSs"] else None,
0839 }
0840
0841 node.inputs = inputs
0842 node.outputs = {f"{step_name}/outDS": {}}
0843 node.is_tail = step_name in tail_node_names
0844 node_list.append(node)
0845
0846
0847 for node in node_list:
0848 for input_name, input_data in node.inputs.items():
0849 source = input_data.get("source")
0850 if not source:
0851 continue
0852
0853
0854 if isinstance(source, str):
0855 if source.startswith("{") and source.endswith("}"):
0856 input_data["source"] = source[1:-1]
0857 elif "/" in source:
0858 source_node_name = source.split("/")[0]
0859 if source_node_name in node_name_map:
0860 parent = node_name_map[source_node_name]
0861 node.add_parent(parent.id)
0862 input_data["parent_id"] = parent.id
0863
0864 elif isinstance(source, list):
0865 parent_ids = []
0866 for src in source:
0867 if isinstance(src, str) and "/" in src:
0868 source_node_name = src.split("/")[0]
0869 if source_node_name in node_name_map:
0870 parent = node_name_map[source_node_name]
0871 node.add_parent(parent.id)
0872 parent_ids.append(parent.id)
0873 if parent_ids:
0874 input_data["parent_id"] = parent_ids
0875
0876
0877 visited = set()
0878 sorted_nodes = []
0879
0880 def visit(n):
0881 if n.id in visited:
0882 return
0883 for parent_id in n.parents:
0884 for other in node_list:
0885 if other.id == parent_id:
0886 visit(other)
0887 visited.add(n.id)
0888 sorted_nodes.append(n)
0889
0890 for node in node_list:
0891 visit(node)
0892
0893 return sorted_nodes, root_inputs