File indexing completed on 2026-05-27 07:23:54
0001 import sys
0002 import sympy
0003
0004 from detray_sympy.common import name_expr, cxx_printer_wo_known
0005 from detray_sympy.output import write_out_file
0006 from detray_sympy.codegen import gen_cxx_code
0007 import detray_sympy.matrices
0008
0009
0010 def gen_code(gradient=True):
0011 dFdr = sympy.MatrixSymbol("dFdr", 3, 3).as_explicit().as_mutable()
0012 dGdr = sympy.MatrixSymbol("dGdr", 3, 3).as_explicit().as_mutable()
0013
0014 dFdt = sympy.MatrixSymbol("dFdt", 3, 3).as_explicit().as_mutable()
0015 dGdt = sympy.MatrixSymbol("dGdt", 3, 3).as_explicit().as_mutable()
0016
0017 dFdqop = sympy.MatrixSymbol("dFdqop", 3, 1).as_explicit().as_mutable()
0018 dGdqop = sympy.MatrixSymbol("dGdqop", 3, 1).as_explicit().as_mutable()
0019
0020 dqopqop = sympy.Symbol("dqopqop")
0021
0022 D = detray_sympy.matrices.get_matrix_D(
0023 dFdr, dGdr, dFdqop, dGdqop, dFdt, dGdt, dqopqop, gradient=gradient
0024 )
0025
0026 J_transport = sympy.MatrixSymbol("J_transport", 8, 8).as_explicit().as_mutable()
0027 J_transport = detray_sympy.matrices.add_transport_jacobian_substructure(
0028 J_transport, gradient=gradient
0029 )
0030
0031 new_transport_jacobian = D * J_transport
0032
0033 input_name_exprs = []
0034 input_name_exprs.append(name_expr("J_transport", J_transport))
0035 input_name_exprs.append(name_expr("dFdt", dFdt))
0036 input_name_exprs.append(name_expr("dGdt", dGdt))
0037 if gradient:
0038 input_name_exprs.append(name_expr("dFdr", dFdr))
0039 input_name_exprs.append(name_expr("dGdr", dGdr))
0040 input_name_exprs.append(name_expr("dFdqop", dFdqop))
0041 input_name_exprs.append(name_expr("dGdqop", dGdqop))
0042 input_name_exprs.append(name_expr("dqopqop", dqopqop))
0043 output_name_exprs = [name_expr("new_J", new_transport_jacobian)]
0044 code = gen_cxx_code(
0045 "update_transport_jacobian_"
0046 + ("with_gradient" if gradient else "without_gradient")
0047 + "_impl",
0048 input_name_exprs,
0049 output_name_exprs,
0050 run_cse=True,
0051 printer=cxx_printer_wo_known,
0052 )
0053 return code
0054
0055
0056 if __name__ == "__main__":
0057 if len(sys.argv) > 1:
0058 output = sys.argv[1]
0059 else:
0060 output = None
0061
0062 c1 = gen_code(False)
0063 c2 = gen_code(True)
0064
0065 write_out_file(c1 + c2, output)