Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-27 07:23:54

0001 from detray_sympy.common import (
0002     cxx_printer,
0003     my_expression_print,
0004 )
0005 
0006 
0007 def render_dim_requirement(name, j):
0008     if not hasattr(j, "shape"):
0009         return "(detray::concepts::scalar<%s_t>)" % (name)
0010     # HACK: This is for the 8x1 path-to-free matrix.
0011     elif len(j.shape) == 1 or (
0012         len(j.shape) == 2 and j.shape[1] == 1 and j.shape[0] <= 3
0013     ):
0014         if j.shape[0] == 3:
0015             return "(detray::concepts::vector3D<%s_t>)" % (name)
0016         else:
0017             return (
0018                 "(detray::concepts::vector<%s_t> && detray::traits::rows<%s_t> == %d)"
0019                 % (name, name, j.shape[0])
0020             )
0021     elif len(j.shape) == 2:
0022         if j.shape[0] == j.shape[1]:
0023             return (
0024                 "(detray::concepts::square_matrix<%s_t> && detray::traits::max_rank<%s_t> == %d)"
0025                 % (name, name, j.shape[0])
0026             )
0027         elif j.shape[0] == 1:
0028             return (
0029                 "(detray::concepts::row_matrix<%s_t> && detray::traits::columns<%s_t> == %d)"
0030                 % (name, name, j.shape[1])
0031             )
0032         else:
0033             return (
0034                 "(detray::concepts::matrix<%s_t> && detray::traits::rows<%s_t> == %d && detray::traits::columns<%s_t> == %d)"
0035                 % (name, name, j.shape[0], name, j.shape[1])
0036             )
0037 
0038 
0039 def gen_cxx_code(function_name, inputs, outputs, run_cse=True, printer=None):
0040     if printer is None:
0041         printer = cxx_printer
0042 
0043     template_types = []
0044 
0045     for i, j in inputs:
0046         template_types.append("%s_t" % i)
0047     for i, j in outputs:
0048         template_types.append("%s_t" % i)
0049 
0050     lines = []
0051 
0052     lines.append(
0053         "template <%s>" % (", ".join("typename %s" % s for s in template_types))
0054     )
0055     lines.append("DETRAY_HOST_DEVICE void inline %s (" % function_name)
0056     lines.append(", ".join("const %s_t & %s" % (i, i) for i, _ in inputs) + ",")
0057     lines.append(", ".join("%s_t & %s" % (i, i) for i, _ in outputs))
0058     lines.append(")")
0059     if len(inputs) > 0 or len(outputs) > 0:
0060         lines.append("requires(")
0061         lines.append(
0062             " && ".join(render_dim_requirement(i, j) for i, j in inputs + outputs)
0063         )
0064         lines.append(")")
0065     lines.append("{")
0066 
0067     for i, j in inputs:
0068         if not hasattr(j, "shape"):
0069             continue
0070         # HACK: This is for the 8x1 path-to-free matrix.
0071         if len(j.shape) == 1 or (
0072             len(j.shape) == 2 and j.shape[1] == 1 and j.shape[0] <= 3
0073         ):
0074             for k in range(j.shape[0]):
0075                 if j[k] == 0:
0076                     lines.append(
0077                         "assert((getter::element<{idx}>({var}) == 0.f));".format(
0078                             var=i, idx=k
0079                         )
0080                     )
0081                 elif j[k] == 1:
0082                     lines.append(
0083                         "assert((getter::element<{idx}>({var}) == 1.f));".format(
0084                             var=i, idx=k
0085                         )
0086                     )
0087         elif len(j.shape) == 2:
0088             for k in range(j.shape[0]):
0089                 for l in range(j.shape[1]):
0090                     if j[k, l] == 0:
0091                         lines.append(
0092                             "assert((getter::element<{idx1}, {idx2}>({var}) == 0.f));".format(
0093                                 var=i, idx1=k, idx2=l
0094                             )
0095                         )
0096                     elif j[k, l] == 1:
0097                         lines.append(
0098                             "assert((getter::element<{idx1}, {idx2}>({var}) == 1.f));".format(
0099                                 var=i, idx1=k, idx2=l
0100                             )
0101                         )
0102 
0103     code = my_expression_print(
0104         printer,
0105         outputs,
0106         [x[0] for x in outputs],
0107         run_cse=run_cse,
0108     )
0109     lines.extend([f"  {l}" for l in code.split("\n")])
0110 
0111     lines.append("}")
0112 
0113     return "\n".join(lines)