Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-27 07:23:54

0001 from collections import namedtuple
0002 
0003 import numpy as np
0004 
0005 import sympy as sym
0006 from sympy import Symbol, Matrix, ImmutableMatrix, MatrixSymbol
0007 from sympy.utilities.iterables import numbered_symbols
0008 from sympy.codegen.ast import Assignment
0009 from sympy.printing.cxx import CXX17CodePrinter
0010 
0011 NamedExpr = namedtuple("NamedExpr", ["name", "expr"])
0012 
0013 
0014 def make_vector(name, dim, **kwargs):
0015     return Matrix([[Symbol(f"{name}[{i}]", **kwargs)] for i in range(dim)])
0016 
0017 
0018 def make_matrix(name, rows, cols, **kwargs):
0019     return Matrix(
0020         [
0021             [Symbol(f"{name}[{i},{j}]", **kwargs) for j in range(cols)]
0022             for i in range(rows)
0023         ]
0024     )
0025 
0026 
0027 def name_expr(name, expr):
0028     if hasattr(expr, "shape"):
0029         s = sym.MatrixSymbol(name, *expr.shape)
0030     else:
0031         s = Symbol(name)
0032     return NamedExpr(s, expr)
0033 
0034 
0035 def find_by_name(name_exprs, name):
0036     return next(
0037         (name_expr for name_expr in name_exprs if str(name_expr[0]) == name), None
0038     )
0039 
0040 
0041 class MyCXXCodePrinter(CXX17CodePrinter):
0042     def _traverse_matrix_indices(self, mat):
0043         rows, cols = mat.shape
0044         return ((i, j) for j in range(cols) for i in range(rows))
0045 
0046     def _print_MatrixElement(self, expr):
0047         from sympy.printing.precedence import PRECEDENCE
0048 
0049         # HACK: This is for the 8x1 path-to-free matrix.
0050         if expr.parent.shape[1] == 1 and expr.parent.shape[0] <= 3:
0051             return "getter::element<{idx}>({var})".format(
0052                 var=self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True),
0053                 idx=expr.i,
0054             )
0055         else:
0056             return "getter::element<{idx1}, {idx2}>({var})".format(
0057                 var=self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True),
0058                 idx1=expr.i,
0059                 idx2=expr.j,
0060             )
0061 
0062     def _print_Pow(self, expr):
0063         from sympy.core.numbers import equal_valued, Float
0064         from sympy.codegen.ast import real
0065 
0066         suffix = self._get_func_suffix(real)
0067         if equal_valued(expr.exp, -0.5):
0068             return "%s/%ssqrt%s(%s)" % (
0069                 self._print_Float(Float(1.0)),
0070                 self._ns,
0071                 suffix,
0072                 self._print(expr.base),
0073             )
0074         return super()._print_Pow(expr)
0075 
0076 
0077 class MyCXXCodePrinterWithoutKnownAssignment(MyCXXCodePrinter):
0078     def _print_Assignment(self, expr):
0079         if expr.rhs == 0 or expr.rhs == 1:
0080             return ""
0081         return super()._print_Assignment(expr)
0082 
0083 
0084 cxx_printer = MyCXXCodePrinter()
0085 cxx_printer_wo_known = MyCXXCodePrinterWithoutKnownAssignment()
0086 
0087 
0088 def inflate_expr(name_expr):
0089     name, expr = name_expr
0090 
0091     result = []
0092     references = []
0093 
0094     if hasattr(expr, "shape"):
0095         for indices in np.ndindex(expr.shape):
0096             result.append((name[indices], expr[indices]))
0097             references.append((name, expr.shape, indices))
0098     else:
0099         result.append((name, expr))
0100         references.append(None)
0101 
0102     return result, references
0103 
0104 
0105 def inflate_exprs(name_exprs):
0106     result = []
0107     references = []
0108     for name_expr in name_exprs:
0109         res, refs = inflate_expr(name_expr)
0110         result.extend(res)
0111         references.extend(refs)
0112     return result, references
0113 
0114 
0115 def deflate_exprs(name_exprs, references):
0116     result = []
0117     deflated = {}
0118 
0119     for name_expr, reference in zip(name_exprs, references):
0120         if reference is None:
0121             result.append(name_expr)
0122         else:
0123             _, expr = name_expr
0124             name, shape, indices = reference
0125             if name not in deflated:
0126                 e = Matrix(np.zeros(shape))
0127                 result.append(NamedExpr(name, e))
0128                 deflated[name] = e
0129             deflated[name][*indices] = expr
0130 
0131     another_result = []
0132     for name_expr in result:
0133         name, expr = name_expr
0134         if isinstance(expr, Matrix):
0135             another_result.append(NamedExpr(name, ImmutableMatrix(expr)))
0136         else:
0137             another_result.append(name_expr)
0138 
0139     return another_result
0140 
0141 
0142 def my_subs(expr, sub_name_exprs):
0143     sub_name_exprs, _ = inflate_exprs(sub_name_exprs)
0144 
0145     result = expr.expand()
0146     result = result.subs([(e, n) for n, e in sub_name_exprs])
0147     result = sym.simplify(result)
0148     return result
0149 
0150 
0151 def build_dependency_graph(name_exprs):
0152     graph = {}
0153     for name, expr in name_exprs:
0154         graph[name] = expr.free_symbols
0155     return graph
0156 
0157 
0158 def build_influence_graph(name_exprs):
0159     graph = {}
0160     for name, expr in name_exprs:
0161         for s in expr.free_symbols:
0162             graph.setdefault(s, set()).add(name)
0163     return graph
0164 
0165 
0166 def order_exprs_by_input(name_exprs):
0167     all_expr_names = set().union(name for name, _ in name_exprs)
0168     all_expr_symbols = set().union(*[expr.free_symbols for _, expr in name_exprs])
0169     inputs = all_expr_symbols - all_expr_names
0170 
0171     order = {}
0172 
0173     order.update({i: 0 for i in inputs})
0174 
0175     while len(order) < len(inputs) + len(name_exprs):
0176         for name, expr in name_exprs:
0177             symbols_order = [order.get(s, None) for s in expr.free_symbols]
0178             if None in symbols_order:
0179                 continue
0180             if len(symbols_order) == 0:
0181                 order[name] = 0
0182             else:
0183                 order[name] = max(symbols_order) + 1
0184 
0185     result = name_exprs
0186     result = sorted(result, key=lambda n_e: len(n_e[1].args))
0187     result = sorted(result, key=lambda n_e: len(n_e[1].free_symbols))
0188     result = sorted(result, key=lambda n_e: order[n_e[0]])
0189     return result
0190 
0191 
0192 def order_exprs_by_output(name_exprs, outputs):
0193     name_expr_by_name = {name_expr[0]: name_expr for name_expr in name_exprs}
0194 
0195     def get_inputs(output):
0196         name_expr = name_expr_by_name.get(output, None)
0197         if name_expr is None:
0198             return set()
0199         inputs = set(name_expr[1].free_symbols)
0200         inputs.update(*[get_inputs(name) for name in inputs])
0201         return inputs
0202 
0203     result = []
0204     done = set()
0205 
0206     for output in outputs:
0207         inputs = get_inputs(output) - done
0208         result.extend(
0209             order_exprs_by_input(
0210                 [name_exprs for name_exprs in name_exprs if name_exprs[0] in inputs]
0211             )
0212         )
0213         result.append(name_expr_by_name[output])
0214         done.update(inputs)
0215         done.add(output)
0216 
0217     return result
0218 
0219 
0220 def my_cse(name_exprs, inflate_deflate=True, simplify=True):
0221     sub_symbols = numbered_symbols()
0222 
0223     if inflate_deflate:
0224         name_exprs, references = inflate_exprs(name_exprs)
0225 
0226     names = [x[0] for x in name_exprs]
0227     exprs = [x[1] for x in name_exprs]
0228 
0229     sub_exprs, simp_exprs = sym.cse(exprs, symbols=sub_symbols)
0230 
0231     if simplify:
0232         sub_exprs = [(n, sym.simplify(e)) for n, e in sub_exprs]
0233         simp_exprs = [sym.simplify(e) for e in simp_exprs]
0234 
0235     simp_name_exprs = list(zip(names, simp_exprs))
0236     if inflate_deflate:
0237         simp_name_exprs = deflate_exprs(simp_name_exprs, references)
0238 
0239     name_exprs = []
0240     name_exprs.extend(sub_exprs)
0241     name_exprs.extend(simp_name_exprs)
0242 
0243     return name_exprs
0244 
0245 
0246 def my_expression_print(
0247     printer, name_exprs, outputs, run_cse=True, pre_expr_hook=None, post_expr_hook=None
0248 ):
0249     if run_cse:
0250         name_exprs = my_cse(name_exprs, inflate_deflate=True)
0251     name_exprs = order_exprs_by_output(name_exprs, outputs)
0252 
0253     lines = []
0254 
0255     for var, expr in name_exprs:
0256         if pre_expr_hook is not None:
0257             code = pre_expr_hook(var)
0258             if code is not None:
0259                 lines.extend(code.split("\n"))
0260 
0261         code = printer.doprint(Assignment(var, expr))
0262         if var not in outputs:
0263             if hasattr(expr, "shape"):
0264                 lines.append(f"T {var}[{np.prod(expr.shape)}];")
0265                 lines.extend(code.split("\n"))
0266             else:
0267                 lines.append("const auto " + code)
0268         else:
0269             if hasattr(expr, "shape"):
0270                 lines.extend(code.split("\n"))
0271             else:
0272                 lines.append("*" + code)
0273 
0274         if post_expr_hook is not None:
0275             code = post_expr_hook(var)
0276             if code is not None:
0277                 lines.extend(code.split("\n"))
0278 
0279     return "\n".join(lines)