File indexing completed on 2025-04-04 07:58:04
0001 import sys
0002
0003 import numpy as np
0004
0005 import sympy as sym
0006 from sympy import Symbol, Matrix, ImmutableMatrix, MatrixSymbol
0007 from sympy.codegen.ast import Assignment
0008
0009 from codegen.sympy_common import (
0010 NamedExpr,
0011 name_expr,
0012 find_by_name,
0013 my_subs,
0014 cxx_printer,
0015 my_expression_print,
0016 )
0017
0018
0019 output = sys.stdout
0020 if len(sys.argv) > 1:
0021 output = open(sys.argv[1], "w")
0022
0023
0024
0025 l = Symbol("lambda", real=True)
0026
0027
0028 h = Symbol("h", real=True)
0029
0030
0031 p = MatrixSymbol("p", 3, 1)
0032
0033
0034 d = MatrixSymbol("d", 3, 1)
0035
0036
0037 t = Symbol("t", real=True)
0038
0039
0040 m = Symbol("m", real=True)
0041
0042
0043 p_abs = Symbol("p_abs", real=True, positive=True)
0044
0045
0046 B1 = MatrixSymbol("B1", 3, 1)
0047 B2 = MatrixSymbol("B2", 3, 1)
0048 B3 = MatrixSymbol("B3", 3, 1)
0049
0050
0051 def rk4_full_math():
0052 k1 = name_expr("k1", d.as_explicit().cross(l * B1))
0053 p2 = name_expr("p2", p + h / 2 * d + h**2 / 8 * k1.expr)
0054
0055 k2 = name_expr("k2", (d + h / 2 * k1.expr).as_explicit().cross(l * B2))
0056 k3 = name_expr("k3", (d + h / 2 * k2.expr).as_explicit().cross(l * B2))
0057 p3 = name_expr("p3", p + h * d + h**2 / 2 * k3.expr)
0058
0059 k4 = name_expr("k4", (d + h * k3.expr).as_explicit().cross(l * B3))
0060
0061 err = name_expr("err", h**2 * (k1.expr - k2.expr - k3.expr + k4.expr).norm(1))
0062
0063 new_p = name_expr("new_p", p + h * d + h**2 / 6 * (k1.expr + k2.expr + k3.expr))
0064 new_d_tmp = name_expr(
0065 "new_d_tmp", d + h / 6 * (k1.expr + 2 * (k2.expr + k3.expr) + k4.expr)
0066 )
0067 new_d = name_expr("new_d", new_d_tmp.expr / new_d_tmp.expr.as_explicit().norm())
0068
0069 dtds = name_expr("dtds", sym.sqrt(1 + m**2 / p_abs**2))
0070 new_time = name_expr("new_time", t + h * dtds.expr)
0071
0072 path_derivatives = name_expr("path_derivatives", sym.zeros(8, 1))
0073 path_derivatives.expr[0:3, 0] = new_d.expr.as_explicit()
0074 path_derivatives.expr[3, 0] = dtds.expr
0075 path_derivatives.expr[4:7, 0] = k4.expr.as_explicit()
0076
0077 D = sym.eye(8)
0078 D[0:3, :] = new_p.expr.as_explicit().jacobian([p, t, d, l])
0079 D[4:7, :] = new_d_tmp.expr.as_explicit().jacobian([p, t, d, l])
0080 D[3, 7] = h * m**2 * l / dtds.expr
0081
0082 J = MatrixSymbol("J", 8, 8).as_explicit().as_mutable()
0083 for indices in np.ndindex(J.shape):
0084 if D[indices] in [0, 1]:
0085 J[indices] = D[indices]
0086 J = ImmutableMatrix(J)
0087
0088 new_J = name_expr("new_J", J * D)
0089
0090 return [p2, p3, err, new_p, new_d, new_time, path_derivatives, new_J]
0091
0092
0093 def rk4_short_math():
0094 k1 = name_expr("k1", d.as_explicit().cross(l * B1))
0095 p2 = name_expr("p2", p + h / 2 * d + h**2 / 8 * k1.name)
0096
0097 k2 = name_expr("k2", (d + h / 2 * k1.name).as_explicit().cross(l * B2))
0098 k3 = name_expr("k3", (d + h / 2 * k2.name).as_explicit().cross(l * B2))
0099 p3 = name_expr("p3", p + h * d + h**2 / 2 * k3.name)
0100
0101 k4 = name_expr("k4", (d + h * k3.name).as_explicit().cross(l * B3))
0102
0103 err = name_expr(
0104 "err", h**2 * (k1.name - k2.name - k3.name + k4.name).as_explicit().norm(1)
0105 )
0106
0107 new_p = name_expr("new_p", p + h * d + h**2 / 6 * (k1.name + k2.name + k3.name))
0108 new_d_tmp = name_expr(
0109 "new_d_tmp", d + h / 6 * (k1.name + 2 * (k2.name + k3.name) + k4.name)
0110 )
0111 new_d = name_expr("new_d", new_d_tmp.name / new_d_tmp.name.as_explicit().norm())
0112
0113 dtds = name_expr("dtds", sym.sqrt(1 + m**2 / p_abs**2))
0114 new_time = name_expr("new_time", t + h * dtds.name)
0115
0116 path_derivatives = name_expr("path_derivatives", sym.zeros(8, 1))
0117 path_derivatives.expr[0:3, 0] = new_d.name.as_explicit()
0118 path_derivatives.expr[3, 0] = dtds.name
0119 path_derivatives.expr[4:7, 0] = k4.name.as_explicit()
0120
0121 dk1dTL = name_expr("dk1dTL", k1.expr.jacobian([d, l]))
0122 dk2dTL = name_expr(
0123 "dk2dTL", k2.expr.jacobian([d, l]) + k2.expr.jacobian(k1.name) * dk1dTL.expr
0124 )
0125 dk3dTL = name_expr(
0126 "dk3dTL",
0127 k3.expr.jacobian([d, l])
0128 + k3.expr.jacobian(k2.name) * dk2dTL.name.as_explicit(),
0129 )
0130 dk4dTL = name_expr(
0131 "dk4dTL",
0132 k4.expr.jacobian([d, l])
0133 + k4.expr.jacobian(k3.name) * dk3dTL.name.as_explicit(),
0134 )
0135
0136 dFdTL = name_expr(
0137 "dFdTL",
0138 new_p.expr.as_explicit().jacobian([d, l])
0139 + new_p.expr.as_explicit().jacobian(k1.name) * dk1dTL.expr
0140 + new_p.expr.as_explicit().jacobian(k2.name) * dk2dTL.name.as_explicit()
0141 + new_p.expr.as_explicit().jacobian(k3.name) * dk3dTL.name.as_explicit(),
0142 )
0143 dGdTL = name_expr(
0144 "dGdTL",
0145 new_d_tmp.expr.as_explicit().jacobian([d, l])
0146 + new_d_tmp.expr.as_explicit().jacobian(k1.name) * dk1dTL.expr
0147 + new_d_tmp.expr.as_explicit().jacobian(k2.name) * dk2dTL.name.as_explicit()
0148 + new_d_tmp.expr.as_explicit().jacobian(k3.name) * dk3dTL.name.as_explicit()
0149 + new_d_tmp.expr.as_explicit().jacobian(k4.name) * dk4dTL.name.as_explicit(),
0150 )
0151
0152 D = sym.eye(8)
0153 D[0:3, 4:8] = dFdTL.name.as_explicit()
0154 D[4:7, 4:8] = dGdTL.name.as_explicit()
0155 D[3, 7] = h * m**2 * l / dtds.name
0156
0157 J = Matrix(MatrixSymbol("J", 8, 8).as_explicit())
0158 for indices in np.ndindex(J.shape):
0159 if D[indices] in [0, 1]:
0160 J[indices] = D[indices]
0161 J = ImmutableMatrix(J)
0162 new_J = name_expr("new_J", J * D)
0163
0164 return [
0165 k1,
0166 p2,
0167 k2,
0168 k3,
0169 p3,
0170 k4,
0171 err,
0172 new_p,
0173 new_d_tmp,
0174 new_d,
0175 dtds,
0176 new_time,
0177 path_derivatives,
0178 dk2dTL,
0179 dk3dTL,
0180 dk4dTL,
0181 dFdTL,
0182 dGdTL,
0183 new_J,
0184 ]
0185
0186
0187 def my_step_function_print(name_exprs, run_cse=True):
0188 printer = cxx_printer
0189 outputs = [
0190 find_by_name(name_exprs, name)[0]
0191 for name in [
0192 "p2",
0193 "p3",
0194 "err",
0195 "new_p",
0196 "new_d",
0197 "new_time",
0198 "path_derivatives",
0199 "new_J",
0200 ]
0201 ]
0202
0203 lines = []
0204
0205 head = "template <typename T, typename GetB> Acts::Result<bool> rk4(const T* p, const T* d, const T t, const T h, const T lambda, const T m, const T p_abs, GetB getB, T* err, const T errTol, T* new_p, T* new_d, T* new_time, T* path_derivatives, T* J) {"
0206 lines.append(head)
0207
0208 lines.append(" const auto B1res = getB(p);")
0209 lines.append(
0210 " if (!B1res.ok()) {\n return Acts::Result<bool>::failure(B1res.error());\n }"
0211 )
0212 lines.append(" const auto B1 = *B1res;")
0213
0214 def pre_expr_hook(var):
0215 if str(var) == "p2":
0216 return "T p2[3];"
0217 if str(var) == "p3":
0218 return "T p3[3];"
0219 if str(var) == "new_J":
0220 return "T new_J[64];"
0221 return None
0222
0223 def post_expr_hook(var):
0224 if str(var) == "p2":
0225 return "const auto B2res = getB(p2);\n if (!B2res.ok()) {\n return Acts::Result<bool>::failure(B2res.error());\n }\n const auto B2 = *B2res;"
0226 if str(var) == "p3":
0227 return "const auto B3res = getB(p3);\n if (!B3res.ok()) {\n return Acts::Result<bool>::failure(B3res.error());\n }\n const auto B3 = *B3res;"
0228 if str(var) == "err":
0229 return (
0230 "if (*err > errTol) {\n return Acts::Result<bool>::success(false);\n}"
0231 )
0232 if str(var) == "new_time":
0233 return "if (J == nullptr) {\n return Acts::Result<bool>::success(true);\n}"
0234 if str(var) == "new_J":
0235 return printer.doprint(Assignment(MatrixSymbol("J", 8, 8), var))
0236 return None
0237
0238 code = my_expression_print(
0239 printer,
0240 name_exprs,
0241 outputs,
0242 run_cse=run_cse,
0243 pre_expr_hook=pre_expr_hook,
0244 post_expr_hook=post_expr_hook,
0245 )
0246 lines.extend([f" {l}" for l in code.split("\n")])
0247
0248 lines.append(" return Acts::Result<bool>::success(true);")
0249
0250 lines.append("}")
0251
0252 return "\n".join(lines)
0253
0254
0255 all_name_exprs = rk4_short_math()
0256
0257
0258
0259
0260
0261
0262
0263
0264
0265
0266
0267
0268
0269
0270
0271
0272
0273
0274
0275
0276 code = my_step_function_print(
0277 all_name_exprs,
0278 run_cse=True,
0279 )
0280
0281 output.write(
0282 """// This file is part of the ACTS project.
0283 //
0284 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0285 //
0286 // This Source Code Form is subject to the terms of the Mozilla Public
0287 // License, v. 2.0. If a copy of the MPL was not distributed with this
0288 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0289
0290 // Note: This file is generated by generate_sympy_stepper.py
0291 // Do not modify it manually.
0292
0293 #pragma once
0294
0295 #include <Acts/Utilities/Result.hpp>
0296
0297 #include <cmath>
0298 """
0299 )
0300 output.write(code + "\n")
0301
0302 if output is not sys.stdout:
0303 output.close()