File indexing completed on 2025-01-18 09:10:44
0001 from collections import namedtuple
0002
0003 import numpy as np
0004
0005 import sympy as sym
0006 from sympy import Symbol, Matrix, ImmutableMatrix, MatrixSymbol
0007 from sympy.utilities.iterables import numbered_symbols
0008 from sympy.codegen.ast import Assignment
0009 from sympy.printing.cxx import CXX17CodePrinter
0010
0011
0012 NamedExpr = namedtuple("NamedExpr", ["name", "expr"])
0013
0014
0015 def name_expr(name, expr):
0016 if hasattr(expr, "shape"):
0017 s = sym.MatrixSymbol(name, *expr.shape)
0018 else:
0019 s = Symbol(name)
0020 return NamedExpr(s, expr)
0021
0022
0023 def find_by_name(name_exprs, name):
0024 return next(
0025 (name_expr for name_expr in name_exprs if str(name_expr[0]) == name), None
0026 )
0027
0028
0029 class MyCXXCodePrinter(CXX17CodePrinter):
0030 def _traverse_matrix_indices(self, mat):
0031 rows, cols = mat.shape
0032 return ((i, j) for j in range(cols) for i in range(rows))
0033
0034 def _print_MatrixElement(self, expr):
0035 from sympy.printing.precedence import PRECEDENCE
0036
0037 return "{}[{}]".format(
0038 self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True),
0039 expr.i + expr.j * expr.parent.shape[0],
0040 )
0041
0042 def _print_Pow(self, expr):
0043 from sympy.core.numbers import equal_valued, Float
0044 from sympy.codegen.ast import real
0045
0046 suffix = self._get_func_suffix(real)
0047 if equal_valued(expr.exp, -0.5):
0048 return "%s/%ssqrt%s(%s)" % (
0049 self._print_Float(Float(1.0)),
0050 self._ns,
0051 suffix,
0052 self._print(expr.base),
0053 )
0054 return super()._print_Pow(expr)
0055
0056
0057 cxx_printer = MyCXXCodePrinter()
0058
0059
0060 def inflate_expr(name_expr):
0061 name, expr = name_expr
0062
0063 result = []
0064 references = []
0065
0066 if hasattr(expr, "shape"):
0067 for indices in np.ndindex(expr.shape):
0068 result.append((name[indices], expr[indices]))
0069 references.append((name, expr.shape, indices))
0070 else:
0071 result.append((name, expr))
0072 references.append(None)
0073
0074 return result, references
0075
0076
0077 def inflate_exprs(name_exprs):
0078 result = []
0079 references = []
0080 for name_expr in name_exprs:
0081 res, refs = inflate_expr(name_expr)
0082 result.extend(res)
0083 references.extend(refs)
0084 return result, references
0085
0086
0087 def deflate_exprs(name_exprs, references):
0088 result = []
0089 deflated = {}
0090
0091 for name_expr, reference in zip(name_exprs, references):
0092 if reference is None:
0093 result.append(name_expr)
0094 else:
0095 _, expr = name_expr
0096 name, shape, indices = reference
0097 if name not in deflated:
0098 e = Matrix(np.zeros(shape))
0099 result.append(NamedExpr(name, e))
0100 deflated[name] = e
0101 deflated[name][*indices] = expr
0102
0103 another_result = []
0104 for name_expr in result:
0105 name, expr = name_expr
0106 if isinstance(expr, Matrix):
0107 another_result.append(NamedExpr(name, ImmutableMatrix(expr)))
0108 else:
0109 another_result.append(name_expr)
0110
0111 return another_result
0112
0113
0114 def my_subs(expr, sub_name_exprs):
0115 sub_name_exprs, _ = inflate_exprs(sub_name_exprs)
0116
0117 result = expr.expand()
0118 result = result.subs([(e, n) for n, e in sub_name_exprs])
0119 result = sym.simplify(result)
0120 return result
0121
0122
0123 def build_dependency_graph(name_exprs):
0124 graph = {}
0125 for name, expr in name_exprs:
0126 graph[name] = expr.free_symbols
0127 return graph
0128
0129
0130 def build_influence_graph(name_exprs):
0131 graph = {}
0132 for name, expr in name_exprs:
0133 for s in expr.free_symbols:
0134 graph.setdefault(s, set()).add(name)
0135 return graph
0136
0137
0138 def order_exprs_by_input(name_exprs):
0139 all_expr_names = set().union(name for name, _ in name_exprs)
0140 all_expr_symbols = set().union(*[expr.free_symbols for _, expr in name_exprs])
0141 inputs = all_expr_symbols - all_expr_names
0142
0143 order = {}
0144
0145 order.update({i: 0 for i in inputs})
0146
0147 while len(order) < len(inputs) + len(name_exprs):
0148 for name, expr in name_exprs:
0149 symbols_order = [order.get(s, None) for s in expr.free_symbols]
0150 if None in symbols_order:
0151 continue
0152 order[name] = max(symbols_order) + 1
0153
0154 result = name_exprs
0155 result = sorted(result, key=lambda n_e: len(n_e[1].args))
0156 result = sorted(result, key=lambda n_e: len(n_e[1].free_symbols))
0157 result = sorted(result, key=lambda n_e: order[n_e[0]])
0158 return result
0159
0160
0161 def order_exprs_by_output(name_exprs, outputs):
0162 name_expr_by_name = {name_expr[0]: name_expr for name_expr in name_exprs}
0163
0164 def get_inputs(output):
0165 name_expr = name_expr_by_name.get(output, None)
0166 if name_expr is None:
0167 return set()
0168 inputs = set(name_expr[1].free_symbols)
0169 inputs.update(*[get_inputs(name) for name in inputs])
0170 return inputs
0171
0172 result = []
0173 done = set()
0174
0175 for output in outputs:
0176 inputs = get_inputs(output) - done
0177 result.extend(
0178 order_exprs_by_input(
0179 [name_exprs for name_exprs in name_exprs if name_exprs[0] in inputs]
0180 )
0181 )
0182 result.append(name_expr_by_name[output])
0183 done.update(inputs)
0184 done.add(output)
0185
0186 return result
0187
0188
0189 def my_cse(name_exprs, inflate_deflate=True, simplify=True):
0190 sub_symbols = numbered_symbols()
0191
0192 if inflate_deflate:
0193 name_exprs, references = inflate_exprs(name_exprs)
0194
0195 names = [x[0] for x in name_exprs]
0196 exprs = [x[1] for x in name_exprs]
0197
0198 sub_exprs, simp_exprs = sym.cse(exprs, symbols=sub_symbols)
0199
0200 if simplify:
0201 sub_exprs = [(n, sym.simplify(e)) for n, e in sub_exprs]
0202 simp_exprs = [sym.simplify(e) for e in simp_exprs]
0203
0204 simp_name_exprs = list(zip(names, simp_exprs))
0205 if inflate_deflate:
0206 simp_name_exprs = deflate_exprs(simp_name_exprs, references)
0207
0208 name_exprs = []
0209 name_exprs.extend(sub_exprs)
0210 name_exprs.extend(simp_name_exprs)
0211
0212 return name_exprs
0213
0214
0215 def my_expression_print(
0216 printer, name_exprs, outputs, run_cse=True, pre_expr_hook=None, post_expr_hook=None
0217 ):
0218 if run_cse:
0219 name_exprs = my_cse(name_exprs, inflate_deflate=True)
0220 name_exprs = order_exprs_by_output(name_exprs, outputs)
0221
0222 lines = []
0223
0224 for var, expr in name_exprs:
0225 if pre_expr_hook is not None:
0226 code = pre_expr_hook(var)
0227 if code is not None:
0228 lines.extend(code.split("\n"))
0229
0230 code = printer.doprint(Assignment(var, expr))
0231 if var not in outputs:
0232 if hasattr(expr, "shape"):
0233 lines.append(f"T {var}[{np.prod(expr.shape)}];")
0234 lines.extend(code.split("\n"))
0235 else:
0236 lines.append("const auto " + code)
0237 else:
0238 if hasattr(expr, "shape"):
0239 lines.extend(code.split("\n"))
0240 else:
0241 lines.append("*" + code)
0242
0243 if post_expr_hook is not None:
0244 code = post_expr_hook(var)
0245 if code is not None:
0246 lines.extend(code.split("\n"))
0247
0248 return "\n".join(lines)
0249
0250
0251 def my_function_print(printer, name, inputs, name_exprs, outputs, run_cse=True):
0252 def input_param(input):
0253 if isinstance(input, MatrixSymbol):
0254 return f"const T* {input.name}"
0255 return f"const T {input.name}"
0256
0257 def output_param(name):
0258 return f"T* {name}"
0259
0260 lines = []
0261
0262 params = [input_param(input) for input in inputs] + [
0263 output_param(output) for output in outputs
0264 ]
0265 head = f"template<typename T> void {name}({", ".join(params)}) {{"
0266
0267 lines.append(head)
0268
0269 code = my_expression_print(printer, name_exprs, outputs, run_cse=run_cse)
0270 lines.extend([f" {l}" for l in code.split("\n")])
0271
0272 lines.append("}")
0273
0274 return "\n".join(lines)