Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:10:44

0001 from collections import namedtuple
0002 
0003 import numpy as np
0004 
0005 import sympy as sym
0006 from sympy import Symbol, Matrix, ImmutableMatrix, MatrixSymbol
0007 from sympy.utilities.iterables import numbered_symbols
0008 from sympy.codegen.ast import Assignment
0009 from sympy.printing.cxx import CXX17CodePrinter
0010 
0011 
0012 NamedExpr = namedtuple("NamedExpr", ["name", "expr"])
0013 
0014 
0015 def name_expr(name, expr):
0016     if hasattr(expr, "shape"):
0017         s = sym.MatrixSymbol(name, *expr.shape)
0018     else:
0019         s = Symbol(name)
0020     return NamedExpr(s, expr)
0021 
0022 
0023 def find_by_name(name_exprs, name):
0024     return next(
0025         (name_expr for name_expr in name_exprs if str(name_expr[0]) == name), None
0026     )
0027 
0028 
0029 class MyCXXCodePrinter(CXX17CodePrinter):
0030     def _traverse_matrix_indices(self, mat):
0031         rows, cols = mat.shape
0032         return ((i, j) for j in range(cols) for i in range(rows))
0033 
0034     def _print_MatrixElement(self, expr):
0035         from sympy.printing.precedence import PRECEDENCE
0036 
0037         return "{}[{}]".format(
0038             self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True),
0039             expr.i + expr.j * expr.parent.shape[0],
0040         )
0041 
0042     def _print_Pow(self, expr):
0043         from sympy.core.numbers import equal_valued, Float
0044         from sympy.codegen.ast import real
0045 
0046         suffix = self._get_func_suffix(real)
0047         if equal_valued(expr.exp, -0.5):
0048             return "%s/%ssqrt%s(%s)" % (
0049                 self._print_Float(Float(1.0)),
0050                 self._ns,
0051                 suffix,
0052                 self._print(expr.base),
0053             )
0054         return super()._print_Pow(expr)
0055 
0056 
0057 cxx_printer = MyCXXCodePrinter()
0058 
0059 
0060 def inflate_expr(name_expr):
0061     name, expr = name_expr
0062 
0063     result = []
0064     references = []
0065 
0066     if hasattr(expr, "shape"):
0067         for indices in np.ndindex(expr.shape):
0068             result.append((name[indices], expr[indices]))
0069             references.append((name, expr.shape, indices))
0070     else:
0071         result.append((name, expr))
0072         references.append(None)
0073 
0074     return result, references
0075 
0076 
0077 def inflate_exprs(name_exprs):
0078     result = []
0079     references = []
0080     for name_expr in name_exprs:
0081         res, refs = inflate_expr(name_expr)
0082         result.extend(res)
0083         references.extend(refs)
0084     return result, references
0085 
0086 
0087 def deflate_exprs(name_exprs, references):
0088     result = []
0089     deflated = {}
0090 
0091     for name_expr, reference in zip(name_exprs, references):
0092         if reference is None:
0093             result.append(name_expr)
0094         else:
0095             _, expr = name_expr
0096             name, shape, indices = reference
0097             if name not in deflated:
0098                 e = Matrix(np.zeros(shape))
0099                 result.append(NamedExpr(name, e))
0100                 deflated[name] = e
0101             deflated[name][*indices] = expr
0102 
0103     another_result = []
0104     for name_expr in result:
0105         name, expr = name_expr
0106         if isinstance(expr, Matrix):
0107             another_result.append(NamedExpr(name, ImmutableMatrix(expr)))
0108         else:
0109             another_result.append(name_expr)
0110 
0111     return another_result
0112 
0113 
0114 def my_subs(expr, sub_name_exprs):
0115     sub_name_exprs, _ = inflate_exprs(sub_name_exprs)
0116 
0117     result = expr.expand()
0118     result = result.subs([(e, n) for n, e in sub_name_exprs])
0119     result = sym.simplify(result)
0120     return result
0121 
0122 
0123 def build_dependency_graph(name_exprs):
0124     graph = {}
0125     for name, expr in name_exprs:
0126         graph[name] = expr.free_symbols
0127     return graph
0128 
0129 
0130 def build_influence_graph(name_exprs):
0131     graph = {}
0132     for name, expr in name_exprs:
0133         for s in expr.free_symbols:
0134             graph.setdefault(s, set()).add(name)
0135     return graph
0136 
0137 
0138 def order_exprs_by_input(name_exprs):
0139     all_expr_names = set().union(name for name, _ in name_exprs)
0140     all_expr_symbols = set().union(*[expr.free_symbols for _, expr in name_exprs])
0141     inputs = all_expr_symbols - all_expr_names
0142 
0143     order = {}
0144 
0145     order.update({i: 0 for i in inputs})
0146 
0147     while len(order) < len(inputs) + len(name_exprs):
0148         for name, expr in name_exprs:
0149             symbols_order = [order.get(s, None) for s in expr.free_symbols]
0150             if None in symbols_order:
0151                 continue
0152             order[name] = max(symbols_order) + 1
0153 
0154     result = name_exprs
0155     result = sorted(result, key=lambda n_e: len(n_e[1].args))
0156     result = sorted(result, key=lambda n_e: len(n_e[1].free_symbols))
0157     result = sorted(result, key=lambda n_e: order[n_e[0]])
0158     return result
0159 
0160 
0161 def order_exprs_by_output(name_exprs, outputs):
0162     name_expr_by_name = {name_expr[0]: name_expr for name_expr in name_exprs}
0163 
0164     def get_inputs(output):
0165         name_expr = name_expr_by_name.get(output, None)
0166         if name_expr is None:
0167             return set()
0168         inputs = set(name_expr[1].free_symbols)
0169         inputs.update(*[get_inputs(name) for name in inputs])
0170         return inputs
0171 
0172     result = []
0173     done = set()
0174 
0175     for output in outputs:
0176         inputs = get_inputs(output) - done
0177         result.extend(
0178             order_exprs_by_input(
0179                 [name_exprs for name_exprs in name_exprs if name_exprs[0] in inputs]
0180             )
0181         )
0182         result.append(name_expr_by_name[output])
0183         done.update(inputs)
0184         done.add(output)
0185 
0186     return result
0187 
0188 
0189 def my_cse(name_exprs, inflate_deflate=True, simplify=True):
0190     sub_symbols = numbered_symbols()
0191 
0192     if inflate_deflate:
0193         name_exprs, references = inflate_exprs(name_exprs)
0194 
0195     names = [x[0] for x in name_exprs]
0196     exprs = [x[1] for x in name_exprs]
0197 
0198     sub_exprs, simp_exprs = sym.cse(exprs, symbols=sub_symbols)
0199 
0200     if simplify:
0201         sub_exprs = [(n, sym.simplify(e)) for n, e in sub_exprs]
0202         simp_exprs = [sym.simplify(e) for e in simp_exprs]
0203 
0204     simp_name_exprs = list(zip(names, simp_exprs))
0205     if inflate_deflate:
0206         simp_name_exprs = deflate_exprs(simp_name_exprs, references)
0207 
0208     name_exprs = []
0209     name_exprs.extend(sub_exprs)
0210     name_exprs.extend(simp_name_exprs)
0211 
0212     return name_exprs
0213 
0214 
0215 def my_expression_print(
0216     printer, name_exprs, outputs, run_cse=True, pre_expr_hook=None, post_expr_hook=None
0217 ):
0218     if run_cse:
0219         name_exprs = my_cse(name_exprs, inflate_deflate=True)
0220     name_exprs = order_exprs_by_output(name_exprs, outputs)
0221 
0222     lines = []
0223 
0224     for var, expr in name_exprs:
0225         if pre_expr_hook is not None:
0226             code = pre_expr_hook(var)
0227             if code is not None:
0228                 lines.extend(code.split("\n"))
0229 
0230         code = printer.doprint(Assignment(var, expr))
0231         if var not in outputs:
0232             if hasattr(expr, "shape"):
0233                 lines.append(f"T {var}[{np.prod(expr.shape)}];")
0234                 lines.extend(code.split("\n"))
0235             else:
0236                 lines.append("const auto " + code)
0237         else:
0238             if hasattr(expr, "shape"):
0239                 lines.extend(code.split("\n"))
0240             else:
0241                 lines.append("*" + code)
0242 
0243         if post_expr_hook is not None:
0244             code = post_expr_hook(var)
0245             if code is not None:
0246                 lines.extend(code.split("\n"))
0247 
0248     return "\n".join(lines)
0249 
0250 
0251 def my_function_print(printer, name, inputs, name_exprs, outputs, run_cse=True):
0252     def input_param(input):
0253         if isinstance(input, MatrixSymbol):
0254             return f"const T* {input.name}"
0255         return f"const T {input.name}"
0256 
0257     def output_param(name):
0258         return f"T* {name}"
0259 
0260     lines = []
0261 
0262     params = [input_param(input) for input in inputs] + [
0263         output_param(output) for output in outputs
0264     ]
0265     head = f"template<typename T> void {name}({", ".join(params)}) {{"
0266 
0267     lines.append(head)
0268 
0269     code = my_expression_print(printer, name_exprs, outputs, run_cse=run_cse)
0270     lines.extend([f"  {l}" for l in code.split("\n")])
0271 
0272     lines.append("}")
0273 
0274     return "\n".join(lines)