File indexing completed on 2025-12-16 09:23:15
0001 import sys
0002
0003 import numpy as np
0004
0005 import sympy as sym
0006 from sympy import Symbol, Matrix, ImmutableMatrix, MatrixSymbol
0007 from sympy.codegen.ast import Assignment
0008
0009 from codegen.sympy_common import (
0010 make_vector,
0011 NamedExpr,
0012 name_expr,
0013 find_by_name,
0014 my_subs,
0015 cxx_printer,
0016 my_expression_print,
0017 )
0018
0019
0020 output = sys.stdout
0021 if len(sys.argv) > 1:
0022 output = open(sys.argv[1], "w")
0023
0024
0025
0026 l = Symbol("l", real=True)
0027
0028
0029 h = Symbol("h", real=True)
0030
0031
0032 s = Symbol("s", real=True)
0033
0034
0035 p = make_vector("p", 3, real=True)
0036
0037
0038 d = make_vector("d", 3, real=True)
0039
0040
0041 t = Symbol("t", real=True)
0042
0043
0044 m = Symbol("m", real=True, positive=True)
0045
0046
0047 p_abs = Symbol("p_abs", real=True, positive=True)
0048
0049
0050 g = Symbol("g", real=True)
0051
0052
0053 q = Symbol("q", real=True)
0054
0055
0056 B = make_vector("B", 3, real=True)
0057
0058
0059 B1 = make_vector("B1", 3, real=True)
0060 B2 = make_vector("B2", 3, real=True)
0061 B3 = make_vector("B3", 3, real=True)
0062
0063
0064 g1 = Symbol("g1", real=True)
0065 g2 = Symbol("g2", real=True)
0066 g3 = Symbol("g3", real=True)
0067 g4 = Symbol("g4", real=True)
0068
0069
0070 def rk4_subexpr(f, x, y, ydot, h):
0071 k1 = name_expr("k1", f(1, x, y, ydot))
0072 x2 = name_expr("x2", x + h / 2)
0073 y2 = name_expr("y2", y + h / 2 * ydot + h**2 / 8 * k1.name)
0074 ydot2 = name_expr("ydot2", ydot + h / 2 * k1.name)
0075
0076 k2 = name_expr("k2", f(2, x2.expr, y2.expr.as_explicit(), ydot2.expr.as_explicit()))
0077 ydot3 = name_expr("ydot3", ydot + h / 2 * k2.name)
0078
0079 k3 = name_expr("k3", f(3, x2.expr, y2.expr.as_explicit(), ydot3.expr.as_explicit()))
0080 x3 = name_expr("x3", x + h)
0081 y3 = name_expr("y3", y + h * ydot + h**2 / 2 * k3.name)
0082 ydot4 = name_expr("ydot4", ydot + h * k3.name)
0083
0084 k4 = name_expr("k4", f(4, x3.expr, y3.expr.as_explicit(), ydot4.expr.as_explicit()))
0085
0086 new_y = name_expr("new_y", y + h * ydot + h**2 / 6 * (k1.name + k2.name + k3.name))
0087 new_ydot = name_expr(
0088 "new_ydot", ydot + h / 6 * (k1.name + 2 * (k2.name + k3.name) + k4.name)
0089 )
0090
0091 dk1dyydot = name_expr("dk1dyydot", k1.expr.jacobian([y, ydot]))
0092 dk2dyydot = name_expr(
0093 "dk2dyydot",
0094 k2.expr.jacobian([y, ydot]) + k2.expr.jacobian(k1.name) * dk1dyydot.name,
0095 )
0096 dk3dyydot = name_expr(
0097 "dk3dyydot",
0098 k3.expr.jacobian([y, ydot]) + k3.expr.jacobian(k2.name) * dk2dyydot.name,
0099 )
0100 dk4dyydot = name_expr(
0101 "dk4dyydot",
0102 k4.expr.jacobian([y, ydot]) + k4.expr.jacobian(k3.name) * dk3dyydot.name,
0103 )
0104
0105 dydyydot = name_expr(
0106 "dydyydot",
0107 new_y.expr.as_explicit().jacobian([y, ydot])
0108 + new_y.expr.as_explicit().jacobian(k1.name) * dk1dyydot.name
0109 + new_y.expr.as_explicit().jacobian(k2.name) * dk2dyydot.name
0110 + new_y.expr.as_explicit().jacobian(k3.name) * dk3dyydot.name,
0111 )
0112 dydotdyydot = name_expr(
0113 "dydotdyydot",
0114 new_ydot.expr.as_explicit().jacobian([y, ydot])
0115 + new_ydot.expr.as_explicit().jacobian(k1.name) * dk1dyydot.name
0116 + new_ydot.expr.as_explicit().jacobian(k2.name) * dk2dyydot.name
0117 + new_ydot.expr.as_explicit().jacobian(k3.name) * dk3dyydot.name
0118 + new_ydot.expr.as_explicit().jacobian(k4.name) * dk4dyydot.name,
0119 )
0120
0121 return (
0122 ((new_y, new_ydot), (k1, k2, k3, k4)),
0123 ((dydyydot, dydotdyydot), (dk1dyydot, dk2dyydot, dk3dyydot, dk4dyydot)),
0124 (x2, y2, ydot2, ydot3, x3, y3, ydot4),
0125 )
0126
0127
0128 def rk4_fullexpr(f, x, y, ydot, h):
0129 k1 = name_expr("k1", f(1, x, y, ydot))
0130 x2 = name_expr("x2", x + h / 2)
0131 y2 = name_expr("y2", y + h / 2 * ydot + h**2 / 8 * k1.expr)
0132 ydot2 = name_expr("ydot2", ydot + h / 2 * k1.expr)
0133
0134 k2 = name_expr("k2", f(2, x2.expr, y2.expr, ydot2.expr))
0135 ydot3 = name_expr("ydot3", ydot + h / 2 * k2.expr)
0136
0137 k3 = name_expr("k3", f(3, x2.expr, y2.expr, ydot3.expr))
0138 x3 = name_expr("x3", x + h)
0139 y3 = name_expr("y3", y + h * ydot + h**2 / 2 * k3.expr)
0140 ydot4 = name_expr("ydot4", ydot + h * k3.expr)
0141
0142 k4 = name_expr("k4", f(4, x3.expr, y3.expr, ydot4.expr))
0143
0144 new_y = name_expr("new_y", y + h * ydot + h**2 / 6 * (k1.expr + k2.expr + k3.expr))
0145 new_ydot = name_expr(
0146 "new_ydot", ydot + h / 6 * (k1.expr + 2 * (k2.expr + k3.expr) + k4.expr)
0147 )
0148
0149 return (new_y, new_ydot), (k1, k2, k3, k4), (x2, y2, ydot2, ydot3, x3, y3, ydot4)
0150
0151
0152 def rk4_vacuum_fullexpr2():
0153 def f(x, y, ydot):
0154 d = ydot[0:3, 0]
0155 return d.cross(l * B)
0156
0157 def decorator(i, ydotdot):
0158 if i == 1:
0159 return ydotdot.subs(B, B1)
0160 if i in [2, 3]:
0161 return ydotdot.subs(B, B2)
0162 if i == 4:
0163 return ydotdot.subs(B, B3)
0164
0165 (
0166 (new_y, new_ydot),
0167 (dydyydot, dydotdyydot),
0168 (k1, k2, k3, k4),
0169 (x2, y2, ydot2, ydot3, x3, y3, ydot4),
0170 ) = rk4_fullexpr(f, s, p, d, h, decorator)
0171
0172 p2 = name_expr("p2", y2.expr[0:3, 0])
0173 p3 = name_expr("p3", y3.expr[0:3, 0])
0174 new_p = name_expr("new_p", new_y.expr[0:3, 0])
0175 new_d_tmp = name_expr("new_d_tmp", new_ydot.expr[0:3, 0])
0176 new_d = name_expr("new_d", new_d_tmp.expr / new_d_tmp.expr.norm())
0177 err = name_expr(
0178 "err",
0179 h**2
0180 * (k1.expr[0:3, 0] - k2.expr[0:3, 0] - k3.expr[0:3, 0] + k4.expr[0:3, 0]).norm(
0181 1
0182 ),
0183 )
0184
0185 dtds = name_expr("dtds", sym.sqrt(1 + m**2 / p_abs**2))
0186 new_t = name_expr("new_t", t + h * dtds.expr)
0187
0188 path_derivatives = name_expr("path_derivatives", sym.zeros(8, 1))
0189 path_derivatives.expr[0:3, 0] = new_d.expr
0190 path_derivatives.expr[3, 0] = dtds.expr
0191 path_derivatives.expr[4:7, 0] = k4.expr
0192
0193 D = sym.eye(8)
0194 D[0:3, :] = new_p.expr.jacobian([p, t, d, l])
0195 D[4:7, :] = new_d_tmp.expr.jacobian([p, t, d, l])
0196 D[3, 7] = h * m**2 * l / dtds.expr
0197
0198 J = MatrixSymbol("J", 8, 8).as_explicit().as_mutable()
0199 for indices in np.ndindex(J.shape):
0200 if D[indices] in [0, 1]:
0201 J[indices] = D[indices]
0202 J = ImmutableMatrix(J)
0203
0204 new_J = name_expr("new_J", J * D)
0205
0206 return [p2, p3, err, new_p, new_t, new_d, path_derivatives, new_J]
0207
0208
0209 def rk4_vacuum_fullexpr():
0210 k1 = name_expr("k1", d.cross(l * B1))
0211 p2 = name_expr("p2", p + h / 2 * d + h**2 / 8 * k1.expr)
0212
0213 k2 = name_expr("k2", (d + h / 2 * k1.expr).cross(l * B2))
0214 k3 = name_expr("k3", (d + h / 2 * k2.expr).cross(l * B2))
0215 p3 = name_expr("p3", p + h * d + h**2 / 2 * k3.expr)
0216
0217 k4 = name_expr("k4", (d + h * k3.expr).cross(l * B3))
0218
0219 err = name_expr("err", h**2 * (k1.expr - k2.expr - k3.expr + k4.expr).norm(1))
0220
0221 new_p = name_expr("new_p", p + h * d + h**2 / 6 * (k1.expr + k2.expr + k3.expr))
0222 new_d_tmp = name_expr(
0223 "new_d_tmp", d + h / 6 * (k1.expr + 2 * (k2.expr + k3.expr) + k4.expr)
0224 )
0225 new_d = name_expr("new_d", new_d_tmp.expr / new_d_tmp.expr.norm())
0226
0227 dtds = name_expr("dtds", sym.sqrt(1 + m**2 / p_abs**2))
0228 new_t = name_expr("new_t", t + h * dtds.expr)
0229
0230 path_derivatives = name_expr("path_derivatives", sym.zeros(8, 1))
0231 path_derivatives.expr[0:3, 0] = new_d.expr
0232 path_derivatives.expr[3, 0] = dtds.expr
0233 path_derivatives.expr[4:7, 0] = k4.expr
0234
0235 D = sym.eye(8)
0236 D[0:3, :] = new_p.expr.jacobian([p, t, d, l])
0237 D[4:7, :] = new_d_tmp.expr.jacobian([p, t, d, l])
0238 D[3, 7] = h * m**2 * l / dtds.expr
0239
0240 J = MatrixSymbol("J", 8, 8).as_explicit().as_mutable()
0241 for indices in np.ndindex(J.shape):
0242 if D[indices] in [0, 1]:
0243 J[indices] = D[indices]
0244 J = ImmutableMatrix(J)
0245
0246 new_J = name_expr("new_J", J * D)
0247
0248 return [p2, p3, err, new_p, new_t, new_d, path_derivatives, new_J]
0249
0250
0251 def rk4_vacuum_tunedexpr():
0252 k1 = name_expr("k1", d.cross(l * B1))
0253 p2 = name_expr("p2", p + h / 2 * d + h**2 / 8 * k1.name)
0254
0255 k2 = name_expr("k2", (d + h / 2 * k1.name).as_explicit().cross(l * B2))
0256 k3 = name_expr("k3", (d + h / 2 * k2.name).as_explicit().cross(l * B2))
0257 p3 = name_expr("p3", p + h * d + h**2 / 2 * k3.name)
0258
0259 k4 = name_expr("k4", (d + h * k3.name).as_explicit().cross(l * B3))
0260
0261 err = name_expr(
0262 "err", h**2 * (k1.name - k2.name - k3.name + k4.name).as_explicit().norm(1)
0263 )
0264
0265 new_p = name_expr("new_p", p + h * d + h**2 / 6 * (k1.name + k2.name + k3.name))
0266 new_d_tmp = name_expr(
0267 "new_d_tmp", d + h / 6 * (k1.name + 2 * (k2.name + k3.name) + k4.name)
0268 )
0269 new_d = name_expr("new_d", new_d_tmp.name / new_d_tmp.name.as_explicit().norm())
0270
0271 dtds = name_expr("dtds", sym.sqrt(1 + m**2 / p_abs**2))
0272 new_t = name_expr("new_t", t + h * dtds.name)
0273
0274 path_derivatives = name_expr("path_derivatives", sym.zeros(8, 1))
0275 path_derivatives.expr[0:3, 0] = new_d.name.as_explicit()
0276 path_derivatives.expr[3, 0] = dtds.name
0277 path_derivatives.expr[4:7, 0] = k4.name.as_explicit()
0278
0279 dk1dTL = name_expr("dk1dTL", k1.expr.jacobian([d, l]))
0280 dk2dTL = name_expr(
0281 "dk2dTL", k2.expr.jacobian([d, l]) + k2.expr.jacobian(k1.name) * dk1dTL.expr
0282 )
0283 dk3dTL = name_expr(
0284 "dk3dTL",
0285 k3.expr.jacobian([d, l]) + k3.expr.jacobian(k2.name) * dk2dTL.name,
0286 )
0287 dk4dTL = name_expr(
0288 "dk4dTL",
0289 k4.expr.jacobian([d, l]) + k4.expr.jacobian(k3.name) * dk3dTL.name,
0290 )
0291
0292 dFdTL = name_expr(
0293 "dFdTL",
0294 new_p.expr.as_explicit().jacobian([d, l])
0295 + new_p.expr.as_explicit().jacobian(k1.name) * dk1dTL.expr
0296 + new_p.expr.as_explicit().jacobian(k2.name) * dk2dTL.name
0297 + new_p.expr.as_explicit().jacobian(k3.name) * dk3dTL.name,
0298 )
0299 dGdTL = name_expr(
0300 "dGdTL",
0301 new_d_tmp.expr.as_explicit().jacobian([d, l])
0302 + new_d_tmp.expr.as_explicit().jacobian(k1.name) * dk1dTL.expr
0303 + new_d_tmp.expr.as_explicit().jacobian(k2.name) * dk2dTL.name
0304 + new_d_tmp.expr.as_explicit().jacobian(k3.name) * dk3dTL.name
0305 + new_d_tmp.expr.as_explicit().jacobian(k4.name) * dk4dTL.name,
0306 )
0307
0308 D = sym.eye(8)
0309 D[0:3, 4:8] = dFdTL.name.as_explicit()
0310 D[4:7, 4:8] = dGdTL.name.as_explicit()
0311 D[3, 7] = h * m**2 * l / dtds.name
0312
0313 J = Matrix(MatrixSymbol("J", 8, 8).as_explicit())
0314 for indices in np.ndindex(J.shape):
0315 if D[indices] in [0, 1]:
0316 J[indices] = D[indices]
0317 J = ImmutableMatrix(J)
0318 new_J = name_expr("new_J", J * D)
0319
0320 return [
0321 k1,
0322 p2,
0323 k2,
0324 k3,
0325 p3,
0326 k4,
0327 err,
0328 new_p,
0329 dtds,
0330 new_t,
0331 new_d_tmp,
0332 new_d,
0333 path_derivatives,
0334 dk2dTL,
0335 dk3dTL,
0336 dk4dTL,
0337 dFdTL,
0338 dGdTL,
0339 new_J,
0340 ]
0341
0342
0343 def rk4_dense_tunedexpr():
0344 def f(i, x, y, ydot):
0345 B = [B1, B2, B2, B3][i - 1]
0346 g = [g1, g2, g3, g4][i - 1]
0347
0348 d = ydot[0:3, 0]
0349 dtds = ydot[3, 0]
0350 l = ydot[4, 0]
0351 return Matrix.vstack(
0352 d.cross(l * B),
0353 Matrix([g * m**2 * l**3 / q**3]),
0354 Matrix([dtds * l**2 * g / q]),
0355 )
0356
0357 big_l = Symbol("big_l", real=True)
0358 dtds = name_expr("dtds", sym.sqrt(1 + m**2 / p_abs**2))
0359
0360 (
0361 ((new_y, new_ydot), (k1, k2, k3, k4)),
0362 ((dydyydot, dydotdyydot), (dk1dyydot, dk2dyydot, dk3dyydot, dk4dyydot)),
0363 (x2, y2, ydot2, ydot3, x3, y3, ydot4),
0364 ) = rk4_subexpr(
0365 f,
0366 s,
0367 Matrix.vstack(p, Matrix([t, big_l])),
0368 Matrix.vstack(d, Matrix([dtds.name, l])),
0369 h,
0370 )
0371
0372 p2 = name_expr("p2", y2.expr[0:3, 0])
0373 l2 = name_expr("l2", ydot2.expr[4, 0])
0374 p3 = name_expr("p3", y3.expr[0:3, 0])
0375 l3 = name_expr("l3", ydot3.expr[4, 0])
0376 l4 = name_expr("l4", ydot4.expr[4, 0])
0377 new_p = name_expr("new_p", new_y.expr[0:3, 0])
0378 new_t = name_expr("new_t", new_y.expr[3, 0])
0379 new_d_tmp = name_expr("new_d_tmp", new_ydot.expr[0:3, 0])
0380 new_d = name_expr("new_d", new_d_tmp.name / new_d_tmp.name.as_explicit().norm())
0381 new_l = name_expr("new_l", new_ydot.expr[4, 0])
0382 err = name_expr(
0383 "err",
0384 h**2
0385 * (k1.name[0:3, 0] - k2.name[0:3, 0] - k3.name[0:3, 0] + k4.name[0:3, 0])
0386 .as_explicit()
0387 .norm(1),
0388 )
0389
0390 path_derivatives = name_expr("path_derivatives", sym.zeros(8, 1))
0391 path_derivatives.expr[0:3, 0] = new_d.name.as_explicit()
0392 path_derivatives.expr[3, 0] = new_ydot.name[3, 0]
0393 path_derivatives.expr[4:7, 0] = k4.name[0:3, 0].as_explicit()
0394 path_derivatives.expr[7, 0] = new_ydot.name[4, 0]
0395
0396 dk1dTL = name_expr("dk1dTL", k1.expr.jacobian([t, d, l]))
0397 dk2dTL = name_expr(
0398 "dk2dTL", k2.expr.jacobian([t, d, l]) + k2.expr.jacobian(k1.name) * dk1dTL.expr
0399 )
0400 dk3dTL = name_expr(
0401 "dk3dTL",
0402 k3.expr.jacobian([t, d, l]) + k3.expr.jacobian(k2.name) * dk2dTL.name,
0403 )
0404 dk4dTL = name_expr(
0405 "dk4dTL",
0406 k4.expr.jacobian([t, d, l]) + k4.expr.jacobian(k3.name) * dk3dTL.name,
0407 )
0408
0409 F = Matrix.vstack(new_p.expr.as_explicit(), Matrix([new_t.expr]))
0410 dFdTL = name_expr(
0411 "dFdTL",
0412 F.jacobian([t, d, l])
0413 + F.jacobian(k1.name) * dk1dTL.expr
0414 + F.jacobian(k2.name) * dk2dTL.name
0415 + F.jacobian(k3.name) * dk3dTL.name,
0416 )
0417 G = Matrix.vstack(new_d_tmp.expr.as_explicit(), Matrix([new_l.expr]))
0418 dGdTL = name_expr(
0419 "dGdTL",
0420 G.jacobian([t, d, l])
0421 + G.jacobian(k1.name) * dk1dTL.expr
0422 + G.jacobian(k2.name) * dk2dTL.name
0423 + G.jacobian(k3.name) * dk3dTL.name
0424 + G.jacobian(k4.name) * dk4dTL.name,
0425 )
0426
0427 D = sym.eye(8)
0428 D[0:4, 3:8] = dFdTL.name.as_explicit()
0429 D[4:8, 3:8] = dGdTL.name.as_explicit()
0430
0431 J = Matrix(MatrixSymbol("J", 8, 8).as_explicit())
0432 for indices in np.ndindex(J.shape):
0433 if D[indices] in [0, 1]:
0434 J[indices] = D[indices]
0435 J = ImmutableMatrix(J)
0436 new_J = name_expr("new_J", J * D)
0437
0438 return [
0439 dtds,
0440 k1,
0441 y2,
0442 ydot2,
0443 ydot3,
0444 p2,
0445 l2,
0446 k2,
0447 l3,
0448 k3,
0449 y3,
0450 ydot4,
0451 p3,
0452 l4,
0453 k4,
0454 err,
0455 new_y,
0456 new_ydot,
0457 new_p,
0458 new_t,
0459 new_d_tmp,
0460 new_d,
0461 new_l,
0462 path_derivatives,
0463 dk2dTL,
0464 dk3dTL,
0465 dk4dTL,
0466 dFdTL,
0467 dGdTL,
0468 new_J,
0469 ]
0470
0471
0472 def print_rk4_vacuum(name_exprs, run_cse=True):
0473 printer = cxx_printer
0474 outputs = [
0475 find_by_name(name_exprs, name)[0]
0476 for name in [
0477 "p2",
0478 "p3",
0479 "err",
0480 "new_p",
0481 "new_t",
0482 "new_d",
0483 "path_derivatives",
0484 "new_J",
0485 ]
0486 ]
0487
0488 lines = []
0489
0490 head = "template <typename T, typename GetB> Acts::Result<bool> rk4_vacuum(const T* p, const T* d, const T t, const T h, const T l, const T m, const T p_abs, GetB getB, T* err, const T errTol, T* new_p, T* new_t, T* new_d, T* path_derivatives, T* J) {"
0491 lines.append(head)
0492
0493 lines.append(" const auto B1res = getB(p);")
0494 lines.append(
0495 " if (!B1res.ok()) {\n return Acts::Result<bool>::failure(B1res.error());\n }"
0496 )
0497 lines.append(" const auto B1 = *B1res;")
0498
0499 def pre_expr_hook(var):
0500 if str(var) == "p2":
0501 return "T p2[3];"
0502 if str(var) == "p3":
0503 return "T p3[3];"
0504 if str(var) == "new_J":
0505 return "T new_J[64];"
0506 return None
0507
0508 def post_expr_hook(var):
0509 if str(var) == "p2":
0510 return "const auto B2res = getB(p2);\n if (!B2res.ok()) {\n return Acts::Result<bool>::failure(B2res.error());\n }\n const auto B2 = *B2res;"
0511 if str(var) == "p3":
0512 return "const auto B3res = getB(p3);\n if (!B3res.ok()) {\n return Acts::Result<bool>::failure(B3res.error());\n }\n const auto B3 = *B3res;"
0513 if str(var) == "err":
0514 return (
0515 "if (*err > errTol) {\n return Acts::Result<bool>::success(false);\n}"
0516 )
0517 if str(var) == "new_d":
0518 return "if (J == nullptr) {\n return Acts::Result<bool>::success(true);\n}"
0519 if str(var) == "new_J":
0520 return printer.doprint(Assignment(MatrixSymbol("J", 8, 8), var))
0521 return None
0522
0523 code = my_expression_print(
0524 printer,
0525 name_exprs,
0526 outputs,
0527 run_cse=run_cse,
0528 pre_expr_hook=pre_expr_hook,
0529 post_expr_hook=post_expr_hook,
0530 )
0531 lines.extend([f" {l}" for l in code.split("\n")])
0532
0533 lines.append(" return Acts::Result<bool>::success(true);")
0534
0535 lines.append("}")
0536
0537 return "\n".join(lines)
0538
0539
0540 def print_rk4_dense(name_exprs, run_cse=True):
0541 printer = cxx_printer
0542 outputs = [
0543 find_by_name(name_exprs, name)[0]
0544 for name in [
0545 "p2",
0546 "l2",
0547 "l3",
0548 "p3",
0549 "l4",
0550 "err",
0551 "new_p",
0552 "new_t",
0553 "new_d",
0554 "new_l",
0555 "path_derivatives",
0556 "new_J",
0557 ]
0558 ]
0559
0560 lines = []
0561
0562 head = "template <typename T, typename GetB, typename GetG> Acts::Result<bool> rk4_dense(const T* p, const T* d, const T t, const T h, const T l, const T m, const T q, const T p_abs, GetB getB, GetG getG, T* err, const T errTol, T* new_p, T* new_t, T* new_d, T* new_l, T* path_derivatives, T* J) {"
0563 lines.append(head)
0564
0565 lines.append(" const auto B1res = getB(p);")
0566 lines.append(
0567 " if (!B1res.ok()) {\n return Acts::Result<bool>::failure(B1res.error());\n }"
0568 )
0569 lines.append(" const auto B1 = *B1res;")
0570 lines.append(" const auto g1 = getG(p, l);")
0571
0572 def pre_expr_hook(var):
0573 if str(var) == "p2":
0574 return "T p2[3];"
0575 if str(var) == "p3":
0576 return "T p3[3];"
0577 if str(var) == "l2":
0578 return "T l2[1];"
0579 if str(var) == "l3":
0580 return "T l3[1];"
0581 if str(var) == "l4":
0582 return "T l4[1];"
0583 if str(var) == "new_J":
0584 return "T new_J[64];"
0585 return None
0586
0587 def post_expr_hook(var):
0588 if str(var) == "p2":
0589 return "const auto B2res = getB(p2);\n if (!B2res.ok()) {\n return Acts::Result<bool>::failure(B2res.error());\n }\n const auto B2 = *B2res;"
0590 if str(var) == "p3":
0591 return "const auto B3res = getB(p3);\n if (!B3res.ok()) {\n return Acts::Result<bool>::failure(B3res.error());\n }\n const auto B3 = *B3res;"
0592 if str(var) == "l2":
0593 return "const auto g2 = getG(p2, *l2);"
0594 if str(var) == "l3":
0595 return "const auto g3 = getG(p2, *l3);"
0596 if str(var) == "l4":
0597 return "const auto g4 = getG(p3, *l4);"
0598 if str(var) == "err":
0599 return (
0600 "if (*err > errTol) {\n return Acts::Result<bool>::success(false);\n}"
0601 )
0602 if str(var) == "new_d":
0603 return "if (J == nullptr) {\n return Acts::Result<bool>::success(true);\n}"
0604 if str(var) == "new_J":
0605 return printer.doprint(Assignment(MatrixSymbol("J", 8, 8), var))
0606 return None
0607
0608 code = my_expression_print(
0609 printer,
0610 name_exprs,
0611 outputs,
0612 run_cse=run_cse,
0613 pre_expr_hook=pre_expr_hook,
0614 post_expr_hook=post_expr_hook,
0615 )
0616 lines.extend([f" {l}" for l in code.split("\n")])
0617
0618 lines.append(" return Acts::Result<bool>::success(true);")
0619
0620 lines.append("}")
0621
0622 return "\n".join(lines)
0623
0624
0625 output.write(
0626 """
0627 // This file is part of the ACTS project.
0628 //
0629 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0630 //
0631 // This Source Code Form is subject to the terms of the Mozilla Public
0632 // License, v. 2.0. If a copy of the MPL was not distributed with this
0633 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0634
0635 // Note: This file is generated by generate_sympy_stepper.py
0636 // Do not modify it manually.
0637
0638 #pragma once
0639
0640 #include <Acts/Utilities/Result.hpp>
0641
0642 #include <cmath>
0643 """.strip()
0644 )
0645
0646 output.write("\n\n")
0647
0648 all_name_exprs = rk4_vacuum_tunedexpr()
0649
0650
0651
0652
0653
0654
0655
0656
0657
0658
0659
0660
0661
0662
0663
0664
0665
0666
0667
0668
0669
0670
0671 code = print_rk4_vacuum(
0672 all_name_exprs,
0673 run_cse=True,
0674 )
0675 output.write(code + "\n")
0676
0677 output.write("\n")
0678
0679 all_name_exprs = rk4_dense_tunedexpr()
0680
0681 code = print_rk4_dense(
0682 all_name_exprs,
0683 run_cse=True,
0684 )
0685 output.write(code + "\n")
0686
0687 if output is not sys.stdout:
0688 output.close()