Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-10 08:39:08

0001 import copy
0002 import os.path
0003 import re
0004 from pathlib import Path
0005 from urllib.parse import urlparse
0006 
0007 from .workflow_utils import ConditionItem, Node
0008 
0009 WORKFLOW_NAMES = ["prun", "phpo", "junction", "reana", "gitlab"]
0010 
0011 
0012 # extract id
0013 def extract_id(id_str):
0014     if not id_str:
0015         return id_str
0016     if not isinstance(id_str, list):
0017         id_str = [id_str]
0018         not_list = True
0019     else:
0020         not_list = False
0021     items = [re.search(r"[^/]+#.+$", s).group(0) for s in id_str]
0022     if not_list:
0023         return items[0]
0024     return items
0025 
0026 
0027 # topological sorting
0028 def top_sort(list_data, visited):
0029     if not list_data:
0030         return []
0031     new_list = []
0032     new_visited = []
0033     for node in list_data:
0034         isOK = True
0035         for p in node.parents:
0036             if p not in visited:
0037                 isOK = False
0038                 break
0039         if isOK:
0040             new_visited.append(node)
0041             visited.add(node.id)
0042         else:
0043             new_list.append(node)
0044     return new_visited + top_sort(new_list, visited)
0045 
0046 
0047 # parse CWL file
0048 def parse_workflow_file(workflow_file, log_stream, in_loop=False):
0049     # read the file from yaml
0050     cwl_file = Path(os.path.abspath(workflow_file))
0051     cwl_dir = cwl_file.parent
0052 
0053     from cwl_utils.parser import load_document_by_uri
0054 
0055     # make fake executables to skip validation check in CWL
0056     for tmp_exec in WORKFLOW_NAMES:
0057         tmp_exec_path = cwl_dir / tmp_exec
0058         if not tmp_exec_path.exists():
0059             tmp_exec_path.touch()
0060 
0061     # Import CWL Object
0062     root_obj = load_document_by_uri(cwl_file)
0063 
0064     # root inputs
0065     root_inputs = {extract_id(s.id): s.default for s in root_obj.inputs}
0066 
0067     root_outputs = set([re.sub(re.sub(extract_id(s.id), "", s.id), "", s.outputSource) for s in root_obj.outputs])
0068 
0069     # loop over steps
0070     node_list = []
0071     output_map = {}
0072     serial_id = 0
0073     for step in root_obj.steps:
0074         cwl_name = os.path.basename(step.run)
0075         # check cwl command
0076         if not cwl_name.endswith(".cwl") and cwl_name not in WORKFLOW_NAMES:
0077             log_stream.error(f"Unknown workflow {step.run}")
0078             return False, None
0079         serial_id += 1
0080         workflow_name = step.id.split("#")[-1]
0081         # leaf workflow and sub-workflow
0082         if cwl_name == "prun.cwl":
0083             node = Node(serial_id, "prun", None, True, workflow_name)
0084         elif cwl_name == "phpo.cwl":
0085             node = Node(serial_id, "phpo", None, True, workflow_name)
0086         elif cwl_name in WORKFLOW_NAMES:
0087             node = Node(serial_id, cwl_name, None, True, workflow_name)
0088         else:
0089             node = Node(serial_id, "workflow", None, False, workflow_name)
0090         node.inputs = {extract_id(s.id): {"default": s.default, "source": extract_id(s.source)} for s in step.in_}
0091         node.outputs = {extract_id(s): {} for s in step.out}
0092         # add outDS if no output is defined
0093         if not node.outputs:
0094             node.outputs = {extract_id(step.id + "/outDS"): {}}
0095         output_map.update({name: serial_id for name in node.outputs})
0096         if step.scatter:
0097             node.scatter = [extract_id(s) for s in step.scatter]
0098         if hasattr(step, "when") and step.when:
0099             # parse condition
0100             node.condition = parse_condition_string(step.when)
0101             # suppress inputs based on condition
0102             suppress_inputs_based_on_condition(node.condition, node.inputs)
0103         if step.hints and "loop" in step.hints:
0104             node.loop = True
0105         if node.loop or in_loop:
0106             node.in_loop = True
0107         # expand sub-workflow
0108         if not node.is_leaf:
0109             p = urlparse(step.run)
0110             tmp_path = os.path.abspath(os.path.join(p.netloc, p.path))
0111             node.sub_nodes, node.root_inputs = parse_workflow_file(tmp_path, log_stream, node.in_loop)
0112         # check if tail
0113         if root_outputs & set(node.outputs):
0114             node.is_tail = True
0115         node_list.append(node)
0116 
0117     # look for parents
0118     for node in node_list:
0119         for tmp_name, tmp_data in node.inputs.items():
0120             if not tmp_data["source"]:
0121                 continue
0122             if isinstance(tmp_data["source"], list):
0123                 sources = tmp_data["source"]
0124                 is_str = False
0125             else:
0126                 sources = [tmp_data["source"]]
0127                 is_str = True
0128             parent_ids = []
0129             for tmp_source in sources:
0130                 if tmp_source in output_map:
0131                     parent_id = output_map[tmp_source]
0132                     node.add_parent(parent_id)
0133                     parent_ids.append(parent_id)
0134             if parent_ids:
0135                 if is_str:
0136                     parent_ids = parent_ids[0]
0137                 tmp_data["parent_id"] = parent_ids
0138 
0139     # sort
0140     node_list = top_sort(node_list, set())
0141     return node_list, root_inputs
0142 
0143 
0144 # resolve nodes
0145 def resolve_nodes(node_list, root_inputs, data, serial_id, parent_ids, out_ds_name, log_stream):
0146     for k in root_inputs:
0147         kk = k.split("#")[-1]
0148         if kk in data:
0149             root_inputs[k] = data[kk]
0150     tmp_to_real_id_map = {}
0151     resolved_map = {}
0152     all_nodes = []
0153     for node in node_list:
0154         # resolve input
0155         for tmp_name, tmp_data in node.inputs.items():
0156             if not tmp_data["source"]:
0157                 continue
0158             if isinstance(tmp_data["source"], list):
0159                 tmp_sources = tmp_data["source"]
0160                 if "parent_id" in tmp_data:
0161                     tmp_parent_ids = tmp_data["parent_id"]
0162                     tmp_parent_ids += [None] * (len(tmp_sources) - len(tmp_parent_ids))
0163                 else:
0164                     tmp_parent_ids = [None] * len(tmp_sources)
0165             else:
0166                 tmp_sources = [tmp_data["source"]]
0167                 if "parent_id" in tmp_data:
0168                     tmp_parent_ids = [tmp_data["parent_id"]]
0169                 else:
0170                     tmp_parent_ids = [None] * len(tmp_sources)
0171             for tmp_source, tmp_parent_id in zip(tmp_sources, tmp_parent_ids):
0172                 isOK = False
0173                 # check root input
0174                 if tmp_source in root_inputs:
0175                     node.is_head = True
0176                     node.set_input_value(tmp_name, tmp_source, root_inputs[tmp_source])
0177                     continue
0178                 # check parent output
0179                 for i in node.parents:
0180                     for r_node in resolved_map[i]:
0181                         if tmp_source in r_node.outputs:
0182                             node.set_input_value(
0183                                 tmp_name,
0184                                 tmp_source,
0185                                 r_node.outputs[tmp_source]["value"],
0186                             )
0187                             isOK = True
0188                             break
0189                     if isOK:
0190                         break
0191                 if isOK:
0192                     continue
0193                 # check resolved parent outputs
0194                 if tmp_parent_id is not None:
0195                     values = [list(r_node.outputs.values())[0]["value"] for r_node in resolved_map[tmp_parent_id]]
0196                     if len(values) == 1:
0197                         values = values[0]
0198                     node.set_input_value(tmp_name, tmp_source, values)
0199                     continue
0200         # scatter
0201         if node.scatter:
0202             # resolve scattered parameters
0203             scatters = None
0204             sc_nodes = []
0205             for item in node.scatter:
0206                 if scatters is None:
0207                     scatters = [{item: v} for v in node.inputs[item]["value"]]
0208                 else:
0209                     [i.update({item: v}) for i, v in zip(scatters, node.inputs[item]["value"])]
0210             for idx, item in enumerate(scatters):
0211                 sc_node = copy.deepcopy(node)
0212                 for k, v in item.items():
0213                     sc_node.inputs[k]["value"] = v
0214                 for tmp_node in sc_node.sub_nodes:
0215                     tmp_node.scatter_index = idx
0216                     tmp_node.upper_root_inputs = sc_node.root_inputs
0217                 sc_nodes.append(sc_node)
0218         else:
0219             sc_nodes = [node]
0220         # loop over scattered nodes
0221         for sc_node in sc_nodes:
0222             all_nodes.append(sc_node)
0223             # set real node ID
0224             resolved_map.setdefault(sc_node.id, [])
0225             tmp_to_real_id_map.setdefault(sc_node.id, set())
0226             # resolve parents
0227             real_parens = set()
0228             for i in sc_node.parents:
0229                 real_parens |= tmp_to_real_id_map[i]
0230             sc_node.parents = real_parens
0231             if sc_node.is_head:
0232                 sc_node.parents |= parent_ids
0233             if sc_node.is_leaf:
0234                 resolved_map[sc_node.id].append(sc_node)
0235                 tmp_to_real_id_map[sc_node.id].add(serial_id)
0236                 sc_node.id = serial_id
0237                 serial_id += 1
0238             else:
0239                 serial_id, sub_tail_nodes, sc_node.sub_nodes = resolve_nodes(
0240                     sc_node.sub_nodes,
0241                     sc_node.root_inputs,
0242                     sc_node.convert_dict_inputs(),
0243                     serial_id,
0244                     sc_node.parents,
0245                     out_ds_name,
0246                     log_stream,
0247                 )
0248                 resolved_map[sc_node.id] += sub_tail_nodes
0249                 tmp_to_real_id_map[sc_node.id] |= set([n.id for n in sub_tail_nodes])
0250                 sc_node.id = serial_id
0251                 serial_id += 1
0252             # convert parameters to parent IDs in conditions
0253             if sc_node.condition:
0254                 convert_params_in_condition_to_parent_ids(sc_node.condition, sc_node.inputs, tmp_to_real_id_map)
0255             # resolve outputs
0256             if sc_node.is_leaf:
0257                 for tmp_name, tmp_data in sc_node.outputs.items():
0258                     tmp_data["value"] = f"{out_ds_name}_{sc_node.id:03d}_{sc_node.name}"
0259                     # add loop count for nodes in a loop
0260                     if sc_node.in_loop:
0261                         tmp_data["value"] += ".___idds___num_run___"
0262     # return tails
0263     tail_nodes = []
0264     for node in all_nodes:
0265         if node.is_tail:
0266             if node.is_tail:
0267                 tail_nodes.append(node)
0268             else:
0269                 tail_nodes += resolved_map[node.id]
0270     return serial_id, tail_nodes, all_nodes
0271 
0272 
0273 # parse condition string
0274 def parse_condition_string(cond_string):
0275     # remove $()
0276     cond_string = re.sub(r"\$\((?P<aaa>.+)\)", r"\g<aaa>", cond_string)
0277     cond_map = {}
0278     id = 0
0279     while True:
0280         # look for the most inner parentheses
0281         item_list = re.findall(r"\(([^\(\)]+)\)", cond_string)
0282         if not item_list:
0283             return convert_plain_condition_string(cond_string, cond_map)
0284         else:
0285             for item in item_list:
0286                 cond = convert_plain_condition_string(item, cond_map)
0287                 key = f"___{id}___"
0288                 id += 1
0289                 cond_map[key] = cond
0290                 cond_string = cond_string.replace("(" + item + ")", key)
0291 
0292 
0293 # extract parameter from token
0294 def extract_parameter(token):
0295     m = re.search(r"self\.([^!=]+)", token)
0296     return m.group(1)
0297 
0298 
0299 # convert plain condition string
0300 def convert_plain_condition_string(cond_string, cond_map):
0301     cond_string = re.sub(r" *! *", r"!", cond_string)
0302     cond_string = re.sub(r"\|\|", r" || ", cond_string)
0303     cond_string = re.sub(r"&&", r" && ", cond_string)
0304 
0305     tokens = cond_string.split()
0306     left = None
0307     operator = None
0308     for token in tokens:
0309         token = token.strip()
0310         if token == "||":
0311             operator = "or"
0312             continue
0313         elif token == "&&":
0314             operator = "and"
0315             continue
0316         elif token.startswith("self."):
0317             param = extract_parameter(token)
0318             right = ConditionItem(param)
0319             if not left:
0320                 left = right
0321                 continue
0322         elif token.startswith("!self."):
0323             param = extract_parameter(token)
0324             right = ConditionItem(param, operator="not")
0325             if not left:
0326                 left = right
0327                 continue
0328         elif re.search(r"^___\d+___$", token) and token in cond_map:
0329             right = cond_map[token]
0330             if not left:
0331                 left = right
0332                 continue
0333         elif re.search(r"^!___\d+___$", token) and token[1:] in cond_map:
0334             right = ConditionItem(cond_map[token[1:]], operator="not")
0335             if not left:
0336                 left = right
0337                 continue
0338         else:
0339             raise TypeError(f'unknown token "{token}"')
0340 
0341         left = ConditionItem(left, right, operator)
0342     return left
0343 
0344 
0345 # convert parameter names to parent IDs
0346 def convert_params_in_condition_to_parent_ids(condition_item, input_data, id_map):
0347     for item in ["left", "right"]:
0348         param = getattr(condition_item, item)
0349         if isinstance(param, str):
0350             m = re.search(r"^[^\[]+\[(\d+)\]", param)
0351             if m:
0352                 param = param.split("[")[0]
0353                 idx = int(m.group(1))
0354             else:
0355                 idx = None
0356             isOK = False
0357             for tmp_name, tmp_data in input_data.items():
0358                 if param == tmp_name.split("/")[-1]:
0359                     isOK = True
0360                     if isinstance(tmp_data["parent_id"], list):
0361                         if idx is not None:
0362                             setattr(condition_item, item, id_map[tmp_data["parent_id"][idx]])
0363                         else:
0364                             setattr(condition_item, item, id_map[tmp_data["parent_id"]])
0365                     else:
0366                         setattr(condition_item, item, id_map[tmp_data["parent_id"]])
0367                     break
0368             if not isOK:
0369                 raise ReferenceError(f"unresolved parameter {param} in the condition string")
0370         elif isinstance(param, ConditionItem):
0371             convert_params_in_condition_to_parent_ids(param, input_data, id_map)
0372 
0373 
0374 # suppress inputs based on condition
0375 def suppress_inputs_based_on_condition(condition_item, input_data):
0376     if condition_item.right is None and condition_item.operator == "not" and isinstance(condition_item.left, str):
0377         for tmp_name, tmp_data in input_data.items():
0378             if condition_item.left == tmp_name.split("/")[-1]:
0379                 tmp_data["suppressed"] = True
0380     else:
0381         for item in ["left", "right"]:
0382             param = getattr(condition_item, item)
0383             if isinstance(param, ConditionItem):
0384                 suppress_inputs_based_on_condition(param, input_data)