Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-27 07:23:54

0001 import sys
0002 import sympy
0003 
0004 from detray_sympy.common import name_expr, cxx_printer_wo_known
0005 from detray_sympy.output import write_out_file
0006 from detray_sympy.codegen import gen_cxx_code
0007 import detray_sympy.matrices
0008 
0009 
0010 def gen_code(gradient=True):
0011     dFdr = sympy.MatrixSymbol("dFdr", 3, 3).as_explicit().as_mutable()
0012     dGdr = sympy.MatrixSymbol("dGdr", 3, 3).as_explicit().as_mutable()
0013 
0014     dFdt = sympy.MatrixSymbol("dFdt", 3, 3).as_explicit().as_mutable()
0015     dGdt = sympy.MatrixSymbol("dGdt", 3, 3).as_explicit().as_mutable()
0016 
0017     dFdqop = sympy.MatrixSymbol("dFdqop", 3, 1).as_explicit().as_mutable()
0018     dGdqop = sympy.MatrixSymbol("dGdqop", 3, 1).as_explicit().as_mutable()
0019 
0020     dqopqop = sympy.Symbol("dqopqop")
0021 
0022     D = detray_sympy.matrices.get_matrix_D(
0023         dFdr, dGdr, dFdqop, dGdqop, dFdt, dGdt, dqopqop, gradient=gradient
0024     )
0025 
0026     J_transport = sympy.MatrixSymbol("J_transport", 8, 8).as_explicit().as_mutable()
0027     J_transport = detray_sympy.matrices.add_transport_jacobian_substructure(
0028         J_transport, gradient=gradient
0029     )
0030 
0031     new_transport_jacobian = D * J_transport
0032 
0033     input_name_exprs = []
0034     input_name_exprs.append(name_expr("J_transport", J_transport))
0035     input_name_exprs.append(name_expr("dFdt", dFdt))
0036     input_name_exprs.append(name_expr("dGdt", dGdt))
0037     if gradient:
0038         input_name_exprs.append(name_expr("dFdr", dFdr))
0039         input_name_exprs.append(name_expr("dGdr", dGdr))
0040     input_name_exprs.append(name_expr("dFdqop", dFdqop))
0041     input_name_exprs.append(name_expr("dGdqop", dGdqop))
0042     input_name_exprs.append(name_expr("dqopqop", dqopqop))
0043     output_name_exprs = [name_expr("new_J", new_transport_jacobian)]
0044     code = gen_cxx_code(
0045         "update_transport_jacobian_"
0046         + ("with_gradient" if gradient else "without_gradient")
0047         + "_impl",
0048         input_name_exprs,
0049         output_name_exprs,
0050         run_cse=True,
0051         printer=cxx_printer_wo_known,
0052     )
0053     return code
0054 
0055 
0056 if __name__ == "__main__":
0057     if len(sys.argv) > 1:
0058         output = sys.argv[1]
0059     else:
0060         output = None
0061 
0062     c1 = gen_code(False)
0063     c2 = gen_code(True)
0064 
0065     write_out_file(c1 + c2, output)