File indexing completed on 2025-04-10 07:58:29
0001 import sys
0002 import numpy as np
0003
0004 import sympy as sym
0005 from sympy import MatrixSymbol
0006
0007
0008 from codegen.sympy_common import (
0009 name_expr,
0010 find_by_name,
0011 cxx_printer,
0012 my_expression_print,
0013 )
0014
0015
0016 output = sys.stdout
0017 if len(sys.argv) > 1:
0018 output = open(sys.argv[1], "w")
0019
0020
0021 C = MatrixSymbol("C", 6, 6).as_explicit().as_mutable()
0022 for indices in np.ndindex(C.shape):
0023 C[indices] = C[tuple(sorted(indices))]
0024
0025 J_full = MatrixSymbol("J_full", 6, 6).as_explicit().as_mutable()
0026 tmp = sym.eye(6)
0027 tmp[0:4, 0:5] = J_full[0:4, 0:5]
0028 tmp[5:6, 0:5] = J_full[5:6, 0:5]
0029 J_full = tmp
0030
0031
0032 def covariance_transport_generic():
0033 new_C = name_expr("new_C", J_full * C * J_full.T)
0034
0035 return [new_C]
0036
0037
0038 def my_covariance_transport_generic_function_print(name_exprs, run_cse=True):
0039 printer = cxx_printer
0040 outputs = [find_by_name(name_exprs, name)[0] for name in ["new_C"]]
0041
0042 lines = []
0043
0044 head = "template <typename T> void transportCovarianceToBoundImpl(const T* C, const T* J_full, T* new_C) {"
0045 lines.append(head)
0046
0047 code = my_expression_print(
0048 printer,
0049 name_exprs,
0050 outputs,
0051 run_cse=run_cse,
0052 )
0053 lines.extend([f" {l}" for l in code.split("\n")])
0054
0055 lines.append("}")
0056
0057 return "\n".join(lines)
0058
0059
0060 output.write(
0061 """// This file is part of the ACTS project.
0062 //
0063 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0064 //
0065 // This Source Code Form is subject to the terms of the Mozilla Public
0066 // License, v. 2.0. If a copy of the MPL was not distributed with this
0067 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0068
0069 // Note: This file is generated by generate_sympy_cov.py
0070 // Do not modify it manually.
0071
0072 #pragma once
0073
0074 #include <cmath>
0075 """
0076 )
0077
0078 all_name_exprs = covariance_transport_generic()
0079 code = my_covariance_transport_generic_function_print(
0080 all_name_exprs,
0081 run_cse=True,
0082 )
0083 output.write(code + "\n")
0084
0085 if output is not sys.stdout:
0086 output.close()