File indexing completed on 2025-01-18 09:10:44
0001 import numpy as np
0002
0003 import sympy as sym
0004 from sympy import Symbol, Matrix, ImmutableMatrix, MatrixSymbol
0005 from sympy.codegen.ast import Assignment
0006
0007 from sympy_common import (
0008 NamedExpr,
0009 name_expr,
0010 find_by_name,
0011 my_subs,
0012 cxx_printer,
0013 my_expression_print,
0014 )
0015
0016
0017
0018 l = Symbol("lambda", real=True)
0019
0020
0021 h = Symbol("h", real=True)
0022
0023
0024 p = MatrixSymbol("p", 3, 1)
0025
0026
0027 d = MatrixSymbol("d", 3, 1)
0028
0029
0030 t = Symbol("t", real=True)
0031
0032
0033 m = Symbol("m", real=True)
0034
0035
0036 p_abs = Symbol("p_abs", real=True, positive=True)
0037
0038
0039 B1 = MatrixSymbol("B1", 3, 1)
0040 B2 = MatrixSymbol("B2", 3, 1)
0041 B3 = MatrixSymbol("B3", 3, 1)
0042
0043
0044 def rk4_full_math():
0045 k1 = name_expr("k1", d.as_explicit().cross(l * B1))
0046 p2 = name_expr("p2", p + h / 2 * d + h**2 / 8 * k1.expr)
0047
0048 k2 = name_expr("k2", (d + h / 2 * k1.expr).as_explicit().cross(l * B2))
0049 k3 = name_expr("k3", (d + h / 2 * k2.expr).as_explicit().cross(l * B2))
0050 p3 = name_expr("p3", p + h * d + h**2 / 2 * k3.expr)
0051
0052 k4 = name_expr("k4", (d + h * k3.expr).as_explicit().cross(l * B3))
0053
0054 err = name_expr("err", h**2 * (k1.expr - k2.expr - k3.expr + k4.expr).norm(1))
0055
0056 new_p = name_expr("new_p", p + h * d + h**2 / 6 * (k1.expr + k2.expr + k3.expr))
0057 new_d_tmp = name_expr(
0058 "new_d_tmp", d + h / 6 * (k1.expr + 2 * (k2.expr + k3.expr) + k4.expr)
0059 )
0060 new_d = name_expr("new_d", new_d_tmp.expr / new_d_tmp.expr.as_explicit().norm())
0061
0062 dtds = name_expr("dtds", sym.sqrt(1 + m**2 / p_abs**2))
0063 new_time = name_expr("new_time", t + h * dtds.expr)
0064
0065 path_derivatives = name_expr("path_derivatives", sym.zeros(8, 1))
0066 path_derivatives.expr[0:3, 0] = new_d.expr.as_explicit()
0067 path_derivatives.expr[3, 0] = dtds.expr
0068 path_derivatives.expr[4:7, 0] = k4.expr.as_explicit()
0069
0070 D = sym.eye(8)
0071 D[0:3, :] = new_p.expr.as_explicit().jacobian([p, t, d, l])
0072 D[4:7, :] = new_d_tmp.expr.as_explicit().jacobian([p, t, d, l])
0073 D[3, 7] = h * m**2 * l / dtds.expr
0074
0075 J = MatrixSymbol("J", 8, 8).as_explicit().as_mutable()
0076 for indices in np.ndindex(J.shape):
0077 if D[indices] in [0, 1]:
0078 J[indices] = D[indices]
0079 J = ImmutableMatrix(J)
0080
0081 new_J = name_expr("new_J", J * D)
0082
0083 return [p2, p3, err, new_p, new_d, new_time, path_derivatives, new_J]
0084
0085
0086 def rk4_short_math():
0087 k1 = name_expr("k1", d.as_explicit().cross(l * B1))
0088 p2 = name_expr("p2", p + h / 2 * d + h**2 / 8 * k1.name)
0089
0090 k2 = name_expr("k2", (d + h / 2 * k1.name).as_explicit().cross(l * B2))
0091 k3 = name_expr("k3", (d + h / 2 * k2.name).as_explicit().cross(l * B2))
0092 p3 = name_expr("p3", p + h * d + h**2 / 2 * k3.name)
0093
0094 k4 = name_expr("k4", (d + h * k3.name).as_explicit().cross(l * B3))
0095
0096 err = name_expr(
0097 "err", h**2 * (k1.name - k2.name - k3.name + k4.name).as_explicit().norm(1)
0098 )
0099
0100 new_p = name_expr("new_p", p + h * d + h**2 / 6 * (k1.name + k2.name + k3.name))
0101 new_d_tmp = name_expr(
0102 "new_d_tmp", d + h / 6 * (k1.name + 2 * (k2.name + k3.name) + k4.name)
0103 )
0104 new_d = name_expr("new_d", new_d_tmp.name / new_d_tmp.name.as_explicit().norm())
0105
0106 dtds = name_expr("dtds", sym.sqrt(1 + m**2 / p_abs**2))
0107 new_time = name_expr("new_time", t + h * dtds.name)
0108
0109 path_derivatives = name_expr("path_derivatives", sym.zeros(8, 1))
0110 path_derivatives.expr[0:3, 0] = new_d.name.as_explicit()
0111 path_derivatives.expr[3, 0] = dtds.name
0112 path_derivatives.expr[4:7, 0] = k4.name.as_explicit()
0113
0114 dk1dTL = name_expr("dk1dTL", k1.expr.jacobian([d, l]))
0115 dk2dTL = name_expr(
0116 "dk2dTL", k2.expr.jacobian([d, l]) + k2.expr.jacobian(k1.name) * dk1dTL.expr
0117 )
0118 dk3dTL = name_expr(
0119 "dk3dTL",
0120 k3.expr.jacobian([d, l])
0121 + k3.expr.jacobian(k2.name) * dk2dTL.name.as_explicit(),
0122 )
0123 dk4dTL = name_expr(
0124 "dk4dTL",
0125 k4.expr.jacobian([d, l])
0126 + k4.expr.jacobian(k3.name) * dk3dTL.name.as_explicit(),
0127 )
0128
0129 dFdTL = name_expr(
0130 "dFdTL",
0131 new_p.expr.as_explicit().jacobian([d, l])
0132 + new_p.expr.as_explicit().jacobian(k1.name) * dk1dTL.expr
0133 + new_p.expr.as_explicit().jacobian(k2.name) * dk2dTL.name.as_explicit()
0134 + new_p.expr.as_explicit().jacobian(k3.name) * dk3dTL.name.as_explicit(),
0135 )
0136 dGdTL = name_expr(
0137 "dGdTL",
0138 new_d_tmp.expr.as_explicit().jacobian([d, l])
0139 + new_d_tmp.expr.as_explicit().jacobian(k1.name) * dk1dTL.expr
0140 + new_d_tmp.expr.as_explicit().jacobian(k2.name) * dk2dTL.name.as_explicit()
0141 + new_d_tmp.expr.as_explicit().jacobian(k3.name) * dk3dTL.name.as_explicit()
0142 + new_d_tmp.expr.as_explicit().jacobian(k4.name) * dk4dTL.name.as_explicit(),
0143 )
0144
0145 D = sym.eye(8)
0146 D[0:3, 4:8] = dFdTL.name.as_explicit()
0147 D[4:7, 4:8] = dGdTL.name.as_explicit()
0148 D[3, 7] = h * m**2 * l / dtds.name
0149
0150 J = Matrix(MatrixSymbol("J", 8, 8).as_explicit())
0151 for indices in np.ndindex(J.shape):
0152 if D[indices] in [0, 1]:
0153 J[indices] = D[indices]
0154 J = ImmutableMatrix(J)
0155 new_J = name_expr("new_J", J * D)
0156
0157 return [
0158 k1,
0159 p2,
0160 k2,
0161 k3,
0162 p3,
0163 k4,
0164 err,
0165 new_p,
0166 new_d_tmp,
0167 new_d,
0168 dtds,
0169 new_time,
0170 path_derivatives,
0171 dk2dTL,
0172 dk3dTL,
0173 dk4dTL,
0174 dFdTL,
0175 dGdTL,
0176 new_J,
0177 ]
0178
0179
0180 def my_step_function_print(name_exprs, run_cse=True):
0181 printer = cxx_printer
0182 outputs = [
0183 find_by_name(name_exprs, name)[0]
0184 for name in [
0185 "p2",
0186 "p3",
0187 "err",
0188 "new_p",
0189 "new_d",
0190 "new_time",
0191 "path_derivatives",
0192 "new_J",
0193 ]
0194 ]
0195
0196 lines = []
0197
0198 head = "template <typename T, typename GetB> Acts::Result<bool> rk4(const T* p, const T* d, const T t, const T h, const T lambda, const T m, const T p_abs, GetB getB, T* err, const T errTol, T* new_p, T* new_d, T* new_time, T* path_derivatives, T* J) {"
0199 lines.append(head)
0200
0201 lines.append(" const auto B1res = getB(p);")
0202 lines.append(
0203 " if (!B1res.ok()) {\n return Acts::Result<bool>::failure(B1res.error());\n }"
0204 )
0205 lines.append(" const auto B1 = *B1res;")
0206
0207 def pre_expr_hook(var):
0208 if str(var) == "p2":
0209 return "T p2[3];"
0210 if str(var) == "p3":
0211 return "T p3[3];"
0212 if str(var) == "new_J":
0213 return "T new_J[64];"
0214 return None
0215
0216 def post_expr_hook(var):
0217 if str(var) == "p2":
0218 return "const auto B2res = getB(p2);\n if (!B2res.ok()) {\n return Acts::Result<bool>::failure(B2res.error());\n }\n const auto B2 = *B2res;"
0219 if str(var) == "p3":
0220 return "const auto B3res = getB(p3);\n if (!B3res.ok()) {\n return Acts::Result<bool>::failure(B3res.error());\n }\n const auto B3 = *B3res;"
0221 if str(var) == "err":
0222 return (
0223 "if (*err > errTol) {\n return Acts::Result<bool>::success(false);\n}"
0224 )
0225 if str(var) == "new_time":
0226 return "if (J == nullptr) {\n return Acts::Result<bool>::success(true);\n}"
0227 if str(var) == "new_J":
0228 return printer.doprint(Assignment(MatrixSymbol("J", 8, 8), var))
0229 return None
0230
0231 code = my_expression_print(
0232 printer,
0233 name_exprs,
0234 outputs,
0235 run_cse=run_cse,
0236 pre_expr_hook=pre_expr_hook,
0237 post_expr_hook=post_expr_hook,
0238 )
0239 lines.extend([f" {l}" for l in code.split("\n")])
0240
0241 lines.append(" return Acts::Result<bool>::success(true);")
0242
0243 lines.append("}")
0244
0245 return "\n".join(lines)
0246
0247
0248 all_name_exprs = rk4_short_math()
0249
0250
0251
0252
0253
0254
0255
0256
0257
0258
0259
0260
0261
0262
0263
0264
0265
0266
0267
0268
0269 code = my_step_function_print(
0270 all_name_exprs,
0271 run_cse=True,
0272 )
0273
0274 print(
0275 """// This file is part of the ACTS project.
0276 //
0277 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0278 //
0279 // This Source Code Form is subject to the terms of the Mozilla Public
0280 // License, v. 2.0. If a copy of the MPL was not distributed with this
0281 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0282
0283 // Note: This file is generated by generate_sympy_stepper.py
0284 // Do not modify it manually.
0285
0286 #pragma once
0287
0288 #include <Acts/Utilities/Result.hpp>
0289
0290 #include <cmath>
0291 """
0292 )
0293 print(code)