File indexing completed on 2025-12-16 09:22:13
0001 from collections import namedtuple
0002
0003 import numpy as np
0004
0005 import sympy as sym
0006 from sympy import Symbol, Matrix, ImmutableMatrix, MatrixSymbol
0007 from sympy.utilities.iterables import numbered_symbols
0008 from sympy.codegen.ast import Assignment
0009 from sympy.printing.cxx import CXX17CodePrinter
0010
0011
0012 NamedExpr = namedtuple("NamedExpr", ["name", "expr"])
0013
0014
0015 def make_vector(name, dim, **kwargs):
0016 return Matrix([[Symbol(f"{name}[{i}]", **kwargs)] for i in range(dim)])
0017
0018
0019 def make_matrix(name, rows, cols, **kwargs):
0020 return Matrix(
0021 [
0022 [Symbol(f"{name}[{i},{j}]", **kwargs) for j in range(cols)]
0023 for i in range(rows)
0024 ]
0025 )
0026
0027
0028 def name_expr(name, expr):
0029 if hasattr(expr, "shape"):
0030 s = sym.MatrixSymbol(name, *expr.shape)
0031 else:
0032 s = Symbol(name)
0033 return NamedExpr(s, expr)
0034
0035
0036 def find_by_name(name_exprs, name):
0037 return next(
0038 (name_expr for name_expr in name_exprs if str(name_expr[0]) == name), None
0039 )
0040
0041
0042 class MyCXXCodePrinter(CXX17CodePrinter):
0043 def _traverse_matrix_indices(self, mat):
0044 rows, cols = mat.shape
0045 return ((i, j) for j in range(cols) for i in range(rows))
0046
0047 def _print_MatrixElement(self, expr):
0048 from sympy.printing.precedence import PRECEDENCE
0049
0050 return "{}[{}]".format(
0051 self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True),
0052 expr.i + expr.j * expr.parent.shape[0],
0053 )
0054
0055 def _print_Pow(self, expr):
0056 from sympy.core.numbers import equal_valued, Float
0057 from sympy.codegen.ast import real
0058
0059 suffix = self._get_func_suffix(real)
0060 if equal_valued(expr.exp, -0.5):
0061 return "%s/%ssqrt%s(%s)" % (
0062 self._print_Float(Float(1.0)),
0063 self._ns,
0064 suffix,
0065 self._print(expr.base),
0066 )
0067 return super()._print_Pow(expr)
0068
0069
0070 cxx_printer = MyCXXCodePrinter()
0071
0072
0073 def inflate_expr(name_expr):
0074 name, expr = name_expr
0075
0076 result = []
0077 references = []
0078
0079 if hasattr(expr, "shape"):
0080 for indices in np.ndindex(expr.shape):
0081 result.append((name[indices], expr[indices]))
0082 references.append((name, expr.shape, indices))
0083 else:
0084 result.append((name, expr))
0085 references.append(None)
0086
0087 return result, references
0088
0089
0090 def inflate_exprs(name_exprs):
0091 result = []
0092 references = []
0093 for name_expr in name_exprs:
0094 res, refs = inflate_expr(name_expr)
0095 result.extend(res)
0096 references.extend(refs)
0097 return result, references
0098
0099
0100 def deflate_exprs(name_exprs, references):
0101 result = []
0102 deflated = {}
0103
0104 for name_expr, reference in zip(name_exprs, references):
0105 if reference is None:
0106 result.append(name_expr)
0107 else:
0108 _, expr = name_expr
0109 name, shape, indices = reference
0110 if name not in deflated:
0111 e = Matrix(np.zeros(shape))
0112 result.append(NamedExpr(name, e))
0113 deflated[name] = e
0114 deflated[name][*indices] = expr
0115
0116 another_result = []
0117 for name_expr in result:
0118 name, expr = name_expr
0119 if isinstance(expr, Matrix):
0120 another_result.append(NamedExpr(name, ImmutableMatrix(expr)))
0121 else:
0122 another_result.append(name_expr)
0123
0124 return another_result
0125
0126
0127 def my_subs(expr, sub_name_exprs):
0128 sub_name_exprs, _ = inflate_exprs(sub_name_exprs)
0129
0130 result = expr.expand()
0131 result = result.subs([(e, n) for n, e in sub_name_exprs])
0132 result = sym.simplify(result)
0133 return result
0134
0135
0136 def build_dependency_graph(name_exprs):
0137 graph = {}
0138 for name, expr in name_exprs:
0139 graph[name] = expr.free_symbols
0140 return graph
0141
0142
0143 def build_influence_graph(name_exprs):
0144 graph = {}
0145 for name, expr in name_exprs:
0146 for s in expr.free_symbols:
0147 graph.setdefault(s, set()).add(name)
0148 return graph
0149
0150
0151 def order_exprs_by_input(name_exprs):
0152 all_expr_names = set().union(name for name, _ in name_exprs)
0153 all_expr_symbols = set().union(*[expr.free_symbols for _, expr in name_exprs])
0154 inputs = all_expr_symbols - all_expr_names
0155
0156 order = {}
0157
0158 order.update({i: 0 for i in inputs})
0159
0160 while len(order) < len(inputs) + len(name_exprs):
0161 for name, expr in name_exprs:
0162 symbols_order = [order.get(s, None) for s in expr.free_symbols]
0163 if None in symbols_order:
0164 continue
0165 if len(symbols_order) == 0:
0166 order[name] = 0
0167 else:
0168 order[name] = max(symbols_order) + 1
0169
0170 result = name_exprs
0171 result = sorted(result, key=lambda n_e: len(n_e[1].args))
0172 result = sorted(result, key=lambda n_e: len(n_e[1].free_symbols))
0173 result = sorted(result, key=lambda n_e: order[n_e[0]])
0174 return result
0175
0176
0177 def order_exprs_by_output(name_exprs, outputs):
0178 name_expr_by_name = {name_expr[0]: name_expr for name_expr in name_exprs}
0179
0180 def get_inputs(output):
0181 name_expr = name_expr_by_name.get(output, None)
0182 if name_expr is None:
0183 return set()
0184 inputs = set(name_expr[1].free_symbols)
0185 inputs.update(*[get_inputs(name) for name in inputs])
0186 return inputs
0187
0188 result = []
0189 done = set()
0190
0191 for output in outputs:
0192 inputs = get_inputs(output) - done
0193 result.extend(
0194 order_exprs_by_input(
0195 [name_exprs for name_exprs in name_exprs if name_exprs[0] in inputs]
0196 )
0197 )
0198 result.append(name_expr_by_name[output])
0199 done.update(inputs)
0200 done.add(output)
0201
0202 return result
0203
0204
0205 def my_cse(name_exprs, inflate_deflate=True, simplify=True):
0206 sub_symbols = numbered_symbols()
0207
0208 if inflate_deflate:
0209 name_exprs, references = inflate_exprs(name_exprs)
0210
0211 names = [x[0] for x in name_exprs]
0212 exprs = [x[1] for x in name_exprs]
0213
0214 sub_exprs, simp_exprs = sym.cse(exprs, symbols=sub_symbols)
0215
0216 if simplify:
0217 sub_exprs = [(n, sym.simplify(e)) for n, e in sub_exprs]
0218 simp_exprs = [sym.simplify(e) for e in simp_exprs]
0219
0220 simp_name_exprs = list(zip(names, simp_exprs))
0221 if inflate_deflate:
0222 simp_name_exprs = deflate_exprs(simp_name_exprs, references)
0223
0224 name_exprs = []
0225 name_exprs.extend(sub_exprs)
0226 name_exprs.extend(simp_name_exprs)
0227
0228 return name_exprs
0229
0230
0231 def my_expression_print(
0232 printer, name_exprs, outputs, run_cse=True, pre_expr_hook=None, post_expr_hook=None
0233 ):
0234 if run_cse:
0235 name_exprs = my_cse(name_exprs, inflate_deflate=True)
0236 name_exprs = order_exprs_by_output(name_exprs, outputs)
0237
0238 lines = []
0239
0240 for var, expr in name_exprs:
0241 if pre_expr_hook is not None:
0242 code = pre_expr_hook(var)
0243 if code is not None:
0244 lines.extend(code.split("\n"))
0245
0246 code = printer.doprint(Assignment(var, expr))
0247 if var not in outputs:
0248 if hasattr(expr, "shape"):
0249 lines.append(f"T {var}[{np.prod(expr.shape)}];")
0250 lines.extend(code.split("\n"))
0251 else:
0252 lines.append("const auto " + code)
0253 else:
0254 if hasattr(expr, "shape"):
0255 lines.extend(code.split("\n"))
0256 else:
0257 lines.append("*" + code)
0258
0259 if post_expr_hook is not None:
0260 code = post_expr_hook(var)
0261 if code is not None:
0262 lines.extend(code.split("\n"))
0263
0264 return "\n".join(lines)
0265
0266
0267 def my_function_print(printer, name, inputs, name_exprs, outputs, run_cse=True):
0268 def input_param(input):
0269 if isinstance(input, MatrixSymbol):
0270 return f"const T* {input.name}"
0271 return f"const T {input.name}"
0272
0273 def output_param(name):
0274 return f"T* {name}"
0275
0276 lines = []
0277
0278 params = [input_param(input) for input in inputs] + [
0279 output_param(output) for output in outputs
0280 ]
0281 head = f"template<typename T> void {name}({", ".join(params)}) {{"
0282
0283 lines.append(head)
0284
0285 code = my_expression_print(printer, name_exprs, outputs, run_cse=run_cse)
0286 lines.extend([f" {l}" for l in code.split("\n")])
0287
0288 lines.append("}")
0289
0290 return "\n".join(lines)