Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:10:44

0001 import numpy as np
0002 
0003 import sympy as sym
0004 from sympy import Symbol, Matrix, ImmutableMatrix, MatrixSymbol
0005 from sympy.codegen.ast import Assignment
0006 
0007 from sympy_common import (
0008     NamedExpr,
0009     name_expr,
0010     find_by_name,
0011     my_subs,
0012     cxx_printer,
0013     my_expression_print,
0014 )
0015 
0016 
0017 # lambda for q/p
0018 l = Symbol("lambda", real=True)
0019 
0020 # h for step length
0021 h = Symbol("h", real=True)
0022 
0023 # p for position
0024 p = MatrixSymbol("p", 3, 1)
0025 
0026 # d for direction
0027 d = MatrixSymbol("d", 3, 1)
0028 
0029 # time
0030 t = Symbol("t", real=True)
0031 
0032 # mass
0033 m = Symbol("m", real=True)
0034 
0035 # absolute momentum
0036 p_abs = Symbol("p_abs", real=True, positive=True)
0037 
0038 # magnetic field
0039 B1 = MatrixSymbol("B1", 3, 1)
0040 B2 = MatrixSymbol("B2", 3, 1)
0041 B3 = MatrixSymbol("B3", 3, 1)
0042 
0043 
0044 def rk4_full_math():
0045     k1 = name_expr("k1", d.as_explicit().cross(l * B1))
0046     p2 = name_expr("p2", p + h / 2 * d + h**2 / 8 * k1.expr)
0047 
0048     k2 = name_expr("k2", (d + h / 2 * k1.expr).as_explicit().cross(l * B2))
0049     k3 = name_expr("k3", (d + h / 2 * k2.expr).as_explicit().cross(l * B2))
0050     p3 = name_expr("p3", p + h * d + h**2 / 2 * k3.expr)
0051 
0052     k4 = name_expr("k4", (d + h * k3.expr).as_explicit().cross(l * B3))
0053 
0054     err = name_expr("err", h**2 * (k1.expr - k2.expr - k3.expr + k4.expr).norm(1))
0055 
0056     new_p = name_expr("new_p", p + h * d + h**2 / 6 * (k1.expr + k2.expr + k3.expr))
0057     new_d_tmp = name_expr(
0058         "new_d_tmp", d + h / 6 * (k1.expr + 2 * (k2.expr + k3.expr) + k4.expr)
0059     )
0060     new_d = name_expr("new_d", new_d_tmp.expr / new_d_tmp.expr.as_explicit().norm())
0061 
0062     dtds = name_expr("dtds", sym.sqrt(1 + m**2 / p_abs**2))
0063     new_time = name_expr("new_time", t + h * dtds.expr)
0064 
0065     path_derivatives = name_expr("path_derivatives", sym.zeros(8, 1))
0066     path_derivatives.expr[0:3, 0] = new_d.expr.as_explicit()
0067     path_derivatives.expr[3, 0] = dtds.expr
0068     path_derivatives.expr[4:7, 0] = k4.expr.as_explicit()
0069 
0070     D = sym.eye(8)
0071     D[0:3, :] = new_p.expr.as_explicit().jacobian([p, t, d, l])
0072     D[4:7, :] = new_d_tmp.expr.as_explicit().jacobian([p, t, d, l])
0073     D[3, 7] = h * m**2 * l / dtds.expr
0074 
0075     J = MatrixSymbol("J", 8, 8).as_explicit().as_mutable()
0076     for indices in np.ndindex(J.shape):
0077         if D[indices] in [0, 1]:
0078             J[indices] = D[indices]
0079     J = ImmutableMatrix(J)
0080 
0081     new_J = name_expr("new_J", J * D)
0082 
0083     return [p2, p3, err, new_p, new_d, new_time, path_derivatives, new_J]
0084 
0085 
0086 def rk4_short_math():
0087     k1 = name_expr("k1", d.as_explicit().cross(l * B1))
0088     p2 = name_expr("p2", p + h / 2 * d + h**2 / 8 * k1.name)
0089 
0090     k2 = name_expr("k2", (d + h / 2 * k1.name).as_explicit().cross(l * B2))
0091     k3 = name_expr("k3", (d + h / 2 * k2.name).as_explicit().cross(l * B2))
0092     p3 = name_expr("p3", p + h * d + h**2 / 2 * k3.name)
0093 
0094     k4 = name_expr("k4", (d + h * k3.name).as_explicit().cross(l * B3))
0095 
0096     err = name_expr(
0097         "err", h**2 * (k1.name - k2.name - k3.name + k4.name).as_explicit().norm(1)
0098     )
0099 
0100     new_p = name_expr("new_p", p + h * d + h**2 / 6 * (k1.name + k2.name + k3.name))
0101     new_d_tmp = name_expr(
0102         "new_d_tmp", d + h / 6 * (k1.name + 2 * (k2.name + k3.name) + k4.name)
0103     )
0104     new_d = name_expr("new_d", new_d_tmp.name / new_d_tmp.name.as_explicit().norm())
0105 
0106     dtds = name_expr("dtds", sym.sqrt(1 + m**2 / p_abs**2))
0107     new_time = name_expr("new_time", t + h * dtds.name)
0108 
0109     path_derivatives = name_expr("path_derivatives", sym.zeros(8, 1))
0110     path_derivatives.expr[0:3, 0] = new_d.name.as_explicit()
0111     path_derivatives.expr[3, 0] = dtds.name
0112     path_derivatives.expr[4:7, 0] = k4.name.as_explicit()
0113 
0114     dk1dTL = name_expr("dk1dTL", k1.expr.jacobian([d, l]))
0115     dk2dTL = name_expr(
0116         "dk2dTL", k2.expr.jacobian([d, l]) + k2.expr.jacobian(k1.name) * dk1dTL.expr
0117     )
0118     dk3dTL = name_expr(
0119         "dk3dTL",
0120         k3.expr.jacobian([d, l])
0121         + k3.expr.jacobian(k2.name) * dk2dTL.name.as_explicit(),
0122     )
0123     dk4dTL = name_expr(
0124         "dk4dTL",
0125         k4.expr.jacobian([d, l])
0126         + k4.expr.jacobian(k3.name) * dk3dTL.name.as_explicit(),
0127     )
0128 
0129     dFdTL = name_expr(
0130         "dFdTL",
0131         new_p.expr.as_explicit().jacobian([d, l])
0132         + new_p.expr.as_explicit().jacobian(k1.name) * dk1dTL.expr
0133         + new_p.expr.as_explicit().jacobian(k2.name) * dk2dTL.name.as_explicit()
0134         + new_p.expr.as_explicit().jacobian(k3.name) * dk3dTL.name.as_explicit(),
0135     )
0136     dGdTL = name_expr(
0137         "dGdTL",
0138         new_d_tmp.expr.as_explicit().jacobian([d, l])
0139         + new_d_tmp.expr.as_explicit().jacobian(k1.name) * dk1dTL.expr
0140         + new_d_tmp.expr.as_explicit().jacobian(k2.name) * dk2dTL.name.as_explicit()
0141         + new_d_tmp.expr.as_explicit().jacobian(k3.name) * dk3dTL.name.as_explicit()
0142         + new_d_tmp.expr.as_explicit().jacobian(k4.name) * dk4dTL.name.as_explicit(),
0143     )
0144 
0145     D = sym.eye(8)
0146     D[0:3, 4:8] = dFdTL.name.as_explicit()
0147     D[4:7, 4:8] = dGdTL.name.as_explicit()
0148     D[3, 7] = h * m**2 * l / dtds.name
0149 
0150     J = Matrix(MatrixSymbol("J", 8, 8).as_explicit())
0151     for indices in np.ndindex(J.shape):
0152         if D[indices] in [0, 1]:
0153             J[indices] = D[indices]
0154     J = ImmutableMatrix(J)
0155     new_J = name_expr("new_J", J * D)
0156 
0157     return [
0158         k1,
0159         p2,
0160         k2,
0161         k3,
0162         p3,
0163         k4,
0164         err,
0165         new_p,
0166         new_d_tmp,
0167         new_d,
0168         dtds,
0169         new_time,
0170         path_derivatives,
0171         dk2dTL,
0172         dk3dTL,
0173         dk4dTL,
0174         dFdTL,
0175         dGdTL,
0176         new_J,
0177     ]
0178 
0179 
0180 def my_step_function_print(name_exprs, run_cse=True):
0181     printer = cxx_printer
0182     outputs = [
0183         find_by_name(name_exprs, name)[0]
0184         for name in [
0185             "p2",
0186             "p3",
0187             "err",
0188             "new_p",
0189             "new_d",
0190             "new_time",
0191             "path_derivatives",
0192             "new_J",
0193         ]
0194     ]
0195 
0196     lines = []
0197 
0198     head = "template <typename T, typename GetB> Acts::Result<bool> rk4(const T* p, const T* d, const T t, const T h, const T lambda, const T m, const T p_abs, GetB getB, T* err, const T errTol, T* new_p, T* new_d, T* new_time, T* path_derivatives, T* J) {"
0199     lines.append(head)
0200 
0201     lines.append("  const auto B1res = getB(p);")
0202     lines.append(
0203         "  if (!B1res.ok()) {\n    return Acts::Result<bool>::failure(B1res.error());\n  }"
0204     )
0205     lines.append("  const auto B1 = *B1res;")
0206 
0207     def pre_expr_hook(var):
0208         if str(var) == "p2":
0209             return "T p2[3];"
0210         if str(var) == "p3":
0211             return "T p3[3];"
0212         if str(var) == "new_J":
0213             return "T new_J[64];"
0214         return None
0215 
0216     def post_expr_hook(var):
0217         if str(var) == "p2":
0218             return "const auto B2res = getB(p2);\n  if (!B2res.ok()) {\n    return Acts::Result<bool>::failure(B2res.error());\n  }\n  const auto B2 = *B2res;"
0219         if str(var) == "p3":
0220             return "const auto B3res = getB(p3);\n  if (!B3res.ok()) {\n    return Acts::Result<bool>::failure(B3res.error());\n  }\n  const auto B3 = *B3res;"
0221         if str(var) == "err":
0222             return (
0223                 "if (*err > errTol) {\n  return Acts::Result<bool>::success(false);\n}"
0224             )
0225         if str(var) == "new_time":
0226             return "if (J == nullptr) {\n  return Acts::Result<bool>::success(true);\n}"
0227         if str(var) == "new_J":
0228             return printer.doprint(Assignment(MatrixSymbol("J", 8, 8), var))
0229         return None
0230 
0231     code = my_expression_print(
0232         printer,
0233         name_exprs,
0234         outputs,
0235         run_cse=run_cse,
0236         pre_expr_hook=pre_expr_hook,
0237         post_expr_hook=post_expr_hook,
0238     )
0239     lines.extend([f"  {l}" for l in code.split("\n")])
0240 
0241     lines.append("  return Acts::Result<bool>::success(true);")
0242 
0243     lines.append("}")
0244 
0245     return "\n".join(lines)
0246 
0247 
0248 all_name_exprs = rk4_short_math()
0249 
0250 # attempted manual CSE which turns out a bit slower
0251 #
0252 # sub_name_exprs = [
0253 #     name_expr("hlB1", h * l * B1),
0254 #     name_expr("hlB2", h * l * B2),
0255 #     name_expr("hlB3", h * l * B3),
0256 #     name_expr("lB1", l * B1),
0257 #     name_expr("lB2", l * B2),
0258 #     name_expr("lB3", l * B3),
0259 #     name_expr("h2_2", h**2 / 2),
0260 #     name_expr("h_8", h / 8),
0261 #     name_expr("h_6", h / 6),
0262 #     name_expr("h_2", h / 2),
0263 # ]
0264 # all_name_exprs = [
0265 #     NamedExpr(name, my_subs(expr, sub_name_exprs)) for name, expr in all_name_exprs
0266 # ]
0267 # all_name_exprs.extend(sub_name_exprs)
0268 
0269 code = my_step_function_print(
0270     all_name_exprs,
0271     run_cse=True,
0272 )
0273 
0274 print(
0275     """// This file is part of the ACTS project.
0276 //
0277 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0278 //
0279 // This Source Code Form is subject to the terms of the Mozilla Public
0280 // License, v. 2.0. If a copy of the MPL was not distributed with this
0281 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0282 
0283 // Note: This file is generated by generate_sympy_stepper.py
0284 //       Do not modify it manually.
0285 
0286 #pragma once
0287 
0288 #include <Acts/Utilities/Result.hpp>
0289 
0290 #include <cmath>
0291 """
0292 )
0293 print(code)