File indexing completed on 2025-01-18 09:10:44
0001 import numpy as np
0002
0003 import sympy as sym
0004 from sympy import MatrixSymbol
0005
0006 from sympy_common import name_expr, find_by_name, cxx_printer, my_expression_print
0007
0008
0009 C = MatrixSymbol("C", 6, 6).as_explicit().as_mutable()
0010 for indices in np.ndindex(C.shape):
0011 C[indices] = C[tuple(sorted(indices))]
0012
0013 J_full = MatrixSymbol("J_full", 6, 6).as_explicit().as_mutable()
0014 tmp = sym.eye(6)
0015 tmp[0:4, 0:5] = J_full[0:4, 0:5]
0016 tmp[5:6, 0:5] = J_full[5:6, 0:5]
0017 J_full = tmp
0018
0019
0020 def covariance_transport_generic():
0021 new_C = name_expr("new_C", J_full * C * J_full.T)
0022
0023 return [new_C]
0024
0025
0026 def my_covariance_transport_generic_function_print(name_exprs, run_cse=True):
0027 printer = cxx_printer
0028 outputs = [find_by_name(name_exprs, name)[0] for name in ["new_C"]]
0029
0030 lines = []
0031
0032 head = "template <typename T> void transportCovarianceToBoundImpl(const T* C, const T* J_full, T* new_C) {"
0033 lines.append(head)
0034
0035 code = my_expression_print(
0036 printer,
0037 name_exprs,
0038 outputs,
0039 run_cse=run_cse,
0040 )
0041 lines.extend([f" {l}" for l in code.split("\n")])
0042
0043 lines.append("}")
0044
0045 return "\n".join(lines)
0046
0047
0048 print(
0049 """// This file is part of the ACTS project.
0050 //
0051 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0052 //
0053 // This Source Code Form is subject to the terms of the Mozilla Public
0054 // License, v. 2.0. If a copy of the MPL was not distributed with this
0055 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0056
0057 // Note: This file is generated by generate_sympy_cov.py
0058 // Do not modify it manually.
0059
0060 #pragma once
0061
0062 #include <cmath>
0063 """
0064 )
0065
0066 all_name_exprs = covariance_transport_generic()
0067 code = my_covariance_transport_generic_function_print(
0068 all_name_exprs,
0069 run_cse=True,
0070 )
0071 print(code)