Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-17 07:34:40

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 #include "Acts/Utilities/RangeXD.hpp"
0012 
0013 #include <algorithm>
0014 #include <array>
0015 #include <cmath>
0016 #include <functional>
0017 #include <memory>
0018 #include <vector>
0019 
0020 namespace Acts {
0021 /// @brief A general k-d tree with fast range search.
0022 ///
0023 /// This is a generalized k-d tree, with a configurable number of dimension,
0024 /// scalar type, content type, index type, vector type, and leaf size. This
0025 /// class is purposefully generalized to support a wide range of use cases.
0026 ///
0027 /// A k-d tree is, in essence, a k-dimensional binary search tree. Each internal
0028 /// node splits the content of the tree in half, with the pivot point being an
0029 /// orthogonal hyperplane in one of the k dimensions. This allows us to
0030 /// efficiently look up points within certain k-dimensional ranges.
0031 ///
0032 /// This particular class is mostly a wrapper class around an underlying node
0033 /// class which does all the actual work.
0034 ///
0035 /// @note This type is completely immutable after construction.
0036 ///
0037 /// @tparam Dims The number of dimensions.
0038 /// @tparam Type The type of value contained in the tree.
0039 /// @tparam Scalar The scalar type used to construct position vectors.
0040 /// @tparam Vector The general vector type used to construct coordinates.
0041 /// @tparam LeafSize The maximum number of elements stored in a leaf node.
0042 template <std::size_t Dims, typename Type, typename Scalar = double,
0043           template <typename, std::size_t> typename Vector = std::array,
0044           std::size_t LeafSize = 4>
0045 class KDTree {
0046  public:
0047   /// @brief The type of value contained in this k-d tree.
0048   using value_t = Type;
0049 
0050   /// @brief The type describing a multi-dimensional orthogonal range.
0051   using range_t = RangeXD<Dims, Scalar>;
0052 
0053   /// @brief The type of coordinates for points.
0054   using coordinate_t = Vector<Scalar, Dims>;
0055 
0056   /// @brief The type of coordinate-value pairs.
0057   using pair_t = std::pair<coordinate_t, Type>;
0058 
0059   /// @brief The type of a vector of coordinate-value pairs.
0060   using vector_t = std::vector<pair_t>;
0061 
0062   /// @brief The type of iterators in our vectors.
0063   using iterator_t = typename vector_t::iterator;
0064 
0065   /// Type alias for const iterator over coordinate-value pairs
0066   using const_iterator_t = typename vector_t::const_iterator;
0067 
0068   // We do not need an empty constructor - this is never useful.
0069   KDTree() = delete;
0070 
0071   /// @brief Construct a k-d tree from a vector of position-value pairs.
0072   ///
0073   /// This constructor takes an r-value reference to a vector of position-value
0074   /// pairs and constructs a k-d tree from those pairs.
0075   ///
0076   /// @param d The vector of position-value pairs to construct the k-d tree
0077   /// from.
0078   explicit KDTree(vector_t &&d) : m_elems(d) {
0079     // To start out, we need to check whether we need to construct a leaf node
0080     // or an internal node. We create a leaf only if we have at most as many
0081     // elements as the number of elements that can fit into a leaf node.
0082     // Hopefully most invocations of this constructor will have more than a few
0083     // elements!
0084     //
0085     // One interesting thing to note is that all of the nodes in the k-d tree
0086     // have a range in the element vector of the outermost node. They simply
0087     // make in-place changes to this array, and they hold no memory of their
0088     // own.
0089     m_root = std::make_unique<KDTreeNode>(m_elems.begin(), m_elems.end(),
0090                                           m_elems.size() > LeafSize
0091                                               ? KDTreeNode::NodeType::Internal
0092                                               : KDTreeNode::NodeType::Leaf,
0093                                           0UL);
0094   }
0095 
0096   /// @brief Perform an orthogonal range search within the k-d tree.
0097   ///
0098   /// A range search operation is one that takes a k-d tree and an orthogonal
0099   /// range, and returns all values associated with coordinates in the k-d tree
0100   /// that lie within the orthogonal range. k-d trees can do this operation
0101   /// quickly.
0102   ///
0103   /// @param r The range to search for.
0104   ///
0105   /// @return The vector of all values that lie within the given range.
0106   std::vector<Type> rangeSearch(const range_t &r) const {
0107     std::vector<Type> out;
0108 
0109     rangeSearch(r, out);
0110 
0111     return out;
0112   }
0113 
0114   /// @brief Perform an orthogonal range search within the k-d tree, returning
0115   /// keys as well as values.
0116   ///
0117   /// Performs the same logic as the keyless version, but includes a copy of
0118   /// the key with each element.
0119   ///
0120   /// @param r The range to search for.
0121   ///
0122   /// @return The vector of all key-value pairs that lie within the given
0123   /// range.
0124   std::vector<pair_t> rangeSearchWithKey(const range_t &r) const {
0125     std::vector<pair_t> out;
0126 
0127     rangeSearchWithKey(r, out);
0128 
0129     return out;
0130   }
0131 
0132   /// @brief Perform an in-place orthogonal range search within the k-d tree.
0133   ///
0134   /// This range search module operates in place, writing its results to the
0135   /// given output vector.
0136   ///
0137   /// @param r The range to search for.
0138   /// @param v The vector to write the output to.
0139   void rangeSearch(const range_t &r, std::vector<Type> &v) const {
0140     rangeSearchInserter(r, std::back_inserter(v));
0141   }
0142 
0143   /// @brief Perform an in-place orthogonal range search within the k-d tree,
0144   /// including keys in the result.
0145   ///
0146   /// Performs the same operation as the keyless version, but includes the keys
0147   /// in the results.
0148   ///
0149   /// @param r The range to search for.
0150   /// @param v The vector to write the output to.
0151   void rangeSearchWithKey(const range_t &r, std::vector<pair_t> &v) const {
0152     rangeSearchInserterWithKey(r, std::back_inserter(v));
0153   }
0154 
0155   /// @brief Perform an orthogonal range search within the k-d tree, writing
0156   /// the resulting values to an output iterator.
0157   ///
0158   /// This method allows the user more control in where the result is written
0159   /// to.
0160   ///
0161   /// @tparam OutputIt The type of the output iterator.
0162   ///
0163   /// @param r The range to search for.
0164   /// @param i The iterator to write the output to.
0165   template <typename OutputIt>
0166   void rangeSearchInserter(const range_t &r, OutputIt i) const {
0167     rangeSearchMapDiscard(
0168         r, [i](const coordinate_t &, const Type &v) mutable { i = v; });
0169   }
0170 
0171   /// @brief Perform an orthogonal range search within the k-d tree, writing
0172   /// the resulting values to an output iterator, including the keys.
0173   ///
0174   /// Performs the same operation as the keyless version, but includes the key
0175   /// in the output.
0176   ///
0177   /// @tparam OutputIt The type of the output iterator.
0178   ///
0179   /// @param r The range to search for.
0180   /// @param i The iterator to write the output to.
0181   template <typename OutputIt>
0182   void rangeSearchInserterWithKey(const range_t &r, OutputIt i) const {
0183     rangeSearchMapDiscard(
0184         r, [i](const coordinate_t &c, const Type &v) mutable { i = {c, v}; });
0185   }
0186 
0187   /// @brief Perform an orthogonal range search within the k-d tree, applying
0188   /// a mapping function to the values found.
0189   ///
0190   /// In some cases, we may wish to transform the values in some way. This
0191   /// method allows the user to provide a mapping function which takes a set of
0192   /// coordinates and a value and transforms them to a new value, which is
0193   /// returned.
0194   ///
0195   /// @note Your compiler may not be able to deduce the result type
0196   /// automatically, in which case you will need to specify it manually.
0197   ///
0198   /// @tparam Result The return type of the map operation.
0199   ///
0200   /// @param r The range to search for.
0201   /// @param f The mapping function to apply to key-value pairs.
0202   ///
0203   /// @return A vector of elements matching the range after the application of
0204   /// the mapping function.
0205   template <typename Result>
0206   std::vector<Result> rangeSearchMap(
0207       const range_t &r,
0208       const std::function<Result(const coordinate_t &, const Type &)> &f)
0209       const {
0210     std::vector<Result> out;
0211 
0212     rangeSearchMapInserter(r, f, std::back_inserter(out));
0213 
0214     return out;
0215   }
0216 
0217   /// @brief Perform an orthogonal range search within the k-d tree, applying a
0218   /// mapping function to the values found, and inserting them into an
0219   /// inserter.
0220   ///
0221   /// Performs the same operation as the interter-less version, but allows the
0222   /// user additional control over the insertion process.
0223   ///
0224   /// @note Your compiler may not be able to deduce the result type
0225   /// automatically, in which case you will need to specify it manually.
0226   ///
0227   /// @tparam Result The return type of the map operation.
0228   /// @tparam OutputIt The type of the output iterator.
0229   ///
0230   /// @param r The range to search for.
0231   /// @param f The mapping function to apply to key-value pairs.
0232   /// @param i The inserter to insert the results into.
0233   template <typename Result, typename OutputIt>
0234   void rangeSearchMapInserter(
0235       const range_t &r,
0236       const std::function<Result(const coordinate_t &, const Type &)> &f,
0237       OutputIt i) const {
0238     rangeSearchMapDiscard(r, [i, f](const coordinate_t &c,
0239                                     const Type &v) mutable { i = f(c, v); });
0240   }
0241 
0242   /// @brief Perform an orthogonal range search within the k-d tree, applying a
0243   /// a void-returning function with side-effects to each key-value pair.
0244   ///
0245   /// This is the most general version of range search in this class, and every
0246   /// other operation can be reduced to this operation as long as we allow
0247   /// arbitrary side-effects.
0248   ///
0249   /// Functional programmers will know this method as mapM_.
0250   ///
0251   /// @param r The range to search for.
0252   /// @param f The mapping function to apply to key-value pairs.
0253   template <typename Callable>
0254   void rangeSearchMapDiscard(const range_t &r, Callable &&f) const {
0255     m_root->rangeSearchMapDiscard(r, std::forward<Callable>(f));
0256   }
0257 
0258   /// @brief Return the number of elements in the k-d tree.
0259   ///
0260   /// We simply defer this method to the root node of the k-d tree.
0261   ///
0262   /// @return The number of elements in the k-d tree.
0263   std::size_t size(void) const { return m_root->size(); }
0264 
0265   /// Get iterator to first element
0266   /// @return Const iterator to the beginning of the tree elements
0267   const_iterator_t begin(void) const { return m_elems.begin(); }
0268 
0269   /// Get iterator to one past the last element
0270   /// @return Const iterator to the end of the tree elements
0271   const_iterator_t end(void) const { return m_elems.end(); }
0272 
0273  private:
0274   static Scalar nextRepresentable(Scalar v) {
0275     // I'm not super happy with this bit of code, but since 1D ranges are
0276     // semi-open, we can't simply incorporate values by setting the maximum to
0277     // them. Instead, what we need to do is get the next representable value.
0278     // For integer values, this means adding one. For floating point types, we
0279     // rely on the nextafter method to get the smallest possible value that is
0280     // larger than the one we requested.
0281     if constexpr (std::is_integral_v<Scalar>) {
0282       return v + 1;
0283     } else if constexpr (std::is_floating_point_v<Scalar>) {
0284       return std::nextafter(v, std::numeric_limits<Scalar>::max());
0285     }
0286   }
0287 
0288   static range_t boundingBox(iterator_t b, iterator_t e) {
0289     // Firstly, we find the minimum and maximum value in each dimension to
0290     // construct a bounding box around this node's values.
0291     std::array<Scalar, Dims> min_v{}, max_v{};
0292 
0293     for (std::size_t i = 0; i < Dims; ++i) {
0294       min_v[i] = std::numeric_limits<Scalar>::max();
0295       max_v[i] = std::numeric_limits<Scalar>::lowest();
0296     }
0297 
0298     for (iterator_t i = b; i != e; ++i) {
0299       for (std::size_t j = 0; j < Dims; ++j) {
0300         min_v[j] = std::min(min_v[j], i->first[j]);
0301         max_v[j] = std::max(max_v[j], i->first[j]);
0302       }
0303     }
0304 
0305     // Then, we construct a k-dimensional range from the given minima and
0306     // maxima, which again is just a bounding box.
0307     range_t r;
0308 
0309     for (std::size_t j = 0; j < Dims; ++j) {
0310       r[j] = Range1D<Scalar>{min_v[j], nextRepresentable(max_v[j])};
0311     }
0312 
0313     return r;
0314   }
0315 
0316   /// @brief An abstract class containing common features of k-d tree node
0317   /// types.
0318   ///
0319   /// A k-d tree consists of two different node types: leaf nodes and inner
0320   /// nodes. These nodes have some common functionality, which is captured by
0321   /// this common parent node type.
0322   class KDTreeNode {
0323    public:
0324     /// @brief Enumeration type for the possible node types (internal and leaf).
0325     enum class NodeType { Internal, Leaf };
0326 
0327     /// @brief Construct the common data for all node types.
0328     ///
0329     /// The node types share a few concepts, like an n-dimensional range, and a
0330     /// begin and end of the range of elements managed. This constructor
0331     /// calculates these things so that the individual child constructors don't
0332     /// have to.
0333     KDTreeNode(iterator_t _b, iterator_t _e, NodeType _t, std::size_t _d)
0334         : m_type(_t),
0335           m_begin_it(_b),
0336           m_end_it(_e),
0337           m_range(boundingBox(m_begin_it, m_end_it)) {
0338       if (m_type == NodeType::Internal) {
0339         // This constant determines the maximum number of elements where we
0340         // still
0341         // calculate the exact median of the values for the purposes of
0342         // splitting. In general, the closer the pivot value is to the true
0343         // median, the more balanced the tree will be. However, calculating the
0344         // median exactly is an O(n log n) operation, while approximating it is
0345         // an O(1) time.
0346         constexpr std::size_t max_exact_median = 128;
0347 
0348         iterator_t pivot;
0349 
0350         // Next, we need to determine the pivot point of this node, that is to
0351         // say the point in the selected pivot dimension along which point we
0352         // will split the range. To do this, we check how large the set of
0353         // elements is. If it is sufficiently small, we use the median.
0354         // Otherwise we use the mean.
0355         if (size() > max_exact_median) {
0356           // In this case, we have a lot of elements, and sorting the range to
0357           // find the true median might be too expensive. Therefore, we will
0358           // just use the middle value between the minimum and maximum. This is
0359           // not nearly as accurate as using the median, but it's a nice cheat.
0360           Scalar mid = static_cast<Scalar>(0.5) *
0361                        (m_range[_d].max() + m_range[_d].min());
0362 
0363           pivot = std::partition(m_begin_it, m_end_it, [=](const pair_t &i) {
0364             return i.first[_d] < mid;
0365           });
0366         } else {
0367           // If the number of elements is fairly small, we will just calculate
0368           // the median exactly. We do this by finding the values in the
0369           // dimension, sorting it, and then taking the middle one.
0370           std::sort(m_begin_it, m_end_it,
0371                     [_d](const typename iterator_t::value_type &a,
0372                          const typename iterator_t::value_type &b) {
0373                       return a.first[_d] < b.first[_d];
0374                     });
0375 
0376           pivot = m_begin_it + (std::distance(m_begin_it, m_end_it) / 2);
0377         }
0378 
0379         // This should never really happen, but in very select cases where there
0380         // are a lot of equal values in the range, the pivot can end up all the
0381         // way at the end of the array and we end up in an infinite loop. We
0382         // check for pivot points which would not split the range, and fix them
0383         // if they occur.
0384         if (pivot == m_begin_it || pivot == std::prev(m_end_it)) {
0385           pivot = std::next(m_begin_it, LeafSize);
0386         }
0387 
0388         // Calculate the number of elements on the left-hand side, as well as
0389         // the right-hand side. We do this by calculating the difference from
0390         // the begin and end of the array to the pivot point.
0391         std::size_t lhs_size = std::distance(m_begin_it, pivot);
0392         std::size_t rhs_size = std::distance(pivot, m_end_it);
0393 
0394         // Next, we check whether the left-hand node should be another internal
0395         // node or a leaf node, and we construct the node recursively.
0396         m_lhs = std::make_unique<KDTreeNode>(
0397             m_begin_it, pivot,
0398             lhs_size > LeafSize ? NodeType::Internal : NodeType::Leaf,
0399             (_d + 1) % Dims);
0400 
0401         // Same on the right hand side.
0402         m_rhs = std::make_unique<KDTreeNode>(
0403             pivot, m_end_it,
0404             rhs_size > LeafSize ? NodeType::Internal : NodeType::Leaf,
0405             (_d + 1) % Dims);
0406       }
0407     }
0408 
0409     /// @brief Perform a range search in the k-d tree, mapping the key-value
0410     /// pairs to a side-effecting function.
0411     ///
0412     /// This is the most powerful range search method we have, assuming that we
0413     /// can use arbitrary side effects, which we can. All other range search
0414     /// methods are implemented in terms of this particular function.
0415     ///
0416     /// @param r The range to search for.
0417     /// @param f The mapping function to apply to matching elements.
0418     template <typename Callable>
0419     void rangeSearchMapDiscard(const range_t &r, Callable &&f) const {
0420       // Determine whether the range completely covers the bounding box of
0421       // this leaf node. If it is, we can copy all values without having to
0422       // check for them being inside the range again.
0423       bool contained = r >= m_range;
0424 
0425       if (m_type == NodeType::Internal) {
0426         // Firstly, we can check if the range completely contains the bounding
0427         // box of this node. If that is the case, we know for certain that any
0428         // value contained below this node should end up in the output, and we
0429         // can stop recursively looking for them.
0430         if (contained) {
0431           // We can also pre-allocate space for the number of elements, since we
0432           // are inserting all of them anyway.
0433           for (iterator_t i = m_begin_it; i != m_end_it; ++i) {
0434             f(i->first, i->second);
0435           }
0436 
0437           return;
0438         }
0439 
0440         assert(m_lhs && m_rhs && "Did not find lhs and rhs");
0441 
0442         // If we have a left-hand node (which we should!), then we check if
0443         // there is any overlap between the target range and the bounding box of
0444         // the left-hand node. If there is, we recursively search in that node.
0445         if (m_lhs->range() && r) {
0446           m_lhs->rangeSearchMapDiscard(r, std::forward<Callable>(f));
0447         }
0448 
0449         // Then, we perform exactly the same procedure for the right hand side.
0450         if (m_rhs->range() && r) {
0451           m_rhs->rangeSearchMapDiscard(r, std::forward<Callable>(f));
0452         }
0453       } else {
0454         // Iterate over all the elements in this leaf node. This should be a
0455         // relatively small number (the LeafSize template parameter).
0456         for (iterator_t i = m_begin_it; i != m_end_it; ++i) {
0457           // We need to check whether the element is actually inside the range.
0458           // In case this node's bounding box is fully contained within the
0459           // range, we don't actually need to check this.
0460           if (contained || r.contains(i->first)) {
0461             f(i->first, i->second);
0462           }
0463         }
0464       }
0465     }
0466 
0467     /// @brief Determine the number of elements managed by this node.
0468     ///
0469     /// Conveniently, this number is always equal to the distance between the
0470     /// begin iterator and the end iterator, so we can simply delegate to the
0471     /// relevant standard library method.
0472     ///
0473     /// @return The number of elements below this node.
0474     std::size_t size() const { return std::distance(m_begin_it, m_end_it); }
0475 
0476     /// @brief The axis-aligned bounding box containing all elements in this
0477     /// node.
0478     ///
0479     /// @return The minimal axis-aligned bounding box that contains all the
0480     /// elements under this node.
0481     const range_t &range() const { return m_range; }
0482 
0483    protected:
0484     NodeType m_type;
0485 
0486     /// @brief The start and end of the range of coordinate-value pairs under
0487     /// this node.
0488     const iterator_t m_begin_it, m_end_it;
0489 
0490     /// @brief The axis-aligned bounding box of the coordinates under this
0491     /// node.
0492     const range_t m_range;
0493 
0494     /// @brief Pointers to the left and right children.
0495     std::unique_ptr<KDTreeNode> m_lhs;
0496     std::unique_ptr<KDTreeNode> m_rhs;
0497   };
0498 
0499   /// @brief Vector containing all of the elements in this k-d tree, including
0500   /// the elements managed by the nodes inside of it.
0501   vector_t m_elems;
0502 
0503   /// @brief Pointer to the root node of this k-d tree.
0504   std::unique_ptr<KDTreeNode> m_root;
0505 };
0506 }  // namespace Acts