File indexing completed on 2026-05-27 07:24:10
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 from .type_helpers import link
0011 from .definitions import (
0012 Type,
0013 Algebra,
0014 Shape,
0015 Material,
0016 Accelerator,
0017 GridBin,
0018 GridSerializer,
0019 )
0020
0021
0022 from collections import Counter
0023 from datetime import datetime
0024 import itertools
0025 import logging
0026 import numbers
0027 import os
0028 from typing import Optional
0029 import subprocess
0030
0031 """ Class that represents the c++ metadata struct with all of its types """
0032
0033
0034 class metadata:
0035
0036 def __init__(
0037 self,
0038 detector_name,
0039 ):
0040 self.logger = logging.getLogger(__name__)
0041 self.det_name = detector_name
0042
0043 self.algebra = Algebra.ANY
0044
0045 self.precision = None
0046
0047 self.id_base = Type.UINT_LEAST_8
0048
0049 self.nav_link = link(link_type="single", data_type=Type.UINT_LEAST_16)
0050
0051 self.mask_link = link(link_type="range", data_type=Type.UINT_32)
0052
0053 self.material_link = link(link_type="single", data_type=Type.UINT_32)
0054
0055 self.surface_types = []
0056
0057 self.shapes = {}
0058
0059 self.materials = {}
0060
0061 self.default_surface_accel = Accelerator.BRUTE_FORCE
0062 self.default_volume_accel = Accelerator.BRUTE_FORCE
0063
0064 self.acceleration_structs = {}
0065
0066 self.logger.info(f'Detector: "{self.det_name}"')
0067
0068
0069 def __resolve_id_priority(self, type_dict, type_id: int = -1):
0070 is_valid_id = type_id >= 0
0071 next_auto_id = (
0072 0 if not type_dict else max(len(type_dict), max(type_dict.keys()) + 1)
0073 )
0074 priority = type_id if is_valid_id else next_auto_id
0075 is_clash = priority in type_dict
0076
0077 priority = priority + 1 if is_clash else priority
0078
0079
0080 if is_clash:
0081 new_type_dict = {}
0082 for k, v in sorted(type_dict.items()):
0083 new_type_dict[k + 1 if k >= priority else k] = type_dict[k]
0084
0085
0086 type_dict, new_type_dict = new_type_dict, type_dict
0087
0088 return priority, type_dict
0089
0090
0091 def set_algebra_plugin(self, plugin: Algebra, precision: Optional[Type] = None):
0092 self.algebra = plugin
0093
0094 if self.algebra is not Algebra.ANY:
0095 if precision is not Type.SINGLE or precision is not Type.DOUBLE:
0096 self.logger.warning(
0097 f'Incorrect precision "{precision}" for algebra type. Using "float" instead'
0098 )
0099 self.precision = Type.SINGLE
0100 else:
0101 self.precision = precision
0102 self.logger.info(f"Algebra plugin: {self.algebra}<{self.precision}>")
0103 elif precision is not None:
0104 self.logger.warning(
0105 f'Precision "{precision}" will be ignored for generic algebra type'
0106 )
0107
0108
0109
0110 def add_portal(self, shape: Shape, type_id: int = -1):
0111 self.add_shape(shape, type_id)
0112 if "portal" not in self.surface_types:
0113 self.surface_types.append("portal")
0114
0115
0116
0117 def add_sensitive(self, shape: Shape, type_id: int = -1):
0118 self.add_shape(shape, type_id)
0119 if "sensitive" not in self.surface_types:
0120 self.surface_types.append("sensitive")
0121
0122
0123
0124 def add_passive(self, shape: Shape, type_id: int = -1):
0125 self.add_shape(shape, type_id)
0126 if "passive" not in self.surface_types:
0127 self.surface_types.append("passive")
0128
0129
0130
0131 def add_shape(self, shape: Shape, type_id: int):
0132 if shape not in itertools.chain(*self.shapes.values()):
0133 self.logger.debug(f'--> surface shape "{shape.specifier}"')
0134
0135
0136 i, self.shapes = self.__resolve_id_priority(self.shapes, type_id)
0137 self.shapes.setdefault(i, []).append(shape)
0138
0139
0140 def add_material(self, mat: Material, type_id: int = -1):
0141 if mat not in itertools.chain(*self.materials.values()):
0142 is_hom_mat = (
0143 mat is Material.SLAB or mat is Material.ROD or mat is Material.RAW
0144 )
0145
0146 if is_hom_mat:
0147 self.logger.debug(f'--> material type "{mat.specifier}"')
0148 else:
0149 shape_secifier = mat.param["shape"].specifier
0150 self.logger.debug(
0151 f'--> material type "{mat.specifier}<{shape_secifier}>"'
0152 )
0153
0154
0155 i, self.materials = self.__resolve_id_priority(self.materials, type_id)
0156
0157
0158
0159 if not is_hom_mat:
0160 compatible_types = [
0161 k
0162 for k, v in self.materials.items()
0163 if "frame" in v[0].param
0164 and v[0].param["frame"].specifier == mat.param["frame"].specifier
0165 ]
0166 compatible_types.sort()
0167 i = compatible_types[0] if len(compatible_types) > 0 else i
0168
0169 self.materials.setdefault(i, []).append(mat)
0170
0171
0172 def add_accel_struct(
0173 self,
0174 accel: Accelerator,
0175 obj_type: str = "sensitive",
0176 type_id: int = -1,
0177 value_type: str = "surface",
0178 is_default: bool = False,
0179 grid_bin: GridBin = GridBin.DYNAMIC,
0180 grid_serializer: GridSerializer = GridSerializer.SIMPLE,
0181 ):
0182
0183
0184 if "grid" in accel.specifier:
0185 accel.param["bin"] = grid_bin
0186 accel.param["serialiser"] = grid_serializer
0187
0188
0189 value_type = "index" if "volume" in obj_type else value_type
0190 chosen = False
0191 if obj_type not in self.acceleration_structs.keys():
0192 self.acceleration_structs[obj_type] = [(accel, type_id, value_type)]
0193 chosen = True
0194 elif (accel, type_id, value_type) not in self.acceleration_structs[obj_type]:
0195 chosen = True
0196 self.acceleration_structs[obj_type].append((accel, type_id, value_type))
0197
0198 if chosen:
0199 shape_secifier = (
0200 "" if "grid" not in accel.specifier else accel.param["shape"].specifier
0201 )
0202 accel_type = (
0203 f"{accel.specifier}"
0204 if "grid" not in accel.specifier
0205 else f"{accel.specifier}<{shape_secifier}>"
0206 )
0207
0208 self.logger.debug(f"--> accel. struct ({obj_type}): {accel_type}")
0209
0210 if is_default:
0211 self.set_default_accel_struct(accel, obj_type, value_type)
0212
0213
0214
0215
0216 def set_default_accel_struct(
0217 self,
0218 accel: Accelerator,
0219 obj_type: str,
0220 type_id: int = -1,
0221 value_type: str = "surface",
0222 ):
0223
0224 value_type = "index" if "volume" in obj_type else value_type
0225
0226 shape_secifier = (
0227 "" if "grid" not in accel.specifier else accel.param["shape"].specifier
0228 )
0229 accel_type = (
0230 f"{accel.specifier}"
0231 if "grid" not in accel.specifier
0232 else f"{accel.specifier}<{shape_secifier}>"
0233 )
0234
0235 if obj_type == "portal":
0236 self.logger.debug(
0237 f"--> setting default surface accel. struct: {accel_type}"
0238 )
0239 self.default_surface_accel = accel
0240 elif obj_type == "volume":
0241 self.logger.debug(f"--> setting default volume accel. struct: {accel_type}")
0242 self.default_volume_accel = accel
0243 else:
0244 self.logger.warning(
0245 f'Cannot set default acceleration structure for geometry object type "{obj_type}"'
0246 )
0247
0248
0249 if obj_type not in self.acceleration_structs.keys() or (
0250 accel,
0251 value_type,
0252 ) not in [(a, v) for a, i, v in self.acceleration_structs[obj_type]]:
0253 self.logger.warning(
0254 f"Requested default acceleration structure ({obj_type}, {accel_type}) not defined in metadata: Adding it now..."
0255 )
0256 self.add_accel_struct(accel, obj_type, type_id, value_type)
0257
0258
0259 """ Class that accumulates detector type data and writes a metadata header """
0260
0261
0262 class metadata_generator:
0263
0264 def __init__(
0265 self,
0266 md: metadata,
0267 output="../Detray/detectors/include/detray/detectors/",
0268 format_header=True,
0269 ):
0270
0271 self.out_dir = output
0272 self.file = None
0273 self.logger = logging.getLogger(__name__)
0274 self.format_header = format_header
0275 self.indent = 0
0276
0277
0278 self.shape_specifiers = {}
0279 self.material_specifiers = {}
0280 self.accel_specifiers = {}
0281 self.grid_specifiers = []
0282
0283
0284 root_dir = '#include "detray'
0285 self.__common_includes = f'#pragma once\n\n// Project include(s)\n\
0286 {root_dir}/core/detail/multi_store.hpp"\n\
0287 {root_dir}/core/detail/single_store.hpp"\n\
0288 {root_dir}/definitions/algebra.hpp"\n\
0289 {root_dir}/definitions/containers.hpp"\n\
0290 {root_dir}/definitions/indexing.hpp"\n\
0291 {root_dir}/geometry/mask.hpp"\n\
0292 {root_dir}/geometry/surface_descriptor.hpp"\n'
0293
0294
0295 self.__generate(md, self.out_dir)
0296
0297
0298 def __generate(
0299 self, md: metadata, src_dir="../Detray/detectors/include/detray/detectors/"
0300 ):
0301
0302 det_name = md.det_name.replace("_detector", "")
0303 filename = f"{os.path.abspath(src_dir)}/{det_name}_metadata.hpp"
0304 self.logger.debug(f'Open "{filename}"')
0305 self.file = open(filename, "w+")
0306
0307 self.logger.info("Generating metadata...")
0308
0309
0310 self.__preamble(md)
0311
0312
0313 self.__local_typedefs(md)
0314
0315
0316 self.logger.info(" -> Transforms")
0317 self.__declare_transform_store()
0318
0319
0320 self.logger.info(" -> Masks")
0321 self.__declare_mask_store(md)
0322
0323
0324 self.logger.info(" -> Material")
0325 self.__declare_material_store(md)
0326
0327
0328
0329 self.__lines(2)
0330 self.__declare_surface_descriptor(md)
0331
0332
0333 self.logger.info(" -> Acceleration Structures")
0334 self.__declare_accel_store(md)
0335
0336
0337 self.__declare_geometry_objects(md)
0338
0339
0340 self.__finish()
0341
0342 self.logger.info(f'Finished metadata for "{md.det_name} detector"')
0343
0344
0345 def __put(self, string):
0346 self.file.write(f"{self.__tabs()}{string}")
0347
0348
0349 def __lines(self, n):
0350 self.file.write("\n" * n)
0351
0352
0353 def __tabs(self):
0354 return "\t" * self.indent
0355
0356
0357 def __typedef(self, name, type):
0358 self.__put(f"using {name} = {type};\n")
0359
0360
0361 def __template_list(self, params):
0362 if params and "" in params:
0363 params.remove("")
0364 return "" if not params else f"template <{','.join(params)}>"
0365
0366
0367 def __name_from_specifier(self, specifier):
0368
0369 tokens = specifier.split("<")
0370 tp_str = tokens[0]
0371
0372 tokens = []
0373
0374
0375 tokens = tp_str.split(":")
0376
0377 return tokens[-1]
0378
0379
0380 def __preamble(self, md: metadata):
0381 copy_right = "\
0382 // This file is part of the ACTS project.\
0383 //\
0384 // Copyright (C) 2016 CERN for the benefit of the ACTS project\
0385 //\
0386 // This Source Code Form is subject to the terms of the Mozilla Public\
0387 // License, v. 2.0. If a copy of the MPL was not distributed with this\
0388 // file, You can obtain one at https://mozilla.org/MPL/2.0/."
0389 self.__put(copy_right)
0390 self.__lines(2)
0391 self.__add_header_includes(md.shapes, md.materials, md.acceleration_structs)
0392 self.__lines(2)
0393 self.__put("namespace detray {")
0394 self.__lines(2)
0395
0396 template_params = (
0397 f"{self.__template_list([str(md.algebra)])}\n"
0398 if md.algebra == Algebra.ANY
0399 else ""
0400 )
0401 struct_def = f"{template_params}struct {md.det_name}_metadata {{"
0402 self.__put(struct_def)
0403 self.indent = self.indent + 1
0404 self.__lines(2)
0405
0406
0407 def __add_header_includes(self, shapes, materials, accel_structs):
0408
0409 self.__put(self.__common_includes)
0410
0411
0412 shape_names = {
0413 self.__name_from_specifier(s.specifier)
0414 for s in itertools.chain(*shapes.values())
0415 }
0416
0417
0418 add_line = False
0419 if "line_circular" in shape_names:
0420 shape_names.remove("line_circular")
0421 add_line = True
0422 if "line_square" in shape_names:
0423 shape_names.remove("line_square")
0424 add_line = True
0425
0426 if add_line and "line" not in shape_names:
0427 shape_names.add("line")
0428
0429
0430 mat_names = {
0431 self.__name_from_specifier(m.specifier)
0432 for m in itertools.chain(*materials.values())
0433 }
0434
0435 if "material_map" in mat_names:
0436 shape_names.update(
0437 (
0438 self.__name_from_specifier(m.param["shape"].specifier)
0439 if m.param
0440 else None
0441 )
0442 for m in itertools.chain(*materials.values())
0443 )
0444
0445
0446 accel_names = {
0447 self.__name_from_specifier(a.specifier)
0448 for accels in accel_structs.values()
0449 for a, i, v in accels
0450 }
0451 if "spatial_grid" in accel_names:
0452
0453 shape_names.update(
0454 (
0455 self.__name_from_specifier(a.param["shape"].specifier)
0456 if a.param
0457 else None
0458 )
0459 for accels in accel_structs.values()
0460 for a, i, v in accels
0461 )
0462
0463
0464 root_dir = '#include "detray'
0465 shape_names.discard(None)
0466 [self.__put(f'{root_dir}/geometry/shapes/{n}.hpp"\n') for n in shape_names]
0467
0468 [self.__put(f'{root_dir}/material/{n}.hpp"\n') for n in mat_names]
0469
0470 [
0471 self.__put(f'{root_dir}/navigation/accelerators/{n}.hpp"\n')
0472 for n in accel_names
0473 ]
0474
0475
0476 def __local_typedefs(self, md: metadata):
0477 self.__typedef(
0478 "algebra_type",
0479 (
0480 "algebra_t"
0481 if md.algebra == Algebra.ANY
0482 else f"{md.algebra}<{md.precision}>"
0483 ),
0484 )
0485 self.__typedef("scalar_t", "dscalar<algebra_type>")
0486 self.__lines(1)
0487 self.__typedef("nav_link", md.nav_link.data_type)
0488
0489
0490 def __declare_transform_store(self):
0491 self.__lines(1)
0492 self.__put(
0493 f"\
0494 template <template <typename...> class vector_t = dvector>\n\
0495 {self.__tabs()}using transform_store =\n\
0496 {self.__tabs()} single_store<dtransform3D<algebra_type>, vector_t, geometry_context>;"
0497 )
0498
0499
0500 def __declare_type_enum(self, specifier, items, base_type, extra_items={}):
0501 self.__put(f"enum class {specifier} : {base_type} {{\n")
0502
0503 self.indent = self.indent + 1
0504 for i, values in itertools.chain(items.items(), extra_items.items()):
0505 for v in values:
0506
0507 item = v[0] if isinstance(v, list) else v
0508 tmp_value = v[1] if isinstance(v, list) else i
0509 value = f"{i}u" if isinstance(tmp_value, numbers.Number) else tmp_value
0510
0511
0512 sub_specifier = f"e_{item[:-2]}" if item.endswith("_t") else f"e_{item}"
0513 self.__put(f"{sub_specifier} = {value},\n")
0514
0515 self.indent = self.indent - 1
0516
0517 self.__put("};")
0518
0519
0520 def __define_enum_stream_op(self, specifier, items, extra_items={}):
0521 self.__put(
0522 f"DETRAY_HOST inline friend std::ostream& operator<<(std::ostream& os, {specifier} id) {{\n"
0523 )
0524
0525 self.indent = self.indent + 1
0526
0527 self.__put("switch (id) {\n")
0528
0529 self.indent = self.indent + 1
0530 for i, values in itertools.chain(items.items(), extra_items.items()):
0531
0532 value = ""
0533 item = ""
0534 sub_specifier = ""
0535
0536 for v in values:
0537
0538 item = v[0] if isinstance(v, list) else v
0539 value = v[1] if isinstance(v, list) else f"{i}u"
0540
0541 item = f"e_{item[:-2]}" if item.endswith("_t") else f"e_{item}"
0542 sub_specifier = (
0543 item if sub_specifier == "" else f"{sub_specifier}/{item}"
0544 )
0545
0546 self.__put(f"case {specifier}::{item}:\n")
0547 self.indent = self.indent + 1
0548 self.__put(f'os << "{value if isinstance(v, list) else sub_specifier}";\n')
0549 self.__put("break;\n")
0550 self.indent = self.indent - 1
0551
0552
0553 self.__put("default:\n")
0554 self.indent = self.indent + 1
0555 self.__put('os << "invalid";\n')
0556 self.indent = self.indent - 1
0557
0558 self.indent = self.indent - 1
0559
0560 self.__put("}\n")
0561 self.__put("return os;\n")
0562
0563 self.indent = self.indent - 1
0564
0565 self.__put("};")
0566
0567
0568 def __declare_multi_store(
0569 self, specifier, id_name, types, context="empty_context", is_regular=True
0570 ):
0571 template_params = (
0572 "template<typename...> class vector_t = dvector"
0573 if is_regular
0574 else "typename container_t = host_container_types"
0575 )
0576 self.__put(f"{self.__template_list([template_params])}\n")
0577 self.__put(f"using {specifier} =\n")
0578
0579 if is_regular:
0580 self.__put(
0581 f"\tregular_multi_store<{id_name}, {context}, dtuple, vector_t, {','.join(itertools.chain(*types.values()))}>;"
0582 )
0583 else:
0584 type_list = []
0585
0586 flattened_types = [v[0] for k, v in types.items()]
0587
0588
0589 if "surface_brute_force_t" in flattened_types:
0590 flattened_types = [
0591 t for t in flattened_types if t != "surface_brute_force_t"
0592 ]
0593 type_list.append("brute_force_collection<surface_type, container_t>")
0594
0595
0596 volume_brute_force = ""
0597 if "volume_brute_force_t" in flattened_types:
0598 flattened_types = [
0599 t for t in flattened_types if t != "volume_brute_force_t"
0600 ]
0601 volume_brute_force = "brute_force_collection<dindex, container_t>"
0602
0603 type_list = type_list + [
0604 (
0605 f"grid_collection<{t}<container_t>>"
0606 if (t.endswith("map_t") or t.endswith("grid_t"))
0607 else f"typename container_t::template vector_type<{t}>"
0608 )
0609 for t in flattened_types
0610 ]
0611
0612 if volume_brute_force:
0613 type_list.append(volume_brute_force)
0614
0615 self.__put(
0616 f"\tmulti_store<{id_name}, {context}, dtuple, {','.join(type_list)}>;"
0617 )
0618
0619
0620 def __declare_surface_descriptor(self, md: metadata):
0621 self.__put("using transform_link = typename transform_store<>::single_link;\n")
0622
0623 link_type = (
0624 "single_link" if md.mask_link.link_type == "single" else "range_link"
0625 )
0626 mask_link = f"typename mask_store<>::{link_type}"
0627 self.__put(f"using mask_link = {mask_link};\n")
0628
0629 link_type = (
0630 "single_link" if md.material_link.link_type == "single" else "range_link"
0631 )
0632 material_link = f"typename material_store<>::{link_type}"
0633 self.__put(f"using material_link = {material_link};\n")
0634
0635 self.__put(
0636 "using surface_type = surface_descriptor<mask_link, material_link, transform_link, nav_link>;"
0637 )
0638
0639
0640 def __declare_mask(self, shape, type_id, algebra, link, shape_params={}):
0641 type_specifier = f"{self.__name_from_specifier(shape.specifier)}_t"
0642
0643
0644 if type_specifier == "line_square_t" or type_specifier == "line_circular_t":
0645 type_specifier = (
0646 "drift_cell_t" if type_specifier == "line_square_t" else "straw_tube_t"
0647 )
0648
0649 self.shape_specifiers.setdefault(type_id, []).append(type_specifier)
0650
0651
0652 params = [str(v).lower() for v in shape_params.values()]
0653 template_params = "" if not shape_params else f"<{','.join(params)}>"
0654
0655 self.__put(
0656 f"using {type_specifier} = mask<{shape.specifier}{template_params}, {algebra}, {link}>;\n"
0657 )
0658
0659
0660 def __declare_mask_store(self, md: metadata):
0661
0662 assert md.shapes, "Define at least one geometric shape"
0663
0664 self.__lines(2)
0665
0666 for type_id, (_, shapes) in enumerate(sorted(md.shapes.items())):
0667 for shape in shapes:
0668 self.__declare_mask(
0669 shape, type_id, "algebra_type", "nav_link", shape.param
0670 )
0671
0672 self.__lines(1)
0673 self.__declare_type_enum("mask_id", self.shape_specifiers, md.id_base)
0674 self.__lines(2)
0675 self.__define_enum_stream_op("mask_id", self.shape_specifiers)
0676 self.__lines(2)
0677
0678
0679 self.__declare_multi_store("mask_store", "mask_id", self.shape_specifiers)
0680
0681
0682 def __declare_material(self, mat, type_id, algebra):
0683 type_specifier = f"{self.__name_from_specifier(mat.specifier)}_t"
0684
0685 if mat is Material.SLAB:
0686 self.__put(f"using {type_specifier} = material_slab<scalar_t>;\n")
0687 elif mat is Material.ROD:
0688 self.__put(f"using {type_specifier} = material_rod<scalar_t>;\n")
0689 elif mat is Material.RAW:
0690 type_specifier = "raw_material_t"
0691 self.__put(f"using {type_specifier} = material<scalar_t>;\n")
0692 else:
0693 shape_specifier = mat.param["shape"].specifier
0694 shape_type = self.__name_from_specifier(shape_specifier)
0695 type_specifier = f"{shape_type}_map_t"
0696 template_list = "template <typename container_t>\n"
0697
0698
0699 if type_id not in self.material_specifiers:
0700 self.__put(
0701 f"{template_list}{self.__tabs()}using {type_specifier} = material_map<{algebra}, {shape_specifier}, container_t>;\n"
0702 )
0703
0704 self.material_specifiers.setdefault(type_id, []).append(type_specifier)
0705
0706
0707 def __declare_material_store(self, md: metadata):
0708
0709 if md.materials:
0710 self.__lines(2)
0711
0712 for type_id, (_, materials) in enumerate(sorted(md.materials.items())):
0713 for mat in materials:
0714 self.__declare_material(mat, type_id, "algebra_type")
0715
0716 self.__lines(1)
0717
0718 self.__declare_type_enum(
0719 "material_id",
0720 self.material_specifiers,
0721 md.id_base,
0722 {len(md.materials): ["none"]},
0723 )
0724 self.__lines(2)
0725 self.__define_enum_stream_op(
0726 "material_id", self.material_specifiers, {len(md.materials): ["none"]}
0727 )
0728 self.__lines(2)
0729
0730
0731 self.__declare_multi_store(
0732 "material_store", "material_id", self.material_specifiers, is_regular=False
0733 )
0734
0735
0736
0737 def __declare_accel(self, obj_type, acc, type_id: int, value_type: str):
0738 type_specifier = f"{self.__name_from_specifier(acc.specifier)}_t"
0739
0740 if type_specifier.endswith("grid_t"):
0741
0742
0743
0744
0745 shape_specifier = acc.param["shape"].specifier
0746 shape_name = self.__name_from_specifier(shape_specifier)
0747
0748
0749 entry_type = "surface_type" if value_type == "surface" else "dindex"
0750
0751
0752 grid_bin = acc.param["bin"].specifier
0753 bin_name = self.__name_from_specifier(grid_bin).removesuffix("_array")
0754 bin_capacity = ""
0755 if "static" in grid_bin:
0756 bin_capacity = acc.param["bin"].param["capacity"]
0757 bin_type_param = self.__template_list(
0758 ["bin_entry_t", f"{bin_capacity}"]
0759 ).removeprefix("template ")
0760
0761
0762 grid_serializer = acc.param["serializer"].specifier
0763 serializer_name = self.__name_from_specifier(grid_serializer).removesuffix(
0764 "_serializer"
0765 )
0766
0767 grid_specifier = f"{bin_name}_{serializer_name}_grid_t"
0768
0769
0770 if grid_specifier not in self.grid_specifiers:
0771 self.__lines(2)
0772 template_list = "template <typename axes_t, typename bin_entry_t, typename container_t>\n"
0773 self.__put(
0774 f"{template_list}{self.__tabs()}using {grid_specifier} = spatial_grid<algebra_type, axes_t, {grid_bin}{bin_type_param}, {grid_serializer}, container_t, false>;"
0775 )
0776 self.grid_specifiers.append(grid_specifier)
0777 self.__lines(2)
0778
0779
0780 type_specifier = f"{shape_name}_grid_t"
0781 template_list = "template <typename container_t>\n"
0782
0783 self.__put(
0784 f"{template_list}{self.__tabs()}using {obj_type}_{type_specifier} = {grid_specifier}<axes<{shape_specifier}>, {entry_type}, container_t>;\n"
0785 )
0786
0787 self.accel_specifiers.setdefault(type_id, []).append(
0788 f"{obj_type}_{type_specifier}"
0789 )
0790
0791
0792 def __declare_accel_store(self, md: metadata):
0793
0794 assert (
0795 len(md.acceleration_structs) > 2
0796 ), "Define at least one default surface(portal) and one default volume acceleration structure"
0797
0798 assert (
0799 "portal" in md.acceleration_structs.keys()
0800 ), "Define at least one portal acceleration structure"
0801
0802 assert (
0803 "volume" in md.acceleration_structs.keys()
0804 ), "Define at least one volume acceleration structure"
0805
0806 if not md.acceleration_structs:
0807 return
0808
0809 self.__lines(2)
0810
0811
0812 unique_accel = []
0813 for geo_obj, accels in md.acceleration_structs.items():
0814 obj_type = "volume" if "volume" in geo_obj else "surface"
0815 for acc, type_id, value_type in accels:
0816 if (obj_type, acc, type_id, value_type) not in unique_accel:
0817 unique_accel.append((obj_type, acc, type_id, value_type))
0818
0819
0820 tmp_specifiers = {}
0821 for obj_type, acc, type_id, value_type in unique_accel:
0822
0823 is_valid_id = type_id >= 0
0824 next_auto_id = (
0825 0
0826 if not tmp_specifiers
0827 else max(len(tmp_specifiers), max(tmp_specifiers) + 1)
0828 )
0829 priority = type_id if is_valid_id else next_auto_id
0830 is_clash = priority in tmp_specifiers
0831
0832 priority = priority + 1 if is_clash else priority
0833
0834
0835 if is_clash:
0836 new_type_dict = {}
0837 for k, v in sorted(tmp_specifiers.items()):
0838 new_type_dict[k + 1 if k >= priority else k] = tmp_specifiers[k]
0839
0840
0841 tmp_specifiers, new_type_dict = new_type_dict, tmp_specifiers
0842
0843
0844 if (obj_type, acc, value_type) in tmp_specifiers.values():
0845 for p, v in tmp_specifiers.items():
0846 if v == (obj_type, acc, value_type) and priority < p:
0847 tmp_specifiers[priority] = (obj_type, acc, value_type)
0848 del tmp_specifiers[p]
0849 else:
0850 tmp_specifiers[priority] = (obj_type, acc, value_type)
0851
0852 del unique_accel
0853
0854
0855 for i, (_, values) in enumerate(sorted(tmp_specifiers.items())):
0856 obj_type, acc, value_type = values
0857 self.__declare_accel(obj_type, acc, i, value_type)
0858
0859 del tmp_specifiers
0860
0861
0862 self.__lines(1)
0863
0864
0865 surface_default = ""
0866 if "grid" in md.default_surface_accel.specifier:
0867 shape_specifier = md.default_surface_accel.param["shape"].specifier
0868 surface_default = (
0869 f"e_surface_{self.__name_from_specifier(shape_specifier)}_grid"
0870 )
0871 else:
0872 surface_default = f"e_surface_{self.__name_from_specifier(md.default_surface_accel.specifier)}"
0873
0874
0875 volume_default = ""
0876 if "grid" in md.default_volume_accel.specifier:
0877 shape_specifier = md.default_volume_accel.param["shape"].specifier
0878 volume_default = (
0879 f"e_volume_{self.__name_from_specifier(shape_specifier)}_grid"
0880 )
0881 else:
0882 volume_default = f"e_volume_{self.__name_from_specifier(md.default_volume_accel.specifier)}"
0883
0884 extra_items = {
0885 surface_default: ["surface_default"],
0886 volume_default: ["volume_default"],
0887 }
0888
0889
0890 self.__declare_type_enum(
0891 "accel_id",
0892 self.accel_specifiers,
0893 md.id_base,
0894 extra_items=extra_items,
0895 )
0896
0897
0898 self.__lines(2)
0899 self.__define_enum_stream_op(
0900 "accel_id",
0901 self.accel_specifiers,
0902 )
0903
0904
0905 self.__lines(2)
0906 self.__declare_multi_store(
0907 "accelerator_store", "accel_id", self.accel_specifiers, is_regular=False
0908 )
0909
0910
0911
0912
0913 def __declare_geometry_objects(self, md: metadata):
0914
0915
0916 def is_default_accel(value_type, accel):
0917 is_surface_default = (
0918 value_type == "surface" and accel is md.default_surface_accel.specifier
0919 )
0920 is_volume_default = (
0921 value_type == "index" and accel is md.default_volume_accel.specifier
0922 )
0923
0924 return is_surface_default or is_volume_default
0925
0926
0927
0928
0929
0930
0931
0932
0933
0934
0935 accel_to_types = {}
0936 for t, accels in md.acceleration_structs.items():
0937 for a, i, v in accels:
0938 specifier = f"{a.specifier}"
0939
0940 if (specifier, i, v) not in accel_to_types:
0941 accel_to_types[specifier, i, v] = set()
0942
0943 accel_to_types[specifier, i, v].add(t)
0944
0945
0946 type_id_dict = {}
0947 for acc, type_id, value_type in accel_to_types.keys():
0948 if type_id > -1:
0949 type_id_dict.setdefault((type_id, value_type), []).append(acc)
0950
0951
0952 accel_to_types_merged = {}
0953 for acc, type_id, value_type in accel_to_types.keys():
0954 representative = acc
0955 if (
0956 not is_default_accel(value_type, acc)
0957 and (type_id, value_type) in type_id_dict
0958
0959 ):
0960 tmp_repr = type_id_dict[type_id, value_type][0]
0961 representative = (
0962 type_id_dict[type_id, value_type][1]
0963 if is_default_accel(value_type, tmp_repr)
0964 else tmp_repr
0965 )
0966
0967 if (representative, value_type) not in accel_to_types_merged:
0968 accel_to_types_merged[representative, value_type] = (
0969 set(accel_to_types[acc, type_id, value_type]),
0970 type_id,
0971 )
0972 else:
0973 obj_set, priority = accel_to_types_merged[representative, value_type]
0974 obj_set.update(accel_to_types[acc, type_id, value_type])
0975
0976
0977 if type_id >= 0 and (priority == -1 or priority < type_id):
0978 accel_to_types_merged[representative, value_type] = (
0979 obj_set,
0980 type_id,
0981 )
0982
0983
0984 counted_dict = Counter(
0985 [tuple(s) for (s, type_id) in accel_to_types_merged.values()]
0986 )
0987 duplicates = {
0988 key: (value, type_id)
0989 for key, (value, type_id) in accel_to_types_merged.items()
0990 if counted_dict[tuple(value)] > 1
0991 }
0992
0993
0994 flipped_duplicates = {}
0995 for key, value in duplicates.items():
0996 obj_set, type_id = value
0997 flipped_duplicates.setdefault((tuple(obj_set), type_id), []).append(key)
0998
0999
1000 removal_keys = []
1001 for key, value in flipped_duplicates.items():
1002
1003 type_id_counter = 10000
1004 _, type_id = key
1005 for acc, value_type in value:
1006
1007 if is_default_accel(value_type, acc):
1008 continue
1009
1010 if type_id == -1 or type_id >= type_id_counter:
1011 removal_keys.append((acc, value_type, type_id))
1012 else:
1013 type_id_counter = type_id
1014
1015 for rk in removal_keys:
1016 accel_to_types_merged.pop(rk, None)
1017
1018
1019 accel_to_types_sorted = dict(
1020 sorted(
1021 accel_to_types_merged.items(),
1022 key=lambda item: len(item[1]),
1023 reverse=True,
1024 )
1025 )
1026
1027
1028 accel_link_types = {}
1029
1030
1031 def add_link(key, new_objects):
1032 new_link_types = [
1033 l
1034 for l in new_objects
1035 if l not in itertools.chain(*accel_link_types.values())
1036 ]
1037
1038 if new_link_types:
1039 accel_link_types[key] = new_link_types
1040
1041 for (accel, value_type), (obj_set, type_id) in accel_to_types_sorted.items():
1042
1043
1044 if is_default_accel(value_type, accel):
1045 continue
1046
1047 add_link(key=type_id, new_objects=obj_set)
1048
1049
1050 surface_default_list = [
1051 v
1052 for k, (v, t) in accel_to_types_sorted.items()
1053 if "surface" in k and md.default_surface_accel.specifier in k
1054 ]
1055 add_link(
1056 key="surface_default",
1057 new_objects=itertools.chain.from_iterable(surface_default_list),
1058 )
1059
1060 volume_default_list = [
1061 v
1062 for k, (v, t) in accel_to_types_sorted.items()
1063 if "index" in k and md.default_volume_accel.specifier in k
1064 ]
1065 add_link(
1066 key="volume_default",
1067 new_objects=itertools.chain.from_iterable(volume_default_list),
1068 )
1069
1070
1071 self.__lines(2)
1072 self.__put(f"enum geo_objects : {md.id_base} {{\n")
1073
1074 self.indent = self.indent + 1
1075
1076
1077 portal_key, portal_group = [
1078 (k, v) for (k, v) in accel_link_types.items() if "portal" in v
1079 ][0]
1080
1081 if not portal_group:
1082 self.logger.error(
1083 "No acceleration structure link define for portal surfaces!"
1084 )
1085 return
1086
1087 for gid in portal_group:
1088 self.__put(f"e_{gid} = 0u,\n")
1089
1090
1091 i = 1
1092 for key, link_group in accel_link_types.items():
1093 if key == portal_key:
1094 continue
1095 for gid in link_group:
1096 self.__put(f"e_{gid} = {i}u,\n")
1097 i = i + 1
1098
1099 self.__put(f"e_size = {len(accel_link_types)}u,\n")
1100 self.__put("e_all = e_size,\n")
1101
1102 self.indent = self.indent - 1
1103
1104 self.__put("};")
1105
1106
1107 self.__lines(2)
1108 self.__define_enum_stream_op(
1109 "geo_objects",
1110 accel_link_types,
1111 extra_items={"e_size": ["size"]},
1112 )
1113
1114
1115 self.__lines(2)
1116 self.__put(
1117 "using object_link_type = dmulti_index<dtyped_index<accel_id, dindex>, geo_objects::e_size>;"
1118 )
1119
1120
1121 def __finish(self):
1122 self.__lines(1)
1123 self.indent = self.indent - 1
1124 self.__put("};\n\n} // namespace detray\n")
1125 self.file.close()
1126
1127 self.logger.debug("Wrote file to disk")
1128
1129
1130 if self.format_header and os.path.isfile(self.file.name):
1131 logging.debug("Formatting the header...")
1132 try:
1133 subprocess.run(["clang-format", "-i", "-style=file", self.file.name])
1134 except FileNotFoundError:
1135 logging.error("clang-format not found")