Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:23:15

0001 import sys
0002 
0003 import numpy as np
0004 
0005 import sympy as sym
0006 from sympy import Symbol, Matrix, ImmutableMatrix, MatrixSymbol
0007 from sympy.codegen.ast import Assignment
0008 
0009 from codegen.sympy_common import (
0010     make_vector,
0011     NamedExpr,
0012     name_expr,
0013     find_by_name,
0014     my_subs,
0015     cxx_printer,
0016     my_expression_print,
0017 )
0018 
0019 
0020 output = sys.stdout
0021 if len(sys.argv) > 1:
0022     output = open(sys.argv[1], "w")
0023 
0024 
0025 # q/p
0026 l = Symbol("l", real=True)
0027 
0028 # step length
0029 h = Symbol("h", real=True)
0030 
0031 # path length
0032 s = Symbol("s", real=True)
0033 
0034 # position
0035 p = make_vector("p", 3, real=True)
0036 
0037 # direction
0038 d = make_vector("d", 3, real=True)
0039 
0040 # time
0041 t = Symbol("t", real=True)
0042 
0043 # mass
0044 m = Symbol("m", real=True, positive=True)
0045 
0046 # absolute momentum
0047 p_abs = Symbol("p_abs", real=True, positive=True)
0048 
0049 # energy loss per distance
0050 g = Symbol("g", real=True)
0051 
0052 # charge
0053 q = Symbol("q", real=True)
0054 
0055 # magnetic field
0056 B = make_vector("B", 3, real=True)
0057 
0058 # specific magnetic field values
0059 B1 = make_vector("B1", 3, real=True)
0060 B2 = make_vector("B2", 3, real=True)
0061 B3 = make_vector("B3", 3, real=True)
0062 
0063 # specific energy loss per distance values
0064 g1 = Symbol("g1", real=True)
0065 g2 = Symbol("g2", real=True)
0066 g3 = Symbol("g3", real=True)
0067 g4 = Symbol("g4", real=True)
0068 
0069 
0070 def rk4_subexpr(f, x, y, ydot, h):
0071     k1 = name_expr("k1", f(1, x, y, ydot))
0072     x2 = name_expr("x2", x + h / 2)
0073     y2 = name_expr("y2", y + h / 2 * ydot + h**2 / 8 * k1.name)
0074     ydot2 = name_expr("ydot2", ydot + h / 2 * k1.name)
0075 
0076     k2 = name_expr("k2", f(2, x2.expr, y2.expr.as_explicit(), ydot2.expr.as_explicit()))
0077     ydot3 = name_expr("ydot3", ydot + h / 2 * k2.name)
0078 
0079     k3 = name_expr("k3", f(3, x2.expr, y2.expr.as_explicit(), ydot3.expr.as_explicit()))
0080     x3 = name_expr("x3", x + h)
0081     y3 = name_expr("y3", y + h * ydot + h**2 / 2 * k3.name)
0082     ydot4 = name_expr("ydot4", ydot + h * k3.name)
0083 
0084     k4 = name_expr("k4", f(4, x3.expr, y3.expr.as_explicit(), ydot4.expr.as_explicit()))
0085 
0086     new_y = name_expr("new_y", y + h * ydot + h**2 / 6 * (k1.name + k2.name + k3.name))
0087     new_ydot = name_expr(
0088         "new_ydot", ydot + h / 6 * (k1.name + 2 * (k2.name + k3.name) + k4.name)
0089     )
0090 
0091     dk1dyydot = name_expr("dk1dyydot", k1.expr.jacobian([y, ydot]))
0092     dk2dyydot = name_expr(
0093         "dk2dyydot",
0094         k2.expr.jacobian([y, ydot]) + k2.expr.jacobian(k1.name) * dk1dyydot.name,
0095     )
0096     dk3dyydot = name_expr(
0097         "dk3dyydot",
0098         k3.expr.jacobian([y, ydot]) + k3.expr.jacobian(k2.name) * dk2dyydot.name,
0099     )
0100     dk4dyydot = name_expr(
0101         "dk4dyydot",
0102         k4.expr.jacobian([y, ydot]) + k4.expr.jacobian(k3.name) * dk3dyydot.name,
0103     )
0104 
0105     dydyydot = name_expr(
0106         "dydyydot",
0107         new_y.expr.as_explicit().jacobian([y, ydot])
0108         + new_y.expr.as_explicit().jacobian(k1.name) * dk1dyydot.name
0109         + new_y.expr.as_explicit().jacobian(k2.name) * dk2dyydot.name
0110         + new_y.expr.as_explicit().jacobian(k3.name) * dk3dyydot.name,
0111     )
0112     dydotdyydot = name_expr(
0113         "dydotdyydot",
0114         new_ydot.expr.as_explicit().jacobian([y, ydot])
0115         + new_ydot.expr.as_explicit().jacobian(k1.name) * dk1dyydot.name
0116         + new_ydot.expr.as_explicit().jacobian(k2.name) * dk2dyydot.name
0117         + new_ydot.expr.as_explicit().jacobian(k3.name) * dk3dyydot.name
0118         + new_ydot.expr.as_explicit().jacobian(k4.name) * dk4dyydot.name,
0119     )
0120 
0121     return (
0122         ((new_y, new_ydot), (k1, k2, k3, k4)),
0123         ((dydyydot, dydotdyydot), (dk1dyydot, dk2dyydot, dk3dyydot, dk4dyydot)),
0124         (x2, y2, ydot2, ydot3, x3, y3, ydot4),
0125     )
0126 
0127 
0128 def rk4_fullexpr(f, x, y, ydot, h):
0129     k1 = name_expr("k1", f(1, x, y, ydot))
0130     x2 = name_expr("x2", x + h / 2)
0131     y2 = name_expr("y2", y + h / 2 * ydot + h**2 / 8 * k1.expr)
0132     ydot2 = name_expr("ydot2", ydot + h / 2 * k1.expr)
0133 
0134     k2 = name_expr("k2", f(2, x2.expr, y2.expr, ydot2.expr))
0135     ydot3 = name_expr("ydot3", ydot + h / 2 * k2.expr)
0136 
0137     k3 = name_expr("k3", f(3, x2.expr, y2.expr, ydot3.expr))
0138     x3 = name_expr("x3", x + h)
0139     y3 = name_expr("y3", y + h * ydot + h**2 / 2 * k3.expr)
0140     ydot4 = name_expr("ydot4", ydot + h * k3.expr)
0141 
0142     k4 = name_expr("k4", f(4, x3.expr, y3.expr, ydot4.expr))
0143 
0144     new_y = name_expr("new_y", y + h * ydot + h**2 / 6 * (k1.expr + k2.expr + k3.expr))
0145     new_ydot = name_expr(
0146         "new_ydot", ydot + h / 6 * (k1.expr + 2 * (k2.expr + k3.expr) + k4.expr)
0147     )
0148 
0149     return (new_y, new_ydot), (k1, k2, k3, k4), (x2, y2, ydot2, ydot3, x3, y3, ydot4)
0150 
0151 
0152 def rk4_vacuum_fullexpr2():
0153     def f(x, y, ydot):
0154         d = ydot[0:3, 0]
0155         return d.cross(l * B)
0156 
0157     def decorator(i, ydotdot):
0158         if i == 1:
0159             return ydotdot.subs(B, B1)
0160         if i in [2, 3]:
0161             return ydotdot.subs(B, B2)
0162         if i == 4:
0163             return ydotdot.subs(B, B3)
0164 
0165     (
0166         (new_y, new_ydot),
0167         (dydyydot, dydotdyydot),
0168         (k1, k2, k3, k4),
0169         (x2, y2, ydot2, ydot3, x3, y3, ydot4),
0170     ) = rk4_fullexpr(f, s, p, d, h, decorator)
0171 
0172     p2 = name_expr("p2", y2.expr[0:3, 0])
0173     p3 = name_expr("p3", y3.expr[0:3, 0])
0174     new_p = name_expr("new_p", new_y.expr[0:3, 0])
0175     new_d_tmp = name_expr("new_d_tmp", new_ydot.expr[0:3, 0])
0176     new_d = name_expr("new_d", new_d_tmp.expr / new_d_tmp.expr.norm())
0177     err = name_expr(
0178         "err",
0179         h**2
0180         * (k1.expr[0:3, 0] - k2.expr[0:3, 0] - k3.expr[0:3, 0] + k4.expr[0:3, 0]).norm(
0181             1
0182         ),
0183     )
0184 
0185     dtds = name_expr("dtds", sym.sqrt(1 + m**2 / p_abs**2))
0186     new_t = name_expr("new_t", t + h * dtds.expr)
0187 
0188     path_derivatives = name_expr("path_derivatives", sym.zeros(8, 1))
0189     path_derivatives.expr[0:3, 0] = new_d.expr
0190     path_derivatives.expr[3, 0] = dtds.expr
0191     path_derivatives.expr[4:7, 0] = k4.expr
0192 
0193     D = sym.eye(8)
0194     D[0:3, :] = new_p.expr.jacobian([p, t, d, l])
0195     D[4:7, :] = new_d_tmp.expr.jacobian([p, t, d, l])
0196     D[3, 7] = h * m**2 * l / dtds.expr
0197 
0198     J = MatrixSymbol("J", 8, 8).as_explicit().as_mutable()
0199     for indices in np.ndindex(J.shape):
0200         if D[indices] in [0, 1]:
0201             J[indices] = D[indices]
0202     J = ImmutableMatrix(J)
0203 
0204     new_J = name_expr("new_J", J * D)
0205 
0206     return [p2, p3, err, new_p, new_t, new_d, path_derivatives, new_J]
0207 
0208 
0209 def rk4_vacuum_fullexpr():
0210     k1 = name_expr("k1", d.cross(l * B1))
0211     p2 = name_expr("p2", p + h / 2 * d + h**2 / 8 * k1.expr)
0212 
0213     k2 = name_expr("k2", (d + h / 2 * k1.expr).cross(l * B2))
0214     k3 = name_expr("k3", (d + h / 2 * k2.expr).cross(l * B2))
0215     p3 = name_expr("p3", p + h * d + h**2 / 2 * k3.expr)
0216 
0217     k4 = name_expr("k4", (d + h * k3.expr).cross(l * B3))
0218 
0219     err = name_expr("err", h**2 * (k1.expr - k2.expr - k3.expr + k4.expr).norm(1))
0220 
0221     new_p = name_expr("new_p", p + h * d + h**2 / 6 * (k1.expr + k2.expr + k3.expr))
0222     new_d_tmp = name_expr(
0223         "new_d_tmp", d + h / 6 * (k1.expr + 2 * (k2.expr + k3.expr) + k4.expr)
0224     )
0225     new_d = name_expr("new_d", new_d_tmp.expr / new_d_tmp.expr.norm())
0226 
0227     dtds = name_expr("dtds", sym.sqrt(1 + m**2 / p_abs**2))
0228     new_t = name_expr("new_t", t + h * dtds.expr)
0229 
0230     path_derivatives = name_expr("path_derivatives", sym.zeros(8, 1))
0231     path_derivatives.expr[0:3, 0] = new_d.expr
0232     path_derivatives.expr[3, 0] = dtds.expr
0233     path_derivatives.expr[4:7, 0] = k4.expr
0234 
0235     D = sym.eye(8)
0236     D[0:3, :] = new_p.expr.jacobian([p, t, d, l])
0237     D[4:7, :] = new_d_tmp.expr.jacobian([p, t, d, l])
0238     D[3, 7] = h * m**2 * l / dtds.expr
0239 
0240     J = MatrixSymbol("J", 8, 8).as_explicit().as_mutable()
0241     for indices in np.ndindex(J.shape):
0242         if D[indices] in [0, 1]:
0243             J[indices] = D[indices]
0244     J = ImmutableMatrix(J)
0245 
0246     new_J = name_expr("new_J", J * D)
0247 
0248     return [p2, p3, err, new_p, new_t, new_d, path_derivatives, new_J]
0249 
0250 
0251 def rk4_vacuum_tunedexpr():
0252     k1 = name_expr("k1", d.cross(l * B1))
0253     p2 = name_expr("p2", p + h / 2 * d + h**2 / 8 * k1.name)
0254 
0255     k2 = name_expr("k2", (d + h / 2 * k1.name).as_explicit().cross(l * B2))
0256     k3 = name_expr("k3", (d + h / 2 * k2.name).as_explicit().cross(l * B2))
0257     p3 = name_expr("p3", p + h * d + h**2 / 2 * k3.name)
0258 
0259     k4 = name_expr("k4", (d + h * k3.name).as_explicit().cross(l * B3))
0260 
0261     err = name_expr(
0262         "err", h**2 * (k1.name - k2.name - k3.name + k4.name).as_explicit().norm(1)
0263     )
0264 
0265     new_p = name_expr("new_p", p + h * d + h**2 / 6 * (k1.name + k2.name + k3.name))
0266     new_d_tmp = name_expr(
0267         "new_d_tmp", d + h / 6 * (k1.name + 2 * (k2.name + k3.name) + k4.name)
0268     )
0269     new_d = name_expr("new_d", new_d_tmp.name / new_d_tmp.name.as_explicit().norm())
0270 
0271     dtds = name_expr("dtds", sym.sqrt(1 + m**2 / p_abs**2))
0272     new_t = name_expr("new_t", t + h * dtds.name)
0273 
0274     path_derivatives = name_expr("path_derivatives", sym.zeros(8, 1))
0275     path_derivatives.expr[0:3, 0] = new_d.name.as_explicit()
0276     path_derivatives.expr[3, 0] = dtds.name
0277     path_derivatives.expr[4:7, 0] = k4.name.as_explicit()
0278 
0279     dk1dTL = name_expr("dk1dTL", k1.expr.jacobian([d, l]))
0280     dk2dTL = name_expr(
0281         "dk2dTL", k2.expr.jacobian([d, l]) + k2.expr.jacobian(k1.name) * dk1dTL.expr
0282     )
0283     dk3dTL = name_expr(
0284         "dk3dTL",
0285         k3.expr.jacobian([d, l]) + k3.expr.jacobian(k2.name) * dk2dTL.name,
0286     )
0287     dk4dTL = name_expr(
0288         "dk4dTL",
0289         k4.expr.jacobian([d, l]) + k4.expr.jacobian(k3.name) * dk3dTL.name,
0290     )
0291 
0292     dFdTL = name_expr(
0293         "dFdTL",
0294         new_p.expr.as_explicit().jacobian([d, l])
0295         + new_p.expr.as_explicit().jacobian(k1.name) * dk1dTL.expr
0296         + new_p.expr.as_explicit().jacobian(k2.name) * dk2dTL.name
0297         + new_p.expr.as_explicit().jacobian(k3.name) * dk3dTL.name,
0298     )
0299     dGdTL = name_expr(
0300         "dGdTL",
0301         new_d_tmp.expr.as_explicit().jacobian([d, l])
0302         + new_d_tmp.expr.as_explicit().jacobian(k1.name) * dk1dTL.expr
0303         + new_d_tmp.expr.as_explicit().jacobian(k2.name) * dk2dTL.name
0304         + new_d_tmp.expr.as_explicit().jacobian(k3.name) * dk3dTL.name
0305         + new_d_tmp.expr.as_explicit().jacobian(k4.name) * dk4dTL.name,
0306     )
0307 
0308     D = sym.eye(8)
0309     D[0:3, 4:8] = dFdTL.name.as_explicit()
0310     D[4:7, 4:8] = dGdTL.name.as_explicit()
0311     D[3, 7] = h * m**2 * l / dtds.name
0312 
0313     J = Matrix(MatrixSymbol("J", 8, 8).as_explicit())
0314     for indices in np.ndindex(J.shape):
0315         if D[indices] in [0, 1]:
0316             J[indices] = D[indices]
0317     J = ImmutableMatrix(J)
0318     new_J = name_expr("new_J", J * D)
0319 
0320     return [
0321         k1,
0322         p2,
0323         k2,
0324         k3,
0325         p3,
0326         k4,
0327         err,
0328         new_p,
0329         dtds,
0330         new_t,
0331         new_d_tmp,
0332         new_d,
0333         path_derivatives,
0334         dk2dTL,
0335         dk3dTL,
0336         dk4dTL,
0337         dFdTL,
0338         dGdTL,
0339         new_J,
0340     ]
0341 
0342 
0343 def rk4_dense_tunedexpr():
0344     def f(i, x, y, ydot):
0345         B = [B1, B2, B2, B3][i - 1]
0346         g = [g1, g2, g3, g4][i - 1]
0347 
0348         d = ydot[0:3, 0]
0349         dtds = ydot[3, 0]
0350         l = ydot[4, 0]
0351         return Matrix.vstack(
0352             d.cross(l * B),
0353             Matrix([g * m**2 * l**3 / q**3]),
0354             Matrix([dtds * l**2 * g / q]),
0355         )
0356 
0357     big_l = Symbol("big_l", real=True)
0358     dtds = name_expr("dtds", sym.sqrt(1 + m**2 / p_abs**2))
0359 
0360     (
0361         ((new_y, new_ydot), (k1, k2, k3, k4)),
0362         ((dydyydot, dydotdyydot), (dk1dyydot, dk2dyydot, dk3dyydot, dk4dyydot)),
0363         (x2, y2, ydot2, ydot3, x3, y3, ydot4),
0364     ) = rk4_subexpr(
0365         f,
0366         s,
0367         Matrix.vstack(p, Matrix([t, big_l])),
0368         Matrix.vstack(d, Matrix([dtds.name, l])),
0369         h,
0370     )
0371 
0372     p2 = name_expr("p2", y2.expr[0:3, 0])
0373     l2 = name_expr("l2", ydot2.expr[4, 0])
0374     p3 = name_expr("p3", y3.expr[0:3, 0])
0375     l3 = name_expr("l3", ydot3.expr[4, 0])
0376     l4 = name_expr("l4", ydot4.expr[4, 0])
0377     new_p = name_expr("new_p", new_y.expr[0:3, 0])
0378     new_t = name_expr("new_t", new_y.expr[3, 0])
0379     new_d_tmp = name_expr("new_d_tmp", new_ydot.expr[0:3, 0])
0380     new_d = name_expr("new_d", new_d_tmp.name / new_d_tmp.name.as_explicit().norm())
0381     new_l = name_expr("new_l", new_ydot.expr[4, 0])
0382     err = name_expr(
0383         "err",
0384         h**2
0385         * (k1.name[0:3, 0] - k2.name[0:3, 0] - k3.name[0:3, 0] + k4.name[0:3, 0])
0386         .as_explicit()
0387         .norm(1),
0388     )
0389 
0390     path_derivatives = name_expr("path_derivatives", sym.zeros(8, 1))
0391     path_derivatives.expr[0:3, 0] = new_d.name.as_explicit()
0392     path_derivatives.expr[3, 0] = new_ydot.name[3, 0]
0393     path_derivatives.expr[4:7, 0] = k4.name[0:3, 0].as_explicit()
0394     path_derivatives.expr[7, 0] = new_ydot.name[4, 0]
0395 
0396     dk1dTL = name_expr("dk1dTL", k1.expr.jacobian([t, d, l]))
0397     dk2dTL = name_expr(
0398         "dk2dTL", k2.expr.jacobian([t, d, l]) + k2.expr.jacobian(k1.name) * dk1dTL.expr
0399     )
0400     dk3dTL = name_expr(
0401         "dk3dTL",
0402         k3.expr.jacobian([t, d, l]) + k3.expr.jacobian(k2.name) * dk2dTL.name,
0403     )
0404     dk4dTL = name_expr(
0405         "dk4dTL",
0406         k4.expr.jacobian([t, d, l]) + k4.expr.jacobian(k3.name) * dk3dTL.name,
0407     )
0408 
0409     F = Matrix.vstack(new_p.expr.as_explicit(), Matrix([new_t.expr]))
0410     dFdTL = name_expr(
0411         "dFdTL",
0412         F.jacobian([t, d, l])
0413         + F.jacobian(k1.name) * dk1dTL.expr
0414         + F.jacobian(k2.name) * dk2dTL.name
0415         + F.jacobian(k3.name) * dk3dTL.name,
0416     )
0417     G = Matrix.vstack(new_d_tmp.expr.as_explicit(), Matrix([new_l.expr]))
0418     dGdTL = name_expr(
0419         "dGdTL",
0420         G.jacobian([t, d, l])
0421         + G.jacobian(k1.name) * dk1dTL.expr
0422         + G.jacobian(k2.name) * dk2dTL.name
0423         + G.jacobian(k3.name) * dk3dTL.name
0424         + G.jacobian(k4.name) * dk4dTL.name,
0425     )
0426 
0427     D = sym.eye(8)
0428     D[0:4, 3:8] = dFdTL.name.as_explicit()
0429     D[4:8, 3:8] = dGdTL.name.as_explicit()
0430 
0431     J = Matrix(MatrixSymbol("J", 8, 8).as_explicit())
0432     for indices in np.ndindex(J.shape):
0433         if D[indices] in [0, 1]:
0434             J[indices] = D[indices]
0435     J = ImmutableMatrix(J)
0436     new_J = name_expr("new_J", J * D)
0437 
0438     return [
0439         dtds,
0440         k1,
0441         y2,
0442         ydot2,
0443         ydot3,
0444         p2,
0445         l2,
0446         k2,
0447         l3,
0448         k3,
0449         y3,
0450         ydot4,
0451         p3,
0452         l4,
0453         k4,
0454         err,
0455         new_y,
0456         new_ydot,
0457         new_p,
0458         new_t,
0459         new_d_tmp,
0460         new_d,
0461         new_l,
0462         path_derivatives,
0463         dk2dTL,
0464         dk3dTL,
0465         dk4dTL,
0466         dFdTL,
0467         dGdTL,
0468         new_J,
0469     ]
0470 
0471 
0472 def print_rk4_vacuum(name_exprs, run_cse=True):
0473     printer = cxx_printer
0474     outputs = [
0475         find_by_name(name_exprs, name)[0]
0476         for name in [
0477             "p2",
0478             "p3",
0479             "err",
0480             "new_p",
0481             "new_t",
0482             "new_d",
0483             "path_derivatives",
0484             "new_J",
0485         ]
0486     ]
0487 
0488     lines = []
0489 
0490     head = "template <typename T, typename GetB> Acts::Result<bool> rk4_vacuum(const T* p, const T* d, const T t, const T h, const T l, const T m, const T p_abs, GetB getB, T* err, const T errTol, T* new_p, T* new_t, T* new_d, T* path_derivatives, T* J) {"
0491     lines.append(head)
0492 
0493     lines.append("  const auto B1res = getB(p);")
0494     lines.append(
0495         "  if (!B1res.ok()) {\n    return Acts::Result<bool>::failure(B1res.error());\n  }"
0496     )
0497     lines.append("  const auto B1 = *B1res;")
0498 
0499     def pre_expr_hook(var):
0500         if str(var) == "p2":
0501             return "T p2[3];"
0502         if str(var) == "p3":
0503             return "T p3[3];"
0504         if str(var) == "new_J":
0505             return "T new_J[64];"
0506         return None
0507 
0508     def post_expr_hook(var):
0509         if str(var) == "p2":
0510             return "const auto B2res = getB(p2);\n  if (!B2res.ok()) {\n    return Acts::Result<bool>::failure(B2res.error());\n  }\n  const auto B2 = *B2res;"
0511         if str(var) == "p3":
0512             return "const auto B3res = getB(p3);\n  if (!B3res.ok()) {\n    return Acts::Result<bool>::failure(B3res.error());\n  }\n  const auto B3 = *B3res;"
0513         if str(var) == "err":
0514             return (
0515                 "if (*err > errTol) {\n  return Acts::Result<bool>::success(false);\n}"
0516             )
0517         if str(var) == "new_d":
0518             return "if (J == nullptr) {\n  return Acts::Result<bool>::success(true);\n}"
0519         if str(var) == "new_J":
0520             return printer.doprint(Assignment(MatrixSymbol("J", 8, 8), var))
0521         return None
0522 
0523     code = my_expression_print(
0524         printer,
0525         name_exprs,
0526         outputs,
0527         run_cse=run_cse,
0528         pre_expr_hook=pre_expr_hook,
0529         post_expr_hook=post_expr_hook,
0530     )
0531     lines.extend([f"  {l}" for l in code.split("\n")])
0532 
0533     lines.append("  return Acts::Result<bool>::success(true);")
0534 
0535     lines.append("}")
0536 
0537     return "\n".join(lines)
0538 
0539 
0540 def print_rk4_dense(name_exprs, run_cse=True):
0541     printer = cxx_printer
0542     outputs = [
0543         find_by_name(name_exprs, name)[0]
0544         for name in [
0545             "p2",
0546             "l2",
0547             "l3",
0548             "p3",
0549             "l4",
0550             "err",
0551             "new_p",
0552             "new_t",
0553             "new_d",
0554             "new_l",
0555             "path_derivatives",
0556             "new_J",
0557         ]
0558     ]
0559 
0560     lines = []
0561 
0562     head = "template <typename T, typename GetB, typename GetG> Acts::Result<bool> rk4_dense(const T* p, const T* d, const T t, const T h, const T l, const T m, const T q, const T p_abs, GetB getB, GetG getG, T* err, const T errTol, T* new_p, T* new_t, T* new_d, T* new_l, T* path_derivatives, T* J) {"
0563     lines.append(head)
0564 
0565     lines.append("  const auto B1res = getB(p);")
0566     lines.append(
0567         "  if (!B1res.ok()) {\n    return Acts::Result<bool>::failure(B1res.error());\n  }"
0568     )
0569     lines.append("  const auto B1 = *B1res;")
0570     lines.append("  const auto g1 = getG(p, l);")
0571 
0572     def pre_expr_hook(var):
0573         if str(var) == "p2":
0574             return "T p2[3];"
0575         if str(var) == "p3":
0576             return "T p3[3];"
0577         if str(var) == "l2":
0578             return "T l2[1];"
0579         if str(var) == "l3":
0580             return "T l3[1];"
0581         if str(var) == "l4":
0582             return "T l4[1];"
0583         if str(var) == "new_J":
0584             return "T new_J[64];"
0585         return None
0586 
0587     def post_expr_hook(var):
0588         if str(var) == "p2":
0589             return "const auto B2res = getB(p2);\n  if (!B2res.ok()) {\n    return Acts::Result<bool>::failure(B2res.error());\n  }\n  const auto B2 = *B2res;"
0590         if str(var) == "p3":
0591             return "const auto B3res = getB(p3);\n  if (!B3res.ok()) {\n    return Acts::Result<bool>::failure(B3res.error());\n  }\n  const auto B3 = *B3res;"
0592         if str(var) == "l2":
0593             return "const auto g2 = getG(p2, *l2);"
0594         if str(var) == "l3":
0595             return "const auto g3 = getG(p2, *l3);"
0596         if str(var) == "l4":
0597             return "const auto g4 = getG(p3, *l4);"
0598         if str(var) == "err":
0599             return (
0600                 "if (*err > errTol) {\n  return Acts::Result<bool>::success(false);\n}"
0601             )
0602         if str(var) == "new_d":
0603             return "if (J == nullptr) {\n  return Acts::Result<bool>::success(true);\n}"
0604         if str(var) == "new_J":
0605             return printer.doprint(Assignment(MatrixSymbol("J", 8, 8), var))
0606         return None
0607 
0608     code = my_expression_print(
0609         printer,
0610         name_exprs,
0611         outputs,
0612         run_cse=run_cse,
0613         pre_expr_hook=pre_expr_hook,
0614         post_expr_hook=post_expr_hook,
0615     )
0616     lines.extend([f"  {l}" for l in code.split("\n")])
0617 
0618     lines.append("  return Acts::Result<bool>::success(true);")
0619 
0620     lines.append("}")
0621 
0622     return "\n".join(lines)
0623 
0624 
0625 output.write(
0626     """
0627 // This file is part of the ACTS project.
0628 //
0629 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0630 //
0631 // This Source Code Form is subject to the terms of the Mozilla Public
0632 // License, v. 2.0. If a copy of the MPL was not distributed with this
0633 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0634 
0635 // Note: This file is generated by generate_sympy_stepper.py
0636 //       Do not modify it manually.
0637 
0638 #pragma once
0639 
0640 #include <Acts/Utilities/Result.hpp>
0641 
0642 #include <cmath>
0643 """.strip()
0644 )
0645 
0646 output.write("\n\n")
0647 
0648 all_name_exprs = rk4_vacuum_tunedexpr()
0649 # all_name_exprs = rk4_vacuum_fullexpr()
0650 # all_name_exprs = rk4_vacuum_fullexpr2()
0651 
0652 # attempted manual CSE which turns out a bit slower
0653 #
0654 # sub_name_exprs = [
0655 #     name_expr("hlB1", h * l * B1),
0656 #     name_expr("hlB2", h * l * B2),
0657 #     name_expr("hlB3", h * l * B3),
0658 #     name_expr("lB1", l * B1),
0659 #     name_expr("lB2", l * B2),
0660 #     name_expr("lB3", l * B3),
0661 #     name_expr("h2_2", h**2 / 2),
0662 #     name_expr("h_8", h / 8),
0663 #     name_expr("h_6", h / 6),
0664 #     name_expr("h_2", h / 2),
0665 # ]
0666 # all_name_exprs = [
0667 #     NamedExpr(name, my_subs(expr, sub_name_exprs)) for name, expr in all_name_exprs
0668 # ]
0669 # all_name_exprs.extend(sub_name_exprs)
0670 
0671 code = print_rk4_vacuum(
0672     all_name_exprs,
0673     run_cse=True,
0674 )
0675 output.write(code + "\n")
0676 
0677 output.write("\n")
0678 
0679 all_name_exprs = rk4_dense_tunedexpr()
0680 
0681 code = print_rk4_dense(
0682     all_name_exprs,
0683     run_cse=True,
0684 )
0685 output.write(code + "\n")
0686 
0687 if output is not sys.stdout:
0688     output.close()