Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:22:13

0001 from collections import namedtuple
0002 
0003 import numpy as np
0004 
0005 import sympy as sym
0006 from sympy import Symbol, Matrix, ImmutableMatrix, MatrixSymbol
0007 from sympy.utilities.iterables import numbered_symbols
0008 from sympy.codegen.ast import Assignment
0009 from sympy.printing.cxx import CXX17CodePrinter
0010 
0011 
0012 NamedExpr = namedtuple("NamedExpr", ["name", "expr"])
0013 
0014 
0015 def make_vector(name, dim, **kwargs):
0016     return Matrix([[Symbol(f"{name}[{i}]", **kwargs)] for i in range(dim)])
0017 
0018 
0019 def make_matrix(name, rows, cols, **kwargs):
0020     return Matrix(
0021         [
0022             [Symbol(f"{name}[{i},{j}]", **kwargs) for j in range(cols)]
0023             for i in range(rows)
0024         ]
0025     )
0026 
0027 
0028 def name_expr(name, expr):
0029     if hasattr(expr, "shape"):
0030         s = sym.MatrixSymbol(name, *expr.shape)
0031     else:
0032         s = Symbol(name)
0033     return NamedExpr(s, expr)
0034 
0035 
0036 def find_by_name(name_exprs, name):
0037     return next(
0038         (name_expr for name_expr in name_exprs if str(name_expr[0]) == name), None
0039     )
0040 
0041 
0042 class MyCXXCodePrinter(CXX17CodePrinter):
0043     def _traverse_matrix_indices(self, mat):
0044         rows, cols = mat.shape
0045         return ((i, j) for j in range(cols) for i in range(rows))
0046 
0047     def _print_MatrixElement(self, expr):
0048         from sympy.printing.precedence import PRECEDENCE
0049 
0050         return "{}[{}]".format(
0051             self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True),
0052             expr.i + expr.j * expr.parent.shape[0],
0053         )
0054 
0055     def _print_Pow(self, expr):
0056         from sympy.core.numbers import equal_valued, Float
0057         from sympy.codegen.ast import real
0058 
0059         suffix = self._get_func_suffix(real)
0060         if equal_valued(expr.exp, -0.5):
0061             return "%s/%ssqrt%s(%s)" % (
0062                 self._print_Float(Float(1.0)),
0063                 self._ns,
0064                 suffix,
0065                 self._print(expr.base),
0066             )
0067         return super()._print_Pow(expr)
0068 
0069 
0070 cxx_printer = MyCXXCodePrinter()
0071 
0072 
0073 def inflate_expr(name_expr):
0074     name, expr = name_expr
0075 
0076     result = []
0077     references = []
0078 
0079     if hasattr(expr, "shape"):
0080         for indices in np.ndindex(expr.shape):
0081             result.append((name[indices], expr[indices]))
0082             references.append((name, expr.shape, indices))
0083     else:
0084         result.append((name, expr))
0085         references.append(None)
0086 
0087     return result, references
0088 
0089 
0090 def inflate_exprs(name_exprs):
0091     result = []
0092     references = []
0093     for name_expr in name_exprs:
0094         res, refs = inflate_expr(name_expr)
0095         result.extend(res)
0096         references.extend(refs)
0097     return result, references
0098 
0099 
0100 def deflate_exprs(name_exprs, references):
0101     result = []
0102     deflated = {}
0103 
0104     for name_expr, reference in zip(name_exprs, references):
0105         if reference is None:
0106             result.append(name_expr)
0107         else:
0108             _, expr = name_expr
0109             name, shape, indices = reference
0110             if name not in deflated:
0111                 e = Matrix(np.zeros(shape))
0112                 result.append(NamedExpr(name, e))
0113                 deflated[name] = e
0114             deflated[name][*indices] = expr
0115 
0116     another_result = []
0117     for name_expr in result:
0118         name, expr = name_expr
0119         if isinstance(expr, Matrix):
0120             another_result.append(NamedExpr(name, ImmutableMatrix(expr)))
0121         else:
0122             another_result.append(name_expr)
0123 
0124     return another_result
0125 
0126 
0127 def my_subs(expr, sub_name_exprs):
0128     sub_name_exprs, _ = inflate_exprs(sub_name_exprs)
0129 
0130     result = expr.expand()
0131     result = result.subs([(e, n) for n, e in sub_name_exprs])
0132     result = sym.simplify(result)
0133     return result
0134 
0135 
0136 def build_dependency_graph(name_exprs):
0137     graph = {}
0138     for name, expr in name_exprs:
0139         graph[name] = expr.free_symbols
0140     return graph
0141 
0142 
0143 def build_influence_graph(name_exprs):
0144     graph = {}
0145     for name, expr in name_exprs:
0146         for s in expr.free_symbols:
0147             graph.setdefault(s, set()).add(name)
0148     return graph
0149 
0150 
0151 def order_exprs_by_input(name_exprs):
0152     all_expr_names = set().union(name for name, _ in name_exprs)
0153     all_expr_symbols = set().union(*[expr.free_symbols for _, expr in name_exprs])
0154     inputs = all_expr_symbols - all_expr_names
0155 
0156     order = {}
0157 
0158     order.update({i: 0 for i in inputs})
0159 
0160     while len(order) < len(inputs) + len(name_exprs):
0161         for name, expr in name_exprs:
0162             symbols_order = [order.get(s, None) for s in expr.free_symbols]
0163             if None in symbols_order:
0164                 continue
0165             if len(symbols_order) == 0:
0166                 order[name] = 0
0167             else:
0168                 order[name] = max(symbols_order) + 1
0169 
0170     result = name_exprs
0171     result = sorted(result, key=lambda n_e: len(n_e[1].args))
0172     result = sorted(result, key=lambda n_e: len(n_e[1].free_symbols))
0173     result = sorted(result, key=lambda n_e: order[n_e[0]])
0174     return result
0175 
0176 
0177 def order_exprs_by_output(name_exprs, outputs):
0178     name_expr_by_name = {name_expr[0]: name_expr for name_expr in name_exprs}
0179 
0180     def get_inputs(output):
0181         name_expr = name_expr_by_name.get(output, None)
0182         if name_expr is None:
0183             return set()
0184         inputs = set(name_expr[1].free_symbols)
0185         inputs.update(*[get_inputs(name) for name in inputs])
0186         return inputs
0187 
0188     result = []
0189     done = set()
0190 
0191     for output in outputs:
0192         inputs = get_inputs(output) - done
0193         result.extend(
0194             order_exprs_by_input(
0195                 [name_exprs for name_exprs in name_exprs if name_exprs[0] in inputs]
0196             )
0197         )
0198         result.append(name_expr_by_name[output])
0199         done.update(inputs)
0200         done.add(output)
0201 
0202     return result
0203 
0204 
0205 def my_cse(name_exprs, inflate_deflate=True, simplify=True):
0206     sub_symbols = numbered_symbols()
0207 
0208     if inflate_deflate:
0209         name_exprs, references = inflate_exprs(name_exprs)
0210 
0211     names = [x[0] for x in name_exprs]
0212     exprs = [x[1] for x in name_exprs]
0213 
0214     sub_exprs, simp_exprs = sym.cse(exprs, symbols=sub_symbols)
0215 
0216     if simplify:
0217         sub_exprs = [(n, sym.simplify(e)) for n, e in sub_exprs]
0218         simp_exprs = [sym.simplify(e) for e in simp_exprs]
0219 
0220     simp_name_exprs = list(zip(names, simp_exprs))
0221     if inflate_deflate:
0222         simp_name_exprs = deflate_exprs(simp_name_exprs, references)
0223 
0224     name_exprs = []
0225     name_exprs.extend(sub_exprs)
0226     name_exprs.extend(simp_name_exprs)
0227 
0228     return name_exprs
0229 
0230 
0231 def my_expression_print(
0232     printer, name_exprs, outputs, run_cse=True, pre_expr_hook=None, post_expr_hook=None
0233 ):
0234     if run_cse:
0235         name_exprs = my_cse(name_exprs, inflate_deflate=True)
0236     name_exprs = order_exprs_by_output(name_exprs, outputs)
0237 
0238     lines = []
0239 
0240     for var, expr in name_exprs:
0241         if pre_expr_hook is not None:
0242             code = pre_expr_hook(var)
0243             if code is not None:
0244                 lines.extend(code.split("\n"))
0245 
0246         code = printer.doprint(Assignment(var, expr))
0247         if var not in outputs:
0248             if hasattr(expr, "shape"):
0249                 lines.append(f"T {var}[{np.prod(expr.shape)}];")
0250                 lines.extend(code.split("\n"))
0251             else:
0252                 lines.append("const auto " + code)
0253         else:
0254             if hasattr(expr, "shape"):
0255                 lines.extend(code.split("\n"))
0256             else:
0257                 lines.append("*" + code)
0258 
0259         if post_expr_hook is not None:
0260             code = post_expr_hook(var)
0261             if code is not None:
0262                 lines.extend(code.split("\n"))
0263 
0264     return "\n".join(lines)
0265 
0266 
0267 def my_function_print(printer, name, inputs, name_exprs, outputs, run_cse=True):
0268     def input_param(input):
0269         if isinstance(input, MatrixSymbol):
0270             return f"const T* {input.name}"
0271         return f"const T {input.name}"
0272 
0273     def output_param(name):
0274         return f"T* {name}"
0275 
0276     lines = []
0277 
0278     params = [input_param(input) for input in inputs] + [
0279         output_param(output) for output in outputs
0280     ]
0281     head = f"template<typename T> void {name}({", ".join(params)}) {{"
0282 
0283     lines.append(head)
0284 
0285     code = my_expression_print(printer, name_exprs, outputs, run_cse=run_cse)
0286     lines.extend([f"  {l}" for l in code.split("\n")])
0287 
0288     lines.append("}")
0289 
0290     return "\n".join(lines)