File indexing completed on 2026-05-27 07:23:54
0001 from detray_sympy.common import (
0002 cxx_printer,
0003 my_expression_print,
0004 )
0005
0006
0007 def render_dim_requirement(name, j):
0008 if not hasattr(j, "shape"):
0009 return "(detray::concepts::scalar<%s_t>)" % (name)
0010
0011 elif len(j.shape) == 1 or (
0012 len(j.shape) == 2 and j.shape[1] == 1 and j.shape[0] <= 3
0013 ):
0014 if j.shape[0] == 3:
0015 return "(detray::concepts::vector3D<%s_t>)" % (name)
0016 else:
0017 return (
0018 "(detray::concepts::vector<%s_t> && detray::traits::rows<%s_t> == %d)"
0019 % (name, name, j.shape[0])
0020 )
0021 elif len(j.shape) == 2:
0022 if j.shape[0] == j.shape[1]:
0023 return (
0024 "(detray::concepts::square_matrix<%s_t> && detray::traits::max_rank<%s_t> == %d)"
0025 % (name, name, j.shape[0])
0026 )
0027 elif j.shape[0] == 1:
0028 return (
0029 "(detray::concepts::row_matrix<%s_t> && detray::traits::columns<%s_t> == %d)"
0030 % (name, name, j.shape[1])
0031 )
0032 else:
0033 return (
0034 "(detray::concepts::matrix<%s_t> && detray::traits::rows<%s_t> == %d && detray::traits::columns<%s_t> == %d)"
0035 % (name, name, j.shape[0], name, j.shape[1])
0036 )
0037
0038
0039 def gen_cxx_code(function_name, inputs, outputs, run_cse=True, printer=None):
0040 if printer is None:
0041 printer = cxx_printer
0042
0043 template_types = []
0044
0045 for i, j in inputs:
0046 template_types.append("%s_t" % i)
0047 for i, j in outputs:
0048 template_types.append("%s_t" % i)
0049
0050 lines = []
0051
0052 lines.append(
0053 "template <%s>" % (", ".join("typename %s" % s for s in template_types))
0054 )
0055 lines.append("DETRAY_HOST_DEVICE void inline %s (" % function_name)
0056 lines.append(", ".join("const %s_t & %s" % (i, i) for i, _ in inputs) + ",")
0057 lines.append(", ".join("%s_t & %s" % (i, i) for i, _ in outputs))
0058 lines.append(")")
0059 if len(inputs) > 0 or len(outputs) > 0:
0060 lines.append("requires(")
0061 lines.append(
0062 " && ".join(render_dim_requirement(i, j) for i, j in inputs + outputs)
0063 )
0064 lines.append(")")
0065 lines.append("{")
0066
0067 for i, j in inputs:
0068 if not hasattr(j, "shape"):
0069 continue
0070
0071 if len(j.shape) == 1 or (
0072 len(j.shape) == 2 and j.shape[1] == 1 and j.shape[0] <= 3
0073 ):
0074 for k in range(j.shape[0]):
0075 if j[k] == 0:
0076 lines.append(
0077 "assert((getter::element<{idx}>({var}) == 0.f));".format(
0078 var=i, idx=k
0079 )
0080 )
0081 elif j[k] == 1:
0082 lines.append(
0083 "assert((getter::element<{idx}>({var}) == 1.f));".format(
0084 var=i, idx=k
0085 )
0086 )
0087 elif len(j.shape) == 2:
0088 for k in range(j.shape[0]):
0089 for l in range(j.shape[1]):
0090 if j[k, l] == 0:
0091 lines.append(
0092 "assert((getter::element<{idx1}, {idx2}>({var}) == 0.f));".format(
0093 var=i, idx1=k, idx2=l
0094 )
0095 )
0096 elif j[k, l] == 1:
0097 lines.append(
0098 "assert((getter::element<{idx1}, {idx2}>({var}) == 1.f));".format(
0099 var=i, idx1=k, idx2=l
0100 )
0101 )
0102
0103 code = my_expression_print(
0104 printer,
0105 outputs,
0106 [x[0] for x in outputs],
0107 run_cse=run_cse,
0108 )
0109 lines.extend([f" {l}" for l in code.split("\n")])
0110
0111 lines.append("}")
0112
0113 return "\n".join(lines)