File indexing completed on 2026-04-10 08:39:08
0001 import copy
0002 import os.path
0003 import re
0004 from pathlib import Path
0005 from urllib.parse import urlparse
0006
0007 from .workflow_utils import ConditionItem, Node
0008
0009 WORKFLOW_NAMES = ["prun", "phpo", "junction", "reana", "gitlab"]
0010
0011
0012
0013 def extract_id(id_str):
0014 if not id_str:
0015 return id_str
0016 if not isinstance(id_str, list):
0017 id_str = [id_str]
0018 not_list = True
0019 else:
0020 not_list = False
0021 items = [re.search(r"[^/]+#.+$", s).group(0) for s in id_str]
0022 if not_list:
0023 return items[0]
0024 return items
0025
0026
0027
0028 def top_sort(list_data, visited):
0029 if not list_data:
0030 return []
0031 new_list = []
0032 new_visited = []
0033 for node in list_data:
0034 isOK = True
0035 for p in node.parents:
0036 if p not in visited:
0037 isOK = False
0038 break
0039 if isOK:
0040 new_visited.append(node)
0041 visited.add(node.id)
0042 else:
0043 new_list.append(node)
0044 return new_visited + top_sort(new_list, visited)
0045
0046
0047
0048 def parse_workflow_file(workflow_file, log_stream, in_loop=False):
0049
0050 cwl_file = Path(os.path.abspath(workflow_file))
0051 cwl_dir = cwl_file.parent
0052
0053 from cwl_utils.parser import load_document_by_uri
0054
0055
0056 for tmp_exec in WORKFLOW_NAMES:
0057 tmp_exec_path = cwl_dir / tmp_exec
0058 if not tmp_exec_path.exists():
0059 tmp_exec_path.touch()
0060
0061
0062 root_obj = load_document_by_uri(cwl_file)
0063
0064
0065 root_inputs = {extract_id(s.id): s.default for s in root_obj.inputs}
0066
0067 root_outputs = set([re.sub(re.sub(extract_id(s.id), "", s.id), "", s.outputSource) for s in root_obj.outputs])
0068
0069
0070 node_list = []
0071 output_map = {}
0072 serial_id = 0
0073 for step in root_obj.steps:
0074 cwl_name = os.path.basename(step.run)
0075
0076 if not cwl_name.endswith(".cwl") and cwl_name not in WORKFLOW_NAMES:
0077 log_stream.error(f"Unknown workflow {step.run}")
0078 return False, None
0079 serial_id += 1
0080 workflow_name = step.id.split("#")[-1]
0081
0082 if cwl_name == "prun.cwl":
0083 node = Node(serial_id, "prun", None, True, workflow_name)
0084 elif cwl_name == "phpo.cwl":
0085 node = Node(serial_id, "phpo", None, True, workflow_name)
0086 elif cwl_name in WORKFLOW_NAMES:
0087 node = Node(serial_id, cwl_name, None, True, workflow_name)
0088 else:
0089 node = Node(serial_id, "workflow", None, False, workflow_name)
0090 node.inputs = {extract_id(s.id): {"default": s.default, "source": extract_id(s.source)} for s in step.in_}
0091 node.outputs = {extract_id(s): {} for s in step.out}
0092
0093 if not node.outputs:
0094 node.outputs = {extract_id(step.id + "/outDS"): {}}
0095 output_map.update({name: serial_id for name in node.outputs})
0096 if step.scatter:
0097 node.scatter = [extract_id(s) for s in step.scatter]
0098 if hasattr(step, "when") and step.when:
0099
0100 node.condition = parse_condition_string(step.when)
0101
0102 suppress_inputs_based_on_condition(node.condition, node.inputs)
0103 if step.hints and "loop" in step.hints:
0104 node.loop = True
0105 if node.loop or in_loop:
0106 node.in_loop = True
0107
0108 if not node.is_leaf:
0109 p = urlparse(step.run)
0110 tmp_path = os.path.abspath(os.path.join(p.netloc, p.path))
0111 node.sub_nodes, node.root_inputs = parse_workflow_file(tmp_path, log_stream, node.in_loop)
0112
0113 if root_outputs & set(node.outputs):
0114 node.is_tail = True
0115 node_list.append(node)
0116
0117
0118 for node in node_list:
0119 for tmp_name, tmp_data in node.inputs.items():
0120 if not tmp_data["source"]:
0121 continue
0122 if isinstance(tmp_data["source"], list):
0123 sources = tmp_data["source"]
0124 is_str = False
0125 else:
0126 sources = [tmp_data["source"]]
0127 is_str = True
0128 parent_ids = []
0129 for tmp_source in sources:
0130 if tmp_source in output_map:
0131 parent_id = output_map[tmp_source]
0132 node.add_parent(parent_id)
0133 parent_ids.append(parent_id)
0134 if parent_ids:
0135 if is_str:
0136 parent_ids = parent_ids[0]
0137 tmp_data["parent_id"] = parent_ids
0138
0139
0140 node_list = top_sort(node_list, set())
0141 return node_list, root_inputs
0142
0143
0144
0145 def resolve_nodes(node_list, root_inputs, data, serial_id, parent_ids, out_ds_name, log_stream):
0146 for k in root_inputs:
0147 kk = k.split("#")[-1]
0148 if kk in data:
0149 root_inputs[k] = data[kk]
0150 tmp_to_real_id_map = {}
0151 resolved_map = {}
0152 all_nodes = []
0153 for node in node_list:
0154
0155 for tmp_name, tmp_data in node.inputs.items():
0156 if not tmp_data["source"]:
0157 continue
0158 if isinstance(tmp_data["source"], list):
0159 tmp_sources = tmp_data["source"]
0160 if "parent_id" in tmp_data:
0161 tmp_parent_ids = tmp_data["parent_id"]
0162 tmp_parent_ids += [None] * (len(tmp_sources) - len(tmp_parent_ids))
0163 else:
0164 tmp_parent_ids = [None] * len(tmp_sources)
0165 else:
0166 tmp_sources = [tmp_data["source"]]
0167 if "parent_id" in tmp_data:
0168 tmp_parent_ids = [tmp_data["parent_id"]]
0169 else:
0170 tmp_parent_ids = [None] * len(tmp_sources)
0171 for tmp_source, tmp_parent_id in zip(tmp_sources, tmp_parent_ids):
0172 isOK = False
0173
0174 if tmp_source in root_inputs:
0175 node.is_head = True
0176 node.set_input_value(tmp_name, tmp_source, root_inputs[tmp_source])
0177 continue
0178
0179 for i in node.parents:
0180 for r_node in resolved_map[i]:
0181 if tmp_source in r_node.outputs:
0182 node.set_input_value(
0183 tmp_name,
0184 tmp_source,
0185 r_node.outputs[tmp_source]["value"],
0186 )
0187 isOK = True
0188 break
0189 if isOK:
0190 break
0191 if isOK:
0192 continue
0193
0194 if tmp_parent_id is not None:
0195 values = [list(r_node.outputs.values())[0]["value"] for r_node in resolved_map[tmp_parent_id]]
0196 if len(values) == 1:
0197 values = values[0]
0198 node.set_input_value(tmp_name, tmp_source, values)
0199 continue
0200
0201 if node.scatter:
0202
0203 scatters = None
0204 sc_nodes = []
0205 for item in node.scatter:
0206 if scatters is None:
0207 scatters = [{item: v} for v in node.inputs[item]["value"]]
0208 else:
0209 [i.update({item: v}) for i, v in zip(scatters, node.inputs[item]["value"])]
0210 for idx, item in enumerate(scatters):
0211 sc_node = copy.deepcopy(node)
0212 for k, v in item.items():
0213 sc_node.inputs[k]["value"] = v
0214 for tmp_node in sc_node.sub_nodes:
0215 tmp_node.scatter_index = idx
0216 tmp_node.upper_root_inputs = sc_node.root_inputs
0217 sc_nodes.append(sc_node)
0218 else:
0219 sc_nodes = [node]
0220
0221 for sc_node in sc_nodes:
0222 all_nodes.append(sc_node)
0223
0224 resolved_map.setdefault(sc_node.id, [])
0225 tmp_to_real_id_map.setdefault(sc_node.id, set())
0226
0227 real_parens = set()
0228 for i in sc_node.parents:
0229 real_parens |= tmp_to_real_id_map[i]
0230 sc_node.parents = real_parens
0231 if sc_node.is_head:
0232 sc_node.parents |= parent_ids
0233 if sc_node.is_leaf:
0234 resolved_map[sc_node.id].append(sc_node)
0235 tmp_to_real_id_map[sc_node.id].add(serial_id)
0236 sc_node.id = serial_id
0237 serial_id += 1
0238 else:
0239 serial_id, sub_tail_nodes, sc_node.sub_nodes = resolve_nodes(
0240 sc_node.sub_nodes,
0241 sc_node.root_inputs,
0242 sc_node.convert_dict_inputs(),
0243 serial_id,
0244 sc_node.parents,
0245 out_ds_name,
0246 log_stream,
0247 )
0248 resolved_map[sc_node.id] += sub_tail_nodes
0249 tmp_to_real_id_map[sc_node.id] |= set([n.id for n in sub_tail_nodes])
0250 sc_node.id = serial_id
0251 serial_id += 1
0252
0253 if sc_node.condition:
0254 convert_params_in_condition_to_parent_ids(sc_node.condition, sc_node.inputs, tmp_to_real_id_map)
0255
0256 if sc_node.is_leaf:
0257 for tmp_name, tmp_data in sc_node.outputs.items():
0258 tmp_data["value"] = f"{out_ds_name}_{sc_node.id:03d}_{sc_node.name}"
0259
0260 if sc_node.in_loop:
0261 tmp_data["value"] += ".___idds___num_run___"
0262
0263 tail_nodes = []
0264 for node in all_nodes:
0265 if node.is_tail:
0266 if node.is_tail:
0267 tail_nodes.append(node)
0268 else:
0269 tail_nodes += resolved_map[node.id]
0270 return serial_id, tail_nodes, all_nodes
0271
0272
0273
0274 def parse_condition_string(cond_string):
0275
0276 cond_string = re.sub(r"\$\((?P<aaa>.+)\)", r"\g<aaa>", cond_string)
0277 cond_map = {}
0278 id = 0
0279 while True:
0280
0281 item_list = re.findall(r"\(([^\(\)]+)\)", cond_string)
0282 if not item_list:
0283 return convert_plain_condition_string(cond_string, cond_map)
0284 else:
0285 for item in item_list:
0286 cond = convert_plain_condition_string(item, cond_map)
0287 key = f"___{id}___"
0288 id += 1
0289 cond_map[key] = cond
0290 cond_string = cond_string.replace("(" + item + ")", key)
0291
0292
0293
0294 def extract_parameter(token):
0295 m = re.search(r"self\.([^!=]+)", token)
0296 return m.group(1)
0297
0298
0299
0300 def convert_plain_condition_string(cond_string, cond_map):
0301 cond_string = re.sub(r" *! *", r"!", cond_string)
0302 cond_string = re.sub(r"\|\|", r" || ", cond_string)
0303 cond_string = re.sub(r"&&", r" && ", cond_string)
0304
0305 tokens = cond_string.split()
0306 left = None
0307 operator = None
0308 for token in tokens:
0309 token = token.strip()
0310 if token == "||":
0311 operator = "or"
0312 continue
0313 elif token == "&&":
0314 operator = "and"
0315 continue
0316 elif token.startswith("self."):
0317 param = extract_parameter(token)
0318 right = ConditionItem(param)
0319 if not left:
0320 left = right
0321 continue
0322 elif token.startswith("!self."):
0323 param = extract_parameter(token)
0324 right = ConditionItem(param, operator="not")
0325 if not left:
0326 left = right
0327 continue
0328 elif re.search(r"^___\d+___$", token) and token in cond_map:
0329 right = cond_map[token]
0330 if not left:
0331 left = right
0332 continue
0333 elif re.search(r"^!___\d+___$", token) and token[1:] in cond_map:
0334 right = ConditionItem(cond_map[token[1:]], operator="not")
0335 if not left:
0336 left = right
0337 continue
0338 else:
0339 raise TypeError(f'unknown token "{token}"')
0340
0341 left = ConditionItem(left, right, operator)
0342 return left
0343
0344
0345
0346 def convert_params_in_condition_to_parent_ids(condition_item, input_data, id_map):
0347 for item in ["left", "right"]:
0348 param = getattr(condition_item, item)
0349 if isinstance(param, str):
0350 m = re.search(r"^[^\[]+\[(\d+)\]", param)
0351 if m:
0352 param = param.split("[")[0]
0353 idx = int(m.group(1))
0354 else:
0355 idx = None
0356 isOK = False
0357 for tmp_name, tmp_data in input_data.items():
0358 if param == tmp_name.split("/")[-1]:
0359 isOK = True
0360 if isinstance(tmp_data["parent_id"], list):
0361 if idx is not None:
0362 setattr(condition_item, item, id_map[tmp_data["parent_id"][idx]])
0363 else:
0364 setattr(condition_item, item, id_map[tmp_data["parent_id"]])
0365 else:
0366 setattr(condition_item, item, id_map[tmp_data["parent_id"]])
0367 break
0368 if not isOK:
0369 raise ReferenceError(f"unresolved parameter {param} in the condition string")
0370 elif isinstance(param, ConditionItem):
0371 convert_params_in_condition_to_parent_ids(param, input_data, id_map)
0372
0373
0374
0375 def suppress_inputs_based_on_condition(condition_item, input_data):
0376 if condition_item.right is None and condition_item.operator == "not" and isinstance(condition_item.left, str):
0377 for tmp_name, tmp_data in input_data.items():
0378 if condition_item.left == tmp_name.split("/")[-1]:
0379 tmp_data["suppressed"] = True
0380 else:
0381 for item in ["left", "right"]:
0382 param = getattr(condition_item, item)
0383 if isinstance(param, ConditionItem):
0384 suppress_inputs_based_on_condition(param, input_data)