File indexing completed on 2026-05-27 07:24:04
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011
0012 #include "detray/core/detail/container_buffers.hpp"
0013 #include "detray/core/detail/container_views.hpp"
0014 #include "detray/definitions/detail/qualifiers.hpp"
0015 #include "detray/definitions/grid_axis.hpp"
0016 #include "detray/definitions/indexing.hpp"
0017 #include "detray/utils/grid/concepts.hpp"
0018 #include "detray/utils/grid/detail/axis_binning.hpp"
0019 #include "detray/utils/grid/detail/axis_bounds.hpp"
0020 #include "detray/utils/ranges.hpp"
0021 #include "detray/utils/type_list.hpp"
0022 #include "detray/utils/type_registry.hpp"
0023 #include "detray/utils/type_traits.hpp"
0024
0025
0026 #include <vecmem/memory/memory_resource.hpp>
0027
0028
0029 #include <cstddef>
0030 #include <type_traits>
0031
0032 namespace detray::axis {
0033
0034
0035 template <std::size_t DIM, typename index_t = dindex>
0036 struct multi_bin : public dmulti_index<index_t, DIM> {
0037 using base_t = dmulti_index<index_t, DIM>;
0038 using base_t::base_t;
0039 };
0040
0041
0042
0043 using bin_range = darray<int, 2>;
0044
0045
0046 template <std::size_t DIM>
0047 struct multi_bin_range : public dmulti_index<bin_range, DIM> {
0048 using base_t = dmulti_index<bin_range, DIM>;
0049 using base_t::base_t;
0050 };
0051
0052
0053
0054
0055
0056
0057
0058
0059 template <typename bounds_t, typename binning_t>
0060 struct single_axis {
0061
0062 using bounds_type = bounds_t;
0063
0064 using binning_type = binning_t;
0065
0066 using scalar_type = typename binning_type::scalar_type;
0067
0068
0069
0070 using container_types = typename binning_type::container_types;
0071 template <typename T>
0072 using vector_type = typename binning_type::template vector_type<T>;
0073
0074
0075
0076
0077 DETRAY_NO_UNIQUE_ADDRESS bounds_type m_bounds{};
0078
0079 binning_type m_binning{};
0080
0081
0082 constexpr single_axis() = default;
0083
0084
0085 template <typename... Args>
0086 DETRAY_HOST_DEVICE single_axis(const dsized_index_range &indx_range,
0087 const vector_type<scalar_type> *edges)
0088 : m_binning(indx_range, edges) {}
0089
0090
0091
0092
0093
0094
0095 constexpr bool operator==(const single_axis &rhs) const = default;
0096
0097
0098 DETRAY_HOST_DEVICE
0099 constexpr auto label() const -> axis::label { return bounds_type::label; }
0100
0101
0102 DETRAY_HOST_DEVICE
0103 constexpr auto bounds() const -> axis::bounds { return bounds_type::type; }
0104
0105
0106 DETRAY_HOST_DEVICE
0107 constexpr auto binning() const -> axis::binning { return binning_type::type; }
0108
0109
0110 DETRAY_HOST_DEVICE
0111 constexpr dindex nbins() const {
0112
0113
0114 if constexpr (bounds_type::type == axis::bounds::e_open) {
0115 return m_binning.nbins() + 2u;
0116 } else {
0117 return m_binning.nbins();
0118 }
0119 }
0120
0121
0122 template <typename... Args>
0123 DETRAY_HOST_DEVICE constexpr scalar_type bin_width(Args &&...args) const {
0124 return m_binning.bin_width(std::forward<Args>(args)...);
0125 }
0126
0127
0128
0129
0130
0131
0132
0133
0134 DETRAY_HOST_DEVICE
0135 dindex bin(const scalar_type v) const {
0136 int b{m_bounds.map(m_binning.bin(v), m_binning.nbins())};
0137
0138 if constexpr (bounds_type::type == axis::bounds::e_circular) {
0139 b = m_bounds.wrap(b, m_binning.nbins());
0140 }
0141
0142 return static_cast<dindex>(b);
0143 }
0144
0145
0146
0147
0148
0149
0150
0151
0152
0153
0154 template <typename neighbor_t>
0155 DETRAY_HOST_DEVICE bin_range range(const scalar_type v,
0156 const darray<neighbor_t, 2> &nhood) const {
0157 return m_bounds.map(m_binning.range(v, nhood), m_binning.nbins());
0158 }
0159
0160
0161 DETRAY_HOST_DEVICE
0162 darray<scalar_type, 2> bin_edges(const dindex ibin) const {
0163 return m_binning.bin_edges(ibin);
0164 }
0165
0166
0167 DETRAY_HOST_DEVICE
0168 vector_type<scalar_type> bin_edges() const { return m_binning.bin_edges(); }
0169
0170
0171 DETRAY_HOST_DEVICE
0172 darray<scalar_type, 2> span() const { return m_binning.span(); }
0173
0174
0175 DETRAY_HOST_DEVICE
0176 scalar_type min() const { return m_binning.span()[0]; }
0177
0178
0179 DETRAY_HOST_DEVICE
0180 scalar_type max() const { return m_binning.span()[1]; }
0181
0182
0183 DETRAY_HOST
0184 friend std::ostream &operator<<(std::ostream &os, const single_axis &ax) {
0185 os << "label: " << ax.label() << std::endl;
0186 os << "bounds: " << ax.bounds() << std::endl;
0187 os << "binning: " << ax.binning() << std::endl;
0188 os << "n-bins: " << ax.nbins() << std::endl;
0189 os << "span: [" << ax.min() << ", " << ax.max() << "]";
0190
0191 return os;
0192 }
0193 };
0194
0195
0196
0197
0198
0199
0200
0201
0202
0203 template <bool ownership, typename local_frame_t, concepts::axis... axis_ts>
0204 class multi_axis {
0205
0206 using axis_reg = types::registry<axis::label, axis_ts...>;
0207
0208 public:
0209
0210 static constexpr dindex dim = sizeof...(axis_ts);
0211 static constexpr bool is_owning = ownership;
0212
0213
0214
0215 using binnings = types::list<typename axis_ts::binning_type...>;
0216 using bounds = types::list<typename axis_ts::bounds_type...>;
0217 using loc_bin_index = axis::multi_bin<dim>;
0218
0219
0220
0221 using local_frame_type = local_frame_t;
0222 using algebra_type = typename local_frame_type::algebra_type;
0223 using point_type = typename local_frame_type::loc_point;
0224
0225 using scalar_type = typename detray::detail::first_t<axis_ts...>::scalar_type;
0226
0227
0228
0229 using container_types =
0230 typename detray::detail::first_t<axis_ts...>::container_types;
0231 template <typename T>
0232 using vector_type = typename container_types::template vector_type<T>;
0233
0234
0235 private:
0236
0237 using edge_offset_range_t = std::conditional_t<
0238 is_owning, vector_type<dsized_index_range>,
0239 detray::ranges::subrange<const vector_type<dsized_index_range>>>;
0240
0241 using edge_range_t = std::conditional_t<is_owning, vector_type<scalar_type>,
0242 const vector_type<scalar_type> *>;
0243
0244 public:
0245
0246
0247 using edge_offset_container_type = vector_type<dsized_index_range>;
0248 using edges_container_type = vector_type<scalar_type>;
0249
0250
0251
0252 using view_type =
0253 dmulti_view<dvector_view<dsized_index_range>, dvector_view<scalar_type>>;
0254 using const_view_type = dmulti_view<dvector_view<const dsized_index_range>,
0255 dvector_view<const scalar_type>>;
0256
0257 using buffer_type = dmulti_buffer<dvector_buffer<dsized_index_range>,
0258 dvector_buffer<scalar_type>>;
0259
0260
0261 template <bool owning>
0262 using type = multi_axis<owning, local_frame_t, axis_ts...>;
0263
0264
0265 constexpr multi_axis() = default;
0266
0267
0268 template <bool owner = is_owning>
0269 requires owner
0270 DETRAY_HOST explicit multi_axis(vecmem::memory_resource &resource)
0271 : m_edge_offsets(&resource), m_edges(&resource) {}
0272
0273
0274 template <bool owner = is_owning>
0275 requires owner
0276 DETRAY_HOST_DEVICE multi_axis(edge_offset_range_t &&edge_offsets,
0277 edge_range_t &&edges)
0278 : m_edge_offsets(std::move(edge_offsets)), m_edges(std::move(edges)) {}
0279
0280
0281
0282
0283
0284
0285 template <bool owner = is_owning>
0286 requires(!owner)
0287 DETRAY_HOST_DEVICE multi_axis(
0288 const vector_type<dsized_index_range> &edge_offsets,
0289 const vector_type<scalar_type> &edges, const unsigned int offset = 0)
0290 : m_edge_offsets(edge_offsets, dsized_index_range{offset, offset + dim}),
0291 m_edges(&edges) {}
0292
0293
0294
0295
0296 template <concepts::device_view view_t>
0297 DETRAY_HOST_DEVICE explicit multi_axis(const view_t &view)
0298 : m_edge_offsets(detray::detail::get<0>(view.m_view)),
0299 m_edges(detray::detail::get<1>(view.m_view)) {}
0300
0301
0302 DETRAY_HOST_DEVICE
0303 constexpr auto bin_edge_offsets() const -> const edge_offset_range_t & {
0304 return m_edge_offsets;
0305 }
0306
0307
0308 DETRAY_HOST_DEVICE
0309 constexpr auto bin_edges() const -> const vector_type<scalar_type> & {
0310 if constexpr (is_owning) {
0311 return m_edges;
0312 } else {
0313 return *m_edges;
0314 }
0315 }
0316
0317
0318
0319
0320
0321
0322 template <std::size_t I>
0323 DETRAY_HOST_DEVICE types::get<axis_reg, types::id_cast<axis_reg, I>>
0324 get_axis() const {
0325 if constexpr (std::same_as<edge_offset_range_t,
0326 vecmem::vector<dsized_index_range>>) {
0327 #if defined(__CUDACC__)
0328
0329 DETRAY_VERBOSE_DEVICE(
0330 "The host container types must not be called in device code");
0331 assert(false);
0332 return {dsized_index_range{}, &bin_edges()};
0333 #else
0334 return {m_edge_offsets[I], &bin_edges()};
0335 #endif
0336 } else {
0337 return {m_edge_offsets[I], &bin_edges()};
0338 }
0339 }
0340
0341
0342
0343 template <axis::label L>
0344 DETRAY_HOST_DEVICE types::get<axis_reg, L> get_axis() const {
0345 return get_axis<types::index_cast<axis_reg, L>>();
0346 }
0347
0348
0349
0350 template <typename axis_t>
0351 DETRAY_HOST_DEVICE axis_t get_axis() const {
0352 return get_axis<axis_t::bounds_type::label>();
0353 }
0354
0355
0356
0357 DETRAY_HOST_DEVICE constexpr auto nbins_per_axis() const -> loc_bin_index {
0358
0359 loc_bin_index n_bins{};
0360
0361 (get_axis_nbins(get_axis<axis_ts>(), n_bins), ...);
0362
0363 return n_bins;
0364 }
0365
0366
0367 DETRAY_HOST_DEVICE constexpr auto nbins() const -> dindex {
0368 const auto n_bins_per_axis = nbins_per_axis();
0369 dindex n_bins{1u};
0370 for (dindex i = 0u; i < dim; ++i) {
0371 n_bins *= n_bins_per_axis[i];
0372 }
0373 return n_bins;
0374 }
0375
0376
0377
0378
0379
0380
0381
0382
0383 DETRAY_HOST_DEVICE loc_bin_index bins(const point_type &p) const {
0384
0385 loc_bin_index bin_indices{};
0386
0387 (get_axis_bin(get_axis<axis_ts>(), p, bin_indices), ...);
0388
0389 return bin_indices;
0390 }
0391
0392
0393
0394
0395
0396
0397
0398
0399
0400
0401
0402
0403
0404
0405
0406
0407
0408
0409
0410 template <typename neighbor_t>
0411 DETRAY_HOST_DEVICE multi_bin_range<dim> bin_ranges(
0412 const point_type &p, const darray<neighbor_t, 2> &nhood) const {
0413
0414 multi_bin_range<dim> bin_ranges{};
0415
0416 (get_axis_bin_ranges(get_axis<axis_ts>(), p, nhood, bin_ranges), ...);
0417
0418 return bin_ranges;
0419 }
0420
0421
0422 template <bool owner = is_owning>
0423 requires owner
0424 DETRAY_HOST auto get_data() -> view_type {
0425 return view_type{detray::get_data(m_edge_offsets),
0426 detray::get_data(m_edges)};
0427 }
0428
0429
0430
0431 template <bool owner = is_owning>
0432 requires owner
0433 DETRAY_HOST auto get_data() const -> const_view_type {
0434 return const_view_type{detray::get_data(m_edge_offsets),
0435 detray::get_data(m_edges)};
0436 }
0437
0438
0439
0440
0441
0442
0443
0444
0445 DETRAY_HOST_DEVICE constexpr auto operator==(const multi_axis &rhs) const
0446 -> bool {
0447 if constexpr (!std::is_pointer_v<edge_range_t>) {
0448 return m_edge_offsets == rhs.m_edge_offsets && m_edges == rhs.m_edges;
0449 } else {
0450 return m_edge_offsets == rhs.m_edge_offsets && *m_edges == *rhs.m_edges;
0451 }
0452 return false;
0453 }
0454
0455
0456 DETRAY_HOST
0457 friend std::ostream &operator<<(std::ostream &os, const multi_axis &ax) {
0458 os << "Axis 0:\n" << ax.template get_axis<0>();
0459
0460 if constexpr (multi_axis::dim > 1) {
0461 os << "\nAxis 1:\n" << ax.template get_axis<1>();
0462 }
0463 if constexpr (multi_axis::dim > 2) {
0464 os << "\nAxis 2:\n" << ax.template get_axis<2>();
0465 }
0466
0467 return os;
0468 }
0469
0470 private:
0471
0472
0473
0474
0475
0476
0477 template <typename axis_t>
0478 DETRAY_HOST_DEVICE void get_axis_nbins(const axis_t &ax,
0479 loc_bin_index &n_bins) const {
0480
0481 constexpr auto loc_idx{
0482 types::index_cast<axis_reg, axis_t::bounds_type::label>};
0483 n_bins[loc_idx] = ax.nbins();
0484 }
0485
0486
0487
0488
0489
0490
0491
0492
0493
0494
0495 template <typename axis_t>
0496 DETRAY_HOST_DEVICE void get_axis_bin(const axis_t &ax, const point_type &p,
0497 loc_bin_index &bin_indices) const {
0498
0499 constexpr auto loc_idx{
0500 types::index_cast<axis_reg, axis_t::bounds_type::label>};
0501 bin_indices[loc_idx] = ax.bin(p[loc_idx]);
0502 }
0503
0504
0505
0506
0507
0508
0509
0510
0511
0512
0513
0514
0515
0516 template <typename axis_t, typename neighbor_t>
0517 DETRAY_HOST_DEVICE void get_axis_bin_ranges(
0518 const axis_t &ax, const point_type &p, const darray<neighbor_t, 2> &nhood,
0519 multi_bin_range<dim> &bin_ranges) const {
0520
0521 constexpr auto loc_idx{
0522 types::index_cast<axis_reg, axis_t::bounds_type::label>};
0523 bin_ranges[loc_idx] = ax.range(p[loc_idx], nhood);
0524 }
0525
0526
0527 edge_offset_range_t m_edge_offsets{};
0528
0529 edge_range_t m_edges{};
0530 };
0531
0532 namespace detail {
0533
0534
0535 template <bool is_owning, typename containers, typename local_frame, typename,
0536 typename>
0537 struct multi_axis_assembler;
0538
0539
0540 template <bool is_owning, typename containers, typename local_frame,
0541 typename... axis_bounds, typename... binning_ts>
0542 struct multi_axis_assembler<is_owning, containers, local_frame,
0543 types::list<axis_bounds...>,
0544 types::list<binning_ts...>> {
0545 static_assert(sizeof...(axis_bounds) > 0,
0546 "At least one bounds type needs to be defined");
0547 static_assert(sizeof...(axis_bounds) == sizeof...(binning_ts),
0548 "Number of axis bounds for this mask and given binning types "
0549 "don't match!");
0550
0551 using type = axis::multi_axis<is_owning, local_frame,
0552 axis::single_axis<axis_bounds, binning_ts>...>;
0553 };
0554
0555 }
0556
0557 }