File indexing completed on 2026-04-12 07:34:39
0001 import sys
0002 import numpy as np
0003
0004 import sympy as sym
0005 from sympy import MatrixSymbol
0006
0007
0008 from codegen.sympy_common import (
0009 name_expr,
0010 find_by_name,
0011 cxx_printer,
0012 my_expression_print,
0013 )
0014
0015 output = sys.stdout
0016 if len(sys.argv) > 1:
0017 output = open(sys.argv[1], "w")
0018
0019
0020 C = MatrixSymbol("C", 6, 6).as_explicit().as_mutable()
0021 for indices in np.ndindex(C.shape):
0022 C[indices] = C[tuple(sorted(indices))]
0023
0024 J_full = MatrixSymbol("J_full", 6, 6).as_explicit().as_mutable()
0025 tmp = sym.eye(6)
0026 tmp[0:4, 0:5] = J_full[0:4, 0:5]
0027 tmp[5:6, 0:5] = J_full[5:6, 0:5]
0028 J_full = tmp
0029
0030
0031 def covariance_transport_generic():
0032 new_C = name_expr("new_C", J_full * C * J_full.T)
0033
0034 return [new_C]
0035
0036
0037 def my_covariance_transport_generic_function_print(name_exprs, run_cse=True):
0038 printer = cxx_printer
0039 outputs = [find_by_name(name_exprs, name)[0] for name in ["new_C"]]
0040
0041 lines = []
0042
0043 head = "template <typename T> void transportCovarianceToBoundImpl(const T* C, const T* J_full, T* new_C) {"
0044 lines.append(head)
0045
0046 code = my_expression_print(
0047 printer,
0048 name_exprs,
0049 outputs,
0050 run_cse=run_cse,
0051 )
0052 lines.extend([f" {l}" for l in code.split("\n")])
0053
0054 lines.append("}")
0055
0056 return "\n".join(lines)
0057
0058
0059 output.write("""// This file is part of the ACTS project.
0060 //
0061 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0062 //
0063 // This Source Code Form is subject to the terms of the Mozilla Public
0064 // License, v. 2.0. If a copy of the MPL was not distributed with this
0065 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0066
0067 // Note: This file is generated by generate_sympy_cov.py
0068 // Do not modify it manually.
0069
0070 #pragma once
0071
0072 #include <cmath>
0073 """)
0074
0075 all_name_exprs = covariance_transport_generic()
0076 code = my_covariance_transport_generic_function_print(
0077 all_name_exprs,
0078 run_cse=True,
0079 )
0080 output.write(code + "\n")
0081
0082 if output is not sys.stdout:
0083 output.close()