Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:10:44

0001 import numpy as np
0002 
0003 import sympy as sym
0004 from sympy import MatrixSymbol
0005 
0006 from sympy_common import name_expr, find_by_name, cxx_printer, my_expression_print
0007 
0008 
0009 C = MatrixSymbol("C", 6, 6).as_explicit().as_mutable()
0010 for indices in np.ndindex(C.shape):
0011     C[indices] = C[tuple(sorted(indices))]
0012 
0013 J_full = MatrixSymbol("J_full", 6, 6).as_explicit().as_mutable()
0014 tmp = sym.eye(6)
0015 tmp[0:4, 0:5] = J_full[0:4, 0:5]
0016 tmp[5:6, 0:5] = J_full[5:6, 0:5]
0017 J_full = tmp
0018 
0019 
0020 def covariance_transport_generic():
0021     new_C = name_expr("new_C", J_full * C * J_full.T)
0022 
0023     return [new_C]
0024 
0025 
0026 def my_covariance_transport_generic_function_print(name_exprs, run_cse=True):
0027     printer = cxx_printer
0028     outputs = [find_by_name(name_exprs, name)[0] for name in ["new_C"]]
0029 
0030     lines = []
0031 
0032     head = "template <typename T> void transportCovarianceToBoundImpl(const T* C, const T* J_full, T* new_C) {"
0033     lines.append(head)
0034 
0035     code = my_expression_print(
0036         printer,
0037         name_exprs,
0038         outputs,
0039         run_cse=run_cse,
0040     )
0041     lines.extend([f"  {l}" for l in code.split("\n")])
0042 
0043     lines.append("}")
0044 
0045     return "\n".join(lines)
0046 
0047 
0048 print(
0049     """// This file is part of the ACTS project.
0050 //
0051 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0052 //
0053 // This Source Code Form is subject to the terms of the Mozilla Public
0054 // License, v. 2.0. If a copy of the MPL was not distributed with this
0055 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0056 
0057 // Note: This file is generated by generate_sympy_cov.py
0058 //       Do not modify it manually.
0059 
0060 #pragma once
0061 
0062 #include <cmath>
0063 """
0064 )
0065 
0066 all_name_exprs = covariance_transport_generic()
0067 code = my_covariance_transport_generic_function_print(
0068     all_name_exprs,
0069     run_cse=True,
0070 )
0071 print(code)