Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-04 07:58:04

0001 import sys
0002 
0003 import numpy as np
0004 
0005 import sympy as sym
0006 from sympy import Symbol, Matrix, ImmutableMatrix, MatrixSymbol
0007 from sympy.codegen.ast import Assignment
0008 
0009 from codegen.sympy_common import (
0010     NamedExpr,
0011     name_expr,
0012     find_by_name,
0013     my_subs,
0014     cxx_printer,
0015     my_expression_print,
0016 )
0017 
0018 
0019 output = sys.stdout
0020 if len(sys.argv) > 1:
0021     output = open(sys.argv[1], "w")
0022 
0023 
0024 # lambda for q/p
0025 l = Symbol("lambda", real=True)
0026 
0027 # h for step length
0028 h = Symbol("h", real=True)
0029 
0030 # p for position
0031 p = MatrixSymbol("p", 3, 1)
0032 
0033 # d for direction
0034 d = MatrixSymbol("d", 3, 1)
0035 
0036 # time
0037 t = Symbol("t", real=True)
0038 
0039 # mass
0040 m = Symbol("m", real=True)
0041 
0042 # absolute momentum
0043 p_abs = Symbol("p_abs", real=True, positive=True)
0044 
0045 # magnetic field
0046 B1 = MatrixSymbol("B1", 3, 1)
0047 B2 = MatrixSymbol("B2", 3, 1)
0048 B3 = MatrixSymbol("B3", 3, 1)
0049 
0050 
0051 def rk4_full_math():
0052     k1 = name_expr("k1", d.as_explicit().cross(l * B1))
0053     p2 = name_expr("p2", p + h / 2 * d + h**2 / 8 * k1.expr)
0054 
0055     k2 = name_expr("k2", (d + h / 2 * k1.expr).as_explicit().cross(l * B2))
0056     k3 = name_expr("k3", (d + h / 2 * k2.expr).as_explicit().cross(l * B2))
0057     p3 = name_expr("p3", p + h * d + h**2 / 2 * k3.expr)
0058 
0059     k4 = name_expr("k4", (d + h * k3.expr).as_explicit().cross(l * B3))
0060 
0061     err = name_expr("err", h**2 * (k1.expr - k2.expr - k3.expr + k4.expr).norm(1))
0062 
0063     new_p = name_expr("new_p", p + h * d + h**2 / 6 * (k1.expr + k2.expr + k3.expr))
0064     new_d_tmp = name_expr(
0065         "new_d_tmp", d + h / 6 * (k1.expr + 2 * (k2.expr + k3.expr) + k4.expr)
0066     )
0067     new_d = name_expr("new_d", new_d_tmp.expr / new_d_tmp.expr.as_explicit().norm())
0068 
0069     dtds = name_expr("dtds", sym.sqrt(1 + m**2 / p_abs**2))
0070     new_time = name_expr("new_time", t + h * dtds.expr)
0071 
0072     path_derivatives = name_expr("path_derivatives", sym.zeros(8, 1))
0073     path_derivatives.expr[0:3, 0] = new_d.expr.as_explicit()
0074     path_derivatives.expr[3, 0] = dtds.expr
0075     path_derivatives.expr[4:7, 0] = k4.expr.as_explicit()
0076 
0077     D = sym.eye(8)
0078     D[0:3, :] = new_p.expr.as_explicit().jacobian([p, t, d, l])
0079     D[4:7, :] = new_d_tmp.expr.as_explicit().jacobian([p, t, d, l])
0080     D[3, 7] = h * m**2 * l / dtds.expr
0081 
0082     J = MatrixSymbol("J", 8, 8).as_explicit().as_mutable()
0083     for indices in np.ndindex(J.shape):
0084         if D[indices] in [0, 1]:
0085             J[indices] = D[indices]
0086     J = ImmutableMatrix(J)
0087 
0088     new_J = name_expr("new_J", J * D)
0089 
0090     return [p2, p3, err, new_p, new_d, new_time, path_derivatives, new_J]
0091 
0092 
0093 def rk4_short_math():
0094     k1 = name_expr("k1", d.as_explicit().cross(l * B1))
0095     p2 = name_expr("p2", p + h / 2 * d + h**2 / 8 * k1.name)
0096 
0097     k2 = name_expr("k2", (d + h / 2 * k1.name).as_explicit().cross(l * B2))
0098     k3 = name_expr("k3", (d + h / 2 * k2.name).as_explicit().cross(l * B2))
0099     p3 = name_expr("p3", p + h * d + h**2 / 2 * k3.name)
0100 
0101     k4 = name_expr("k4", (d + h * k3.name).as_explicit().cross(l * B3))
0102 
0103     err = name_expr(
0104         "err", h**2 * (k1.name - k2.name - k3.name + k4.name).as_explicit().norm(1)
0105     )
0106 
0107     new_p = name_expr("new_p", p + h * d + h**2 / 6 * (k1.name + k2.name + k3.name))
0108     new_d_tmp = name_expr(
0109         "new_d_tmp", d + h / 6 * (k1.name + 2 * (k2.name + k3.name) + k4.name)
0110     )
0111     new_d = name_expr("new_d", new_d_tmp.name / new_d_tmp.name.as_explicit().norm())
0112 
0113     dtds = name_expr("dtds", sym.sqrt(1 + m**2 / p_abs**2))
0114     new_time = name_expr("new_time", t + h * dtds.name)
0115 
0116     path_derivatives = name_expr("path_derivatives", sym.zeros(8, 1))
0117     path_derivatives.expr[0:3, 0] = new_d.name.as_explicit()
0118     path_derivatives.expr[3, 0] = dtds.name
0119     path_derivatives.expr[4:7, 0] = k4.name.as_explicit()
0120 
0121     dk1dTL = name_expr("dk1dTL", k1.expr.jacobian([d, l]))
0122     dk2dTL = name_expr(
0123         "dk2dTL", k2.expr.jacobian([d, l]) + k2.expr.jacobian(k1.name) * dk1dTL.expr
0124     )
0125     dk3dTL = name_expr(
0126         "dk3dTL",
0127         k3.expr.jacobian([d, l])
0128         + k3.expr.jacobian(k2.name) * dk2dTL.name.as_explicit(),
0129     )
0130     dk4dTL = name_expr(
0131         "dk4dTL",
0132         k4.expr.jacobian([d, l])
0133         + k4.expr.jacobian(k3.name) * dk3dTL.name.as_explicit(),
0134     )
0135 
0136     dFdTL = name_expr(
0137         "dFdTL",
0138         new_p.expr.as_explicit().jacobian([d, l])
0139         + new_p.expr.as_explicit().jacobian(k1.name) * dk1dTL.expr
0140         + new_p.expr.as_explicit().jacobian(k2.name) * dk2dTL.name.as_explicit()
0141         + new_p.expr.as_explicit().jacobian(k3.name) * dk3dTL.name.as_explicit(),
0142     )
0143     dGdTL = name_expr(
0144         "dGdTL",
0145         new_d_tmp.expr.as_explicit().jacobian([d, l])
0146         + new_d_tmp.expr.as_explicit().jacobian(k1.name) * dk1dTL.expr
0147         + new_d_tmp.expr.as_explicit().jacobian(k2.name) * dk2dTL.name.as_explicit()
0148         + new_d_tmp.expr.as_explicit().jacobian(k3.name) * dk3dTL.name.as_explicit()
0149         + new_d_tmp.expr.as_explicit().jacobian(k4.name) * dk4dTL.name.as_explicit(),
0150     )
0151 
0152     D = sym.eye(8)
0153     D[0:3, 4:8] = dFdTL.name.as_explicit()
0154     D[4:7, 4:8] = dGdTL.name.as_explicit()
0155     D[3, 7] = h * m**2 * l / dtds.name
0156 
0157     J = Matrix(MatrixSymbol("J", 8, 8).as_explicit())
0158     for indices in np.ndindex(J.shape):
0159         if D[indices] in [0, 1]:
0160             J[indices] = D[indices]
0161     J = ImmutableMatrix(J)
0162     new_J = name_expr("new_J", J * D)
0163 
0164     return [
0165         k1,
0166         p2,
0167         k2,
0168         k3,
0169         p3,
0170         k4,
0171         err,
0172         new_p,
0173         new_d_tmp,
0174         new_d,
0175         dtds,
0176         new_time,
0177         path_derivatives,
0178         dk2dTL,
0179         dk3dTL,
0180         dk4dTL,
0181         dFdTL,
0182         dGdTL,
0183         new_J,
0184     ]
0185 
0186 
0187 def my_step_function_print(name_exprs, run_cse=True):
0188     printer = cxx_printer
0189     outputs = [
0190         find_by_name(name_exprs, name)[0]
0191         for name in [
0192             "p2",
0193             "p3",
0194             "err",
0195             "new_p",
0196             "new_d",
0197             "new_time",
0198             "path_derivatives",
0199             "new_J",
0200         ]
0201     ]
0202 
0203     lines = []
0204 
0205     head = "template <typename T, typename GetB> Acts::Result<bool> rk4(const T* p, const T* d, const T t, const T h, const T lambda, const T m, const T p_abs, GetB getB, T* err, const T errTol, T* new_p, T* new_d, T* new_time, T* path_derivatives, T* J) {"
0206     lines.append(head)
0207 
0208     lines.append("  const auto B1res = getB(p);")
0209     lines.append(
0210         "  if (!B1res.ok()) {\n    return Acts::Result<bool>::failure(B1res.error());\n  }"
0211     )
0212     lines.append("  const auto B1 = *B1res;")
0213 
0214     def pre_expr_hook(var):
0215         if str(var) == "p2":
0216             return "T p2[3];"
0217         if str(var) == "p3":
0218             return "T p3[3];"
0219         if str(var) == "new_J":
0220             return "T new_J[64];"
0221         return None
0222 
0223     def post_expr_hook(var):
0224         if str(var) == "p2":
0225             return "const auto B2res = getB(p2);\n  if (!B2res.ok()) {\n    return Acts::Result<bool>::failure(B2res.error());\n  }\n  const auto B2 = *B2res;"
0226         if str(var) == "p3":
0227             return "const auto B3res = getB(p3);\n  if (!B3res.ok()) {\n    return Acts::Result<bool>::failure(B3res.error());\n  }\n  const auto B3 = *B3res;"
0228         if str(var) == "err":
0229             return (
0230                 "if (*err > errTol) {\n  return Acts::Result<bool>::success(false);\n}"
0231             )
0232         if str(var) == "new_time":
0233             return "if (J == nullptr) {\n  return Acts::Result<bool>::success(true);\n}"
0234         if str(var) == "new_J":
0235             return printer.doprint(Assignment(MatrixSymbol("J", 8, 8), var))
0236         return None
0237 
0238     code = my_expression_print(
0239         printer,
0240         name_exprs,
0241         outputs,
0242         run_cse=run_cse,
0243         pre_expr_hook=pre_expr_hook,
0244         post_expr_hook=post_expr_hook,
0245     )
0246     lines.extend([f"  {l}" for l in code.split("\n")])
0247 
0248     lines.append("  return Acts::Result<bool>::success(true);")
0249 
0250     lines.append("}")
0251 
0252     return "\n".join(lines)
0253 
0254 
0255 all_name_exprs = rk4_short_math()
0256 
0257 # attempted manual CSE which turns out a bit slower
0258 #
0259 # sub_name_exprs = [
0260 #     name_expr("hlB1", h * l * B1),
0261 #     name_expr("hlB2", h * l * B2),
0262 #     name_expr("hlB3", h * l * B3),
0263 #     name_expr("lB1", l * B1),
0264 #     name_expr("lB2", l * B2),
0265 #     name_expr("lB3", l * B3),
0266 #     name_expr("h2_2", h**2 / 2),
0267 #     name_expr("h_8", h / 8),
0268 #     name_expr("h_6", h / 6),
0269 #     name_expr("h_2", h / 2),
0270 # ]
0271 # all_name_exprs = [
0272 #     NamedExpr(name, my_subs(expr, sub_name_exprs)) for name, expr in all_name_exprs
0273 # ]
0274 # all_name_exprs.extend(sub_name_exprs)
0275 
0276 code = my_step_function_print(
0277     all_name_exprs,
0278     run_cse=True,
0279 )
0280 
0281 output.write(
0282     """// This file is part of the ACTS project.
0283 //
0284 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0285 //
0286 // This Source Code Form is subject to the terms of the Mozilla Public
0287 // License, v. 2.0. If a copy of the MPL was not distributed with this
0288 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0289 
0290 // Note: This file is generated by generate_sympy_stepper.py
0291 //       Do not modify it manually.
0292 
0293 #pragma once
0294 
0295 #include <Acts/Utilities/Result.hpp>
0296 
0297 #include <cmath>
0298 """
0299 )
0300 output.write(code + "\n")
0301 
0302 if output is not sys.stdout:
0303     output.close()