Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-07-05 08:12:15

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 #include <cmath>
0012 #include <exception>
0013 #include <functional>
0014 #include <sstream>
0015 #include <vector>
0016 
0017 #include "TTreeReaderValue.h"
0018 
0019 // Pairs of elements of the same type
0020 template <typename T>
0021 using HomogeneousPair = std::pair<T, T>;
0022 
0023 // === TYPE ERASURE FOR CONCRETE DATA ===
0024 
0025 // Minimal type-erasure wrapper for std::vector<T>. This will be used as a
0026 // workaround to compensate for the absence of C++17's std::any in Cling.
0027 class AnyVector {
0028  public:
0029   // Create a type-erased vector<T>, using proposed constructor arguments.
0030   // Returns a pair containing the type-erased vector and a pointer to the
0031   // underlying concrete vector.
0032   template <typename T, typename... Args>
0033   static std::pair<AnyVector, std::vector<T>*> create(Args&&... args) {
0034     std::vector<T>* vector = new std::vector<T>(std::forward<Args>(args)...);
0035     std::function<void()> deleter = [vector] { delete vector; };
0036     return {AnyVector{static_cast<void*>(vector), std::move(deleter)}, vector};
0037   }
0038 
0039   // Default-construct a null type-erased vector
0040   AnyVector() = default;
0041 
0042   // Move-construct a type-erased vector
0043   AnyVector(AnyVector&& other)
0044       : m_vector{other.m_vector}, m_deleter{std::move(other.m_deleter)} {
0045     other.m_vector = nullptr;
0046   }
0047 
0048   // Move-assign a type-erased vector
0049   AnyVector& operator=(AnyVector&& other) {
0050     if (&other != this) {
0051       m_vector = other.m_vector;
0052       m_deleter = std::move(other.m_deleter);
0053       other.m_vector = nullptr;
0054     }
0055     return *this;
0056   }
0057 
0058   // Forbid copies of type-erased vectors
0059   AnyVector(const AnyVector&) = delete;
0060   AnyVector& operator=(const AnyVector&) = delete;
0061 
0062   // Delete a type-erased vector
0063   ~AnyVector() {
0064     if (m_vector != nullptr) {
0065       m_deleter();
0066     }
0067   }
0068 
0069  private:
0070   // Construct a type-erased vector from a concrete vector
0071   AnyVector(void* vector, std::function<void()>&& deleter)
0072       : m_vector{vector}, m_deleter{std::move(deleter)} {}
0073 
0074   void* m_vector{nullptr};          // Casted std::vector<T>*
0075   std::function<void()> m_deleter;  // Deletes the underlying vector
0076 };
0077 
0078 // === GENERIC DATA ORDERING ===
0079 
0080 // We want to check, in a single operation, how two pieces of data are ordered
0081 enum class Ordering { SMALLER, EQUAL, GREATER };
0082 
0083 // In general, any type which implements comparison operators that behave as a
0084 // mathematical total order can use this comparison function...
0085 template <typename T>
0086 Ordering compare(const T& x, const T& y) {
0087   if (x < y) {
0088     return Ordering::SMALLER;
0089   } else if (x == y) {
0090     return Ordering::EQUAL;
0091   } else {
0092     return Ordering::GREATER;
0093   }
0094 }
0095 
0096 // ...but we'll want to tweak that a little for floats, to handle NaNs better...
0097 template <typename T>
0098 Ordering compareFloat(const T& x, const T& y) {
0099   if (std::isless(x, y)) {
0100     return Ordering::SMALLER;
0101   } else if (std::isgreater(x, y)) {
0102     return Ordering::GREATER;
0103   } else {
0104     return Ordering::EQUAL;
0105   }
0106 }
0107 
0108 template <>
0109 Ordering compare(const float& x, const float& y) {
0110   return compareFloat(x, y);
0111 }
0112 
0113 template <>
0114 Ordering compare(const double& x, const double& y) {
0115   return compareFloat(x, y);
0116 }
0117 
0118 // ...and for vectors, where the default lexicographic comparison cannot
0119 // efficiently tell all of what we want in a single vector iteration pass.
0120 template <typename U>
0121 Ordering compare(const std::vector<U>& v1, const std::vector<U>& v2) {
0122   // First try to order by size...
0123   if (v1.size() < v2.size()) {
0124     return Ordering::SMALLER;
0125   } else if (v1.size() > v2.size()) {
0126     return Ordering::GREATER;
0127   }
0128   // ...if the size is identical...
0129   else {
0130     // ...then try to order by contents of increasing index...
0131     for (std::size_t i = 0; i < v1.size(); ++i) {
0132       if (v1[i] < v2[i]) {
0133         return Ordering::SMALLER;
0134       } else if (v1[i] > v2[i]) {
0135         return Ordering::GREATER;
0136       }
0137     }
0138 
0139     // ...and declare the vectors equal if the contents are equal
0140     return Ordering::EQUAL;
0141   }
0142 }
0143 
0144 // std::swap does not work with std::vector<bool> because it does not return
0145 // lvalue references.
0146 template <typename U>
0147 void swap(std::vector<U>& vec, std::size_t i, std::size_t j) {
0148   if constexpr (std::is_same_v<U, bool>) {
0149     bool temp = vec[i];
0150     vec[i] = vec[j];
0151     vec[j] = temp;
0152   } else {
0153     std::swap(vec[i], vec[j]);
0154   }
0155 };
0156 
0157 // === GENERIC SORTING MECHANISM ===
0158 
0159 // The following functions are generic implementations of sorting algorithms,
0160 // which require only a comparison operator, a swapping operator, and an
0161 // inclusive range of indices to be sorted in order to operate
0162 using IndexComparator = std::function<Ordering(std::size_t, std::size_t)>;
0163 using IndexSwapper = std::function<void(std::size_t, std::size_t)>;
0164 
0165 // Selection sort has pertty bad asymptotic scaling, but it is non-recursive
0166 // and in-place, which makes it a good choice for smaller inputs
0167 void selectionSort(const std::size_t firstIndex, const std::size_t lastIndex,
0168                    const IndexComparator& compare, const IndexSwapper& swap) {
0169   for (std::size_t targetIndex = firstIndex; targetIndex < lastIndex;
0170        ++targetIndex) {
0171     std::size_t minIndex = targetIndex;
0172     for (std::size_t readIndex = targetIndex + 1; readIndex <= lastIndex;
0173          ++readIndex) {
0174       if (compare(readIndex, minIndex) == Ordering::SMALLER) {
0175         minIndex = readIndex;
0176       }
0177     }
0178     if (minIndex != targetIndex) {
0179       swap(minIndex, targetIndex);
0180     }
0181   }
0182 }
0183 
0184 // Quick sort is used as the top-level sorting algorithm for our datasets
0185 void quickSort(const std::size_t firstIndex, const std::size_t lastIndex,
0186                const IndexComparator& compare, const IndexSwapper& swap) {
0187   // We switch to non-recursive selection sort when the range becomes too small.
0188   // This optimization voids the need for detection of 0- and 1-element input.
0189   static const std::size_t NON_RECURSIVE_THRESHOLD = 25;
0190   if (lastIndex - firstIndex < NON_RECURSIVE_THRESHOLD) {
0191     selectionSort(firstIndex, lastIndex, compare, swap);
0192     return;
0193   }
0194 
0195   // We'll use the midpoint as a pivot. Later on, we can switch to more
0196   // elaborate pivot selection schemes if their usefulness for our use case
0197   // (pseudorandom events with thread-originated reordering) is demonstrated.
0198   std::size_t pivotIndex = firstIndex + (lastIndex - firstIndex) / 2;
0199 
0200   // Partition the data around the pivot using Hoare's scheme
0201   std::size_t splitIndex = 0;
0202   {
0203     // Start with two indices one step beyond each side of the array
0204     std::size_t i = firstIndex - 1;
0205     std::size_t j = lastIndex + 1;
0206     while (true) {
0207       // Move left index forward at least once, and until an element which is
0208       // greater than or equal to the pivot is detected.
0209       do {
0210         i = i + 1;
0211       } while (compare(i, pivotIndex) == Ordering::SMALLER);
0212 
0213       // Move right index backward at least once, and until an element which is
0214       // smaller than or equal to the pivot is detected
0215       do {
0216         j = j - 1;
0217       } while (compare(j, pivotIndex) == Ordering::GREATER);
0218 
0219       // By transitivity of inequality, the element at location i is greater
0220       // than or equal to the one at location j, and a swap could be required
0221       if (i < j) {
0222         // These elements are in the wrong order, swap them
0223         swap(i, j);
0224 
0225         // Don't forget to keep track the pivot's index along the way, as this
0226         // is currently the only way by which we can refer to the pivot element.
0227         if (i == pivotIndex) {
0228           pivotIndex = j;
0229         } else if (j == pivotIndex) {
0230           pivotIndex = i;
0231         }
0232       } else {
0233         // If i and j went past each other, our partitioning is done
0234         splitIndex = j;
0235         break;
0236       }
0237     }
0238   }
0239 
0240   // Now, we'll recursively sort both partitions using quicksort. We should
0241   // recurse in the smaller range first, so as to leverage compiler tail call
0242   // optimization if available.
0243   if (splitIndex - firstIndex <= lastIndex - splitIndex - 1) {
0244     quickSort(firstIndex, splitIndex, compare, swap);
0245     quickSort(splitIndex + 1, lastIndex, compare, swap);
0246   } else {
0247     quickSort(splitIndex + 1, lastIndex, compare, swap);
0248     quickSort(firstIndex, splitIndex, compare, swap);
0249   }
0250 }
0251 
0252 // === GENERIC TTREE BRANCH MANIPULATION MECHANISM ===
0253 
0254 // When comparing a pair of TTrees, we'll need to set up quite a few facilities
0255 // for each branch. Since this setup is dependent on the branch data type, which
0256 // is only known at runtime, it is quite involved, which is why we extracted it
0257 // to a separate struct and its constructor.
0258 struct BranchComparisonHarness {
0259   // We'll keep track of the branch name for debugging purposes
0260   std::string branchName;
0261 
0262   // Type-erased event data for the current branch, in both trees being compared
0263   HomogeneousPair<AnyVector> eventData;
0264 
0265   // Function which loads the active event data for the current branch. This is
0266   // to be performed for each branch and combined with TTreeReader-based event
0267   // iteration on both trees.
0268   void loadCurrentEvent() { (*m_eventLoaderPtr)(); }
0269 
0270   // Functors which compare two events within a given tree and order them
0271   // with respect to one another, and which swap two events. By combining such
0272   // functionality for each branch, a global tree order can be produced.
0273   HomogeneousPair<std::pair<IndexComparator, IndexSwapper>> sortHarness;
0274 
0275   // Functor which compares the current event data in *both* trees and tells
0276   // whether it is identical. The comparison is order-sensitive, so events
0277   // should previously have been sorted in a canonical order in both trees.
0278   // By combining the results for each branch, global tree equality is defined.
0279   using TreeComparator = std::function<bool()>;
0280   TreeComparator eventDataEqual;
0281 
0282   // Functor which dumps the event data for the active event side by side, in
0283   // two columns. This enables manual comparison during debugging.
0284   std::function<void()> dumpEventData;
0285 
0286   // General metadata about the tree which is identical for every branch
0287   struct TreeMetadata {
0288     TTreeReader& tree1Reader;
0289     TTreeReader& tree2Reader;
0290     const std::size_t entryCount;
0291   };
0292 
0293   // This exception will be thrown if an unsupported branch type is encountered
0294   class UnsupportedBranchType : public std::exception {};
0295 
0296   // Type-erased factory of branch comparison harnesses, taking ROOT run-time
0297   // type information as input in order to select an appropriate C++ constructor
0298   static BranchComparisonHarness create(TreeMetadata& treeMetadata,
0299                                         const std::string& branchName,
0300                                         const EDataType dataType,
0301                                         const std::string& className) {
0302     switch (dataType) {
0303       case kChar_t:
0304         return BranchComparisonHarness::create<char>(treeMetadata, branchName);
0305       case kUChar_t:
0306         return BranchComparisonHarness::create<unsigned char>(treeMetadata,
0307                                                               branchName);
0308       case kShort_t:
0309         return BranchComparisonHarness::create<short>(treeMetadata, branchName);
0310       case kUShort_t:
0311         return BranchComparisonHarness::create<unsigned short>(treeMetadata,
0312                                                                branchName);
0313       case kInt_t:
0314         return BranchComparisonHarness::create<int>(treeMetadata, branchName);
0315       case kUInt_t:
0316         return BranchComparisonHarness::create<unsigned int>(treeMetadata,
0317                                                              branchName);
0318       case kLong_t:
0319         return BranchComparisonHarness::create<long>(treeMetadata, branchName);
0320       case kULong_t:
0321         return BranchComparisonHarness::create<unsigned long>(treeMetadata,
0322                                                               branchName);
0323       case kULong64_t:
0324         return BranchComparisonHarness::create<unsigned long long>(treeMetadata,
0325                                                                    branchName);
0326 
0327       case kFloat_t:
0328         return BranchComparisonHarness::create<float>(treeMetadata, branchName);
0329       case kDouble_t:
0330         return BranchComparisonHarness::create<double>(treeMetadata,
0331                                                        branchName);
0332       case kBool_t:
0333         return BranchComparisonHarness::create<bool>(treeMetadata, branchName);
0334       case kOther_t:
0335         if (className.substr(0, 6) == "vector") {
0336           std::string elementType = className.substr(7, className.size() - 8);
0337           return BranchComparisonHarness::createVector(treeMetadata, branchName,
0338                                                        elementType);
0339         } else {
0340           throw UnsupportedBranchType();
0341         }
0342       default:
0343         throw UnsupportedBranchType();
0344     }
0345   }
0346 
0347  private:
0348   // Under the hood, the top-level factory calls the following function
0349   // template, parametrized with the proper C++ data type
0350   template <typename T>
0351   static BranchComparisonHarness create(TreeMetadata& treeMetadata,
0352                                         const std::string& branchName) {
0353     // Our result will eventually go there
0354     BranchComparisonHarness result;
0355 
0356     // Save the branch name for debugging purposes
0357     result.branchName = branchName;
0358 
0359     // Setup type-erased event data storage
0360     auto tree1DataStorage = AnyVector::create<T>();
0361     auto tree2DataStorage = AnyVector::create<T>();
0362     result.eventData = std::make_pair(std::move(tree1DataStorage.first),
0363                                       std::move(tree2DataStorage.first));
0364     std::vector<T>& tree1Data = *tree1DataStorage.second;
0365     std::vector<T>& tree2Data = *tree2DataStorage.second;
0366 
0367     // Use our advance knowledge of the event count to preallocate storage
0368     tree1Data.reserve(treeMetadata.entryCount);
0369     tree2Data.reserve(treeMetadata.entryCount);
0370 
0371     // Setup event data readout
0372     result.m_eventLoaderPtr.reset(
0373         new EventLoaderT<T>{treeMetadata.tree1Reader, treeMetadata.tree2Reader,
0374                             branchName, tree1Data, tree2Data});
0375 
0376     // Setup event comparison and swapping for each tree
0377     result.sortHarness = std::make_pair(
0378         std::make_pair(
0379             [&tree1Data](std::size_t i, std::size_t j) -> Ordering {
0380               return compare(tree1Data[i], tree1Data[j]);
0381             },
0382             [&tree1Data](std::size_t i, std::size_t j) {
0383               swap(tree1Data, i, j);
0384             }),
0385         std::make_pair(
0386             [&tree2Data](std::size_t i, std::size_t j) -> Ordering {
0387               return compare(tree2Data[i], tree2Data[j]);
0388             },
0389             [&tree2Data](std::size_t i, std::size_t j) {
0390               swap(tree2Data, i, j);
0391             }));
0392 
0393     // Setup order-sensitive tree comparison
0394     result.eventDataEqual = [&tree1Data, &tree2Data]() -> bool {
0395       for (std::size_t i = 0; i < tree1Data.size(); ++i) {
0396         if (compare(tree1Data[i], tree2Data[i]) != Ordering::EQUAL) {
0397           return false;
0398         }
0399       }
0400       return true;
0401     };
0402 
0403     // Add a debugging method to dump event data
0404     result.dumpEventData = [&tree1Data, &tree2Data] {
0405       std::cout << "File 1                \tFile 2" << std::endl;
0406       for (std::size_t i = 0; i < tree1Data.size(); ++i) {
0407         std::cout << toString(tree1Data[i]) << "      \t"
0408                   << toString(tree2Data[i]) << std::endl;
0409       }
0410     };
0411 
0412     // ...and we're good to go!
0413     return result;
0414   }
0415 
0416   // Because the people who created TTreeReaderValue could not bother to make it
0417   // movable (for moving it into a lambda), or even just virtually destructible
0418   // (for moving a unique_ptr into the lambda), loadEventData can only be
0419   // implemented through lots of unpleasant C++98-ish boilerplate.
0420   class IEventLoader {
0421    public:
0422     virtual ~IEventLoader() = default;
0423     virtual void operator()() = 0;
0424   };
0425 
0426   template <typename T>
0427   class EventLoaderT : public IEventLoader {
0428    public:
0429     EventLoaderT(TTreeReader& tree1Reader, TTreeReader& tree2Reader,
0430                  const std::string& branchName, std::vector<T>& tree1Data,
0431                  std::vector<T>& tree2Data)
0432         : branch1Reader{tree1Reader, branchName.c_str()},
0433           branch2Reader{tree2Reader, branchName.c_str()},
0434           branch1Data(tree1Data),
0435           branch2Data(tree2Data) {}
0436 
0437     void operator()() override {
0438       T* data1 = branch1Reader.Get();
0439       T* data2 = branch2Reader.Get();
0440       if (data1 == nullptr || data2 == nullptr) {
0441         throw std::runtime_error{"Corrupt data"};
0442       }
0443       branch1Data.push_back(*data1);
0444       branch2Data.push_back(*data2);
0445     }
0446 
0447    private:
0448     TTreeReaderValue<T> branch1Reader, branch2Reader;
0449     std::vector<T>& branch1Data;
0450     std::vector<T>& branch2Data;
0451   };
0452 
0453   std::unique_ptr<IEventLoader> m_eventLoaderPtr;
0454 
0455 #define CREATE_VECTOR__HANDLE_TYPE(type_name)                       \
0456   if (elemType == #type_name) {                                     \
0457     return BranchComparisonHarness::create<std::vector<type_name>>( \
0458         treeMetadata, branchName);                                  \
0459   }
0460 
0461 // For integer types, we'll want to handle both signed and unsigned versions
0462 #define CREATE_VECTOR__HANDLE_INTEGER_TYPE(integer_type_name) \
0463   CREATE_VECTOR__HANDLE_TYPE(integer_type_name)               \
0464   else CREATE_VECTOR__HANDLE_TYPE(unsigned integer_type_name)
0465 
0466 #define CREATE_VECTOR__HANDLE_INTEGER_TYPE_ROOT(integer_type_name) \
0467   CREATE_VECTOR__HANDLE_TYPE(integer_type_name)                    \
0468   else CREATE_VECTOR__HANDLE_TYPE(U##integer_type_name)
0469 
0470   // This helper factory helps building branches associated with std::vectors
0471   // of data, which are the only STL collection that we support at the moment.
0472   static BranchComparisonHarness createVector(TreeMetadata& treeMetadata,
0473                                               const std::string& branchName,
0474                                               const std::string& elemType) {
0475     // We support vectors of different types by switching across type (strings)
0476 
0477     // clang-format off
0478 
0479     // Handle vectors of booleans
0480     CREATE_VECTOR__HANDLE_TYPE(bool)
0481 
0482     // Handle vectors of all standard floating-point types
0483     else CREATE_VECTOR__HANDLE_TYPE(float)
0484     else CREATE_VECTOR__HANDLE_TYPE(double)
0485 
0486     // Handle vectors of all standard integer types
0487     else CREATE_VECTOR__HANDLE_INTEGER_TYPE(char)
0488     else CREATE_VECTOR__HANDLE_INTEGER_TYPE(short)
0489     else CREATE_VECTOR__HANDLE_INTEGER_TYPE(int)
0490     else CREATE_VECTOR__HANDLE_INTEGER_TYPE(long)
0491     else CREATE_VECTOR__HANDLE_INTEGER_TYPE(long long)
0492     else CREATE_VECTOR__HANDLE_INTEGER_TYPE_ROOT(Char_t)
0493     else CREATE_VECTOR__HANDLE_INTEGER_TYPE_ROOT(Short_t)
0494     else CREATE_VECTOR__HANDLE_INTEGER_TYPE_ROOT(Int_t)
0495     else CREATE_VECTOR__HANDLE_INTEGER_TYPE_ROOT(Long_t)
0496     else CREATE_VECTOR__HANDLE_INTEGER_TYPE_ROOT(Long64_t)
0497 
0498     // Throw an exception if the vector element type is not recognized
0499     else {
0500       std::cerr << "Unsupported vector element type: " << elemType << std::endl;
0501       throw UnsupportedBranchType();
0502     }
0503 
0504     // clang-format on
0505   }
0506 
0507 #undef CREATE_VECTOR__HANDLE_TYPE
0508 #undef CREATE_VECTOR__HANDLE_INTEGER_TYPE
0509 #undef CREATE_VECTOR__HANDLE_INTEGER_TYPE_ROOT
0510 
0511   // This helper method provides general string conversion for all supported
0512   // branch event data types.
0513   template <typename T>
0514   static std::string toString(const T& data) {
0515     std::ostringstream oss;
0516     oss << data;
0517     return oss.str();
0518   }
0519 
0520   template <typename U>
0521   static std::string toString(const std::vector<U>& vector) {
0522     std::ostringstream oss{"{ "};
0523     for (const auto& data : vector) {
0524       oss << data << "  \t";
0525     }
0526     oss << " }";
0527     return oss.str();
0528   }
0529 };