File indexing completed on 2026-05-27 07:23:54
0001 from collections import namedtuple
0002
0003 import numpy as np
0004
0005 import sympy as sym
0006 from sympy import Symbol, Matrix, ImmutableMatrix, MatrixSymbol
0007 from sympy.utilities.iterables import numbered_symbols
0008 from sympy.codegen.ast import Assignment
0009 from sympy.printing.cxx import CXX17CodePrinter
0010
0011 NamedExpr = namedtuple("NamedExpr", ["name", "expr"])
0012
0013
0014 def make_vector(name, dim, **kwargs):
0015 return Matrix([[Symbol(f"{name}[{i}]", **kwargs)] for i in range(dim)])
0016
0017
0018 def make_matrix(name, rows, cols, **kwargs):
0019 return Matrix(
0020 [
0021 [Symbol(f"{name}[{i},{j}]", **kwargs) for j in range(cols)]
0022 for i in range(rows)
0023 ]
0024 )
0025
0026
0027 def name_expr(name, expr):
0028 if hasattr(expr, "shape"):
0029 s = sym.MatrixSymbol(name, *expr.shape)
0030 else:
0031 s = Symbol(name)
0032 return NamedExpr(s, expr)
0033
0034
0035 def find_by_name(name_exprs, name):
0036 return next(
0037 (name_expr for name_expr in name_exprs if str(name_expr[0]) == name), None
0038 )
0039
0040
0041 class MyCXXCodePrinter(CXX17CodePrinter):
0042 def _traverse_matrix_indices(self, mat):
0043 rows, cols = mat.shape
0044 return ((i, j) for j in range(cols) for i in range(rows))
0045
0046 def _print_MatrixElement(self, expr):
0047 from sympy.printing.precedence import PRECEDENCE
0048
0049
0050 if expr.parent.shape[1] == 1 and expr.parent.shape[0] <= 3:
0051 return "getter::element<{idx}>({var})".format(
0052 var=self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True),
0053 idx=expr.i,
0054 )
0055 else:
0056 return "getter::element<{idx1}, {idx2}>({var})".format(
0057 var=self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True),
0058 idx1=expr.i,
0059 idx2=expr.j,
0060 )
0061
0062 def _print_Pow(self, expr):
0063 from sympy.core.numbers import equal_valued, Float
0064 from sympy.codegen.ast import real
0065
0066 suffix = self._get_func_suffix(real)
0067 if equal_valued(expr.exp, -0.5):
0068 return "%s/%ssqrt%s(%s)" % (
0069 self._print_Float(Float(1.0)),
0070 self._ns,
0071 suffix,
0072 self._print(expr.base),
0073 )
0074 return super()._print_Pow(expr)
0075
0076
0077 class MyCXXCodePrinterWithoutKnownAssignment(MyCXXCodePrinter):
0078 def _print_Assignment(self, expr):
0079 if expr.rhs == 0 or expr.rhs == 1:
0080 return ""
0081 return super()._print_Assignment(expr)
0082
0083
0084 cxx_printer = MyCXXCodePrinter()
0085 cxx_printer_wo_known = MyCXXCodePrinterWithoutKnownAssignment()
0086
0087
0088 def inflate_expr(name_expr):
0089 name, expr = name_expr
0090
0091 result = []
0092 references = []
0093
0094 if hasattr(expr, "shape"):
0095 for indices in np.ndindex(expr.shape):
0096 result.append((name[indices], expr[indices]))
0097 references.append((name, expr.shape, indices))
0098 else:
0099 result.append((name, expr))
0100 references.append(None)
0101
0102 return result, references
0103
0104
0105 def inflate_exprs(name_exprs):
0106 result = []
0107 references = []
0108 for name_expr in name_exprs:
0109 res, refs = inflate_expr(name_expr)
0110 result.extend(res)
0111 references.extend(refs)
0112 return result, references
0113
0114
0115 def deflate_exprs(name_exprs, references):
0116 result = []
0117 deflated = {}
0118
0119 for name_expr, reference in zip(name_exprs, references):
0120 if reference is None:
0121 result.append(name_expr)
0122 else:
0123 _, expr = name_expr
0124 name, shape, indices = reference
0125 if name not in deflated:
0126 e = Matrix(np.zeros(shape))
0127 result.append(NamedExpr(name, e))
0128 deflated[name] = e
0129 deflated[name][*indices] = expr
0130
0131 another_result = []
0132 for name_expr in result:
0133 name, expr = name_expr
0134 if isinstance(expr, Matrix):
0135 another_result.append(NamedExpr(name, ImmutableMatrix(expr)))
0136 else:
0137 another_result.append(name_expr)
0138
0139 return another_result
0140
0141
0142 def my_subs(expr, sub_name_exprs):
0143 sub_name_exprs, _ = inflate_exprs(sub_name_exprs)
0144
0145 result = expr.expand()
0146 result = result.subs([(e, n) for n, e in sub_name_exprs])
0147 result = sym.simplify(result)
0148 return result
0149
0150
0151 def build_dependency_graph(name_exprs):
0152 graph = {}
0153 for name, expr in name_exprs:
0154 graph[name] = expr.free_symbols
0155 return graph
0156
0157
0158 def build_influence_graph(name_exprs):
0159 graph = {}
0160 for name, expr in name_exprs:
0161 for s in expr.free_symbols:
0162 graph.setdefault(s, set()).add(name)
0163 return graph
0164
0165
0166 def order_exprs_by_input(name_exprs):
0167 all_expr_names = set().union(name for name, _ in name_exprs)
0168 all_expr_symbols = set().union(*[expr.free_symbols for _, expr in name_exprs])
0169 inputs = all_expr_symbols - all_expr_names
0170
0171 order = {}
0172
0173 order.update({i: 0 for i in inputs})
0174
0175 while len(order) < len(inputs) + len(name_exprs):
0176 for name, expr in name_exprs:
0177 symbols_order = [order.get(s, None) for s in expr.free_symbols]
0178 if None in symbols_order:
0179 continue
0180 if len(symbols_order) == 0:
0181 order[name] = 0
0182 else:
0183 order[name] = max(symbols_order) + 1
0184
0185 result = name_exprs
0186 result = sorted(result, key=lambda n_e: len(n_e[1].args))
0187 result = sorted(result, key=lambda n_e: len(n_e[1].free_symbols))
0188 result = sorted(result, key=lambda n_e: order[n_e[0]])
0189 return result
0190
0191
0192 def order_exprs_by_output(name_exprs, outputs):
0193 name_expr_by_name = {name_expr[0]: name_expr for name_expr in name_exprs}
0194
0195 def get_inputs(output):
0196 name_expr = name_expr_by_name.get(output, None)
0197 if name_expr is None:
0198 return set()
0199 inputs = set(name_expr[1].free_symbols)
0200 inputs.update(*[get_inputs(name) for name in inputs])
0201 return inputs
0202
0203 result = []
0204 done = set()
0205
0206 for output in outputs:
0207 inputs = get_inputs(output) - done
0208 result.extend(
0209 order_exprs_by_input(
0210 [name_exprs for name_exprs in name_exprs if name_exprs[0] in inputs]
0211 )
0212 )
0213 result.append(name_expr_by_name[output])
0214 done.update(inputs)
0215 done.add(output)
0216
0217 return result
0218
0219
0220 def my_cse(name_exprs, inflate_deflate=True, simplify=True):
0221 sub_symbols = numbered_symbols()
0222
0223 if inflate_deflate:
0224 name_exprs, references = inflate_exprs(name_exprs)
0225
0226 names = [x[0] for x in name_exprs]
0227 exprs = [x[1] for x in name_exprs]
0228
0229 sub_exprs, simp_exprs = sym.cse(exprs, symbols=sub_symbols)
0230
0231 if simplify:
0232 sub_exprs = [(n, sym.simplify(e)) for n, e in sub_exprs]
0233 simp_exprs = [sym.simplify(e) for e in simp_exprs]
0234
0235 simp_name_exprs = list(zip(names, simp_exprs))
0236 if inflate_deflate:
0237 simp_name_exprs = deflate_exprs(simp_name_exprs, references)
0238
0239 name_exprs = []
0240 name_exprs.extend(sub_exprs)
0241 name_exprs.extend(simp_name_exprs)
0242
0243 return name_exprs
0244
0245
0246 def my_expression_print(
0247 printer, name_exprs, outputs, run_cse=True, pre_expr_hook=None, post_expr_hook=None
0248 ):
0249 if run_cse:
0250 name_exprs = my_cse(name_exprs, inflate_deflate=True)
0251 name_exprs = order_exprs_by_output(name_exprs, outputs)
0252
0253 lines = []
0254
0255 for var, expr in name_exprs:
0256 if pre_expr_hook is not None:
0257 code = pre_expr_hook(var)
0258 if code is not None:
0259 lines.extend(code.split("\n"))
0260
0261 code = printer.doprint(Assignment(var, expr))
0262 if var not in outputs:
0263 if hasattr(expr, "shape"):
0264 lines.append(f"T {var}[{np.prod(expr.shape)}];")
0265 lines.extend(code.split("\n"))
0266 else:
0267 lines.append("const auto " + code)
0268 else:
0269 if hasattr(expr, "shape"):
0270 lines.extend(code.split("\n"))
0271 else:
0272 lines.append("*" + code)
0273
0274 if post_expr_hook is not None:
0275 code = post_expr_hook(var)
0276 if code is not None:
0277 lines.extend(code.split("\n"))
0278
0279 return "\n".join(lines)