Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2024-09-28 07:03:47

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2022 Chao Peng, Sylvester Joosten, Whitney Armstrong
0003 
0004 /*
0005  *  A hits-level data combiner to combine two datasets into one for machine learning
0006  *
0007  *  Author: Chao Peng (ANL), 05/04/2021
0008  */
0009 #include <algorithm>
0010 #include <bitset>
0011 #include <fmt/format.h>
0012 #include <unordered_map>
0013 
0014 #include "Gaudi/Property.h"
0015 #include "GaudiAlg/GaudiAlgorithm.h"
0016 #include "GaudiAlg/GaudiTool.h"
0017 #include "GaudiAlg/Transformer.h"
0018 #include "GaudiKernel/PhysicalConstants.h"
0019 #include "GaudiKernel/RndmGenerators.h"
0020 #include "GaudiKernel/ToolHandle.h"
0021 
0022 #include "DDRec/CellIDPositionConverter.h"
0023 #include "DDRec/Surface.h"
0024 #include "DDRec/SurfaceManager.h"
0025 
0026 #include <k4FWCore/DataHandle.h>
0027 
0028 // Event Model related classes
0029 #include "edm4eic/CalorimeterHitCollection.h"
0030 #include "edm4hep/utils/vector_utils.h"
0031 
0032 using namespace Gaudi::Units;
0033 
0034 namespace Jug::Reco {
0035 
0036 /** Hits combiner for ML algorithm input.
0037  *
0038  * A hits-level data combiner to combine two datasets into one for machine learning
0039  * It accepts inputs from data sorter that hits are sorted by layers
0040  * Two different datasets will be combined together following specified rules in handling the layers
0041  * Supported rules: concatenate, interlayer
0042  *
0043  * \ingroup reco
0044  */
0045 class ImagingPixelDataCombiner : public GaudiAlgorithm {
0046 private:
0047   Gaudi::Property<int> m_layerIncrement{this, "layerIncrement", 0};
0048   Gaudi::Property<std::string> m_rule{this, "rule", "concatenate"};
0049   DataHandle<edm4eic::CalorimeterHitCollection> m_inputHits1{"inputHits1", Gaudi::DataHandle::Reader, this};
0050   DataHandle<edm4eic::CalorimeterHitCollection> m_inputHits2{"inputHits2", Gaudi::DataHandle::Reader, this};
0051   DataHandle<edm4eic::CalorimeterHitCollection> m_outputHits{"outputHits", Gaudi::DataHandle::Writer, this};
0052   std::vector<std::string> supported_rules{"concatenate", "interlayer"};
0053 
0054 public:
0055   ImagingPixelDataCombiner(const std::string& name, ISvcLocator* svcLoc)
0056       : GaudiAlgorithm(name, svcLoc) {
0057     declareProperty("inputHits1", m_inputHits1, "");
0058     declareProperty("inputHits2", m_inputHits2, "");
0059     declareProperty("outputHits", m_outputHits, "");
0060   }
0061 
0062   StatusCode initialize() override {
0063     if (GaudiAlgorithm::initialize().isFailure()) {
0064       return StatusCode::FAILURE;
0065     }
0066 
0067     if (std::find(supported_rules.begin(), supported_rules.end(), m_rule.value()) == supported_rules.end()) {
0068       error() << fmt::format("unsupported rule: {}, please choose one from [{}]", m_rule.value(),
0069                              fmt::join(supported_rules, ", "))
0070               << endmsg;
0071       return StatusCode::FAILURE;
0072     }
0073 
0074     return StatusCode::SUCCESS;
0075   }
0076 
0077   StatusCode execute() override {
0078     // input collections
0079     const auto* const hits1 = m_inputHits1.get();
0080     const auto* const hits2 = m_inputHits2.get();
0081     std::vector<const edm4eic::CalorimeterHitCollection*> inputs{hits1, hits2};
0082     // Create output collections
0083     auto* mhits = m_outputHits.createAndPut();
0084 
0085     // concatenate
0086     if (m_rule.value() == supported_rules[0]) {
0087       for (int i = 0; i < (int)inputs.size(); ++i) {
0088         const auto* const coll = inputs[i];
0089         for (auto hit : *coll) {
0090           edm4eic::CalorimeterHit h2{
0091               hit.getCellID(),    hit.getEnergy(),   hit.getEnergyError(), hit.getTime(),
0092               hit.getTimeError(), hit.getPosition(), hit.getDimension(),   hit.getLayer() + m_layerIncrement * i,
0093               hit.getSector(),    hit.getLocal(),
0094           };
0095           mhits->push_back(h2);
0096         }
0097       }
0098       // interlayer
0099       // @NOTE: it assumes the input hits are sorted by layers
0100     } else if (m_rule.value() == supported_rules[1]) {
0101       std::vector<int> indices{0, 0};
0102       int curr_coll   = 0;
0103       bool init_layer = false;
0104       int curr_layer  = 0;
0105       // int curr_ihit = 0;
0106       while (indices[0] < (int)hits1->size() || indices[1] < (int)hits2->size()) {
0107         // cyclic index
0108         if (curr_coll >= (int)inputs.size()) {
0109           curr_coll -= (int)inputs.size();
0110         }
0111 
0112         // merge hits
0113         int& i                 = indices[curr_coll];
0114         const auto* const coll = inputs[curr_coll];
0115 
0116         // reach this collection's end
0117         if (i >= (int)coll->size()) {
0118           curr_coll++;
0119           init_layer = false;
0120           // curr_ihit = 0;
0121           // info() << "collection end" << endmsg;
0122           continue;
0123         }
0124 
0125         auto hit = (*coll)[i];
0126         if (!init_layer) {
0127           curr_layer = hit.getLayer();
0128           init_layer = true;
0129         }
0130 
0131         // reach this layer's end
0132         if (curr_layer != hit.getLayer()) {
0133           curr_coll++;
0134           init_layer = false;
0135           // curr_ihit = 0;
0136           // info() << "layer end : " << curr_layer << " != " << hit.getLayer() << endmsg;
0137           continue;
0138         }
0139 
0140         // push hit, increment of index
0141         edm4eic::CalorimeterHit h2{
0142             hit.getCellID(),    hit.getEnergy(),   hit.getEnergyError(), hit.getTime(),
0143             hit.getTimeError(), hit.getPosition(), hit.getDimension(),   hit.getLayer() + m_layerIncrement * curr_coll,
0144             hit.getSector(),    hit.getLocal()};
0145         mhits->push_back(h2);
0146         i++;
0147         // info() << curr_coll << ": " << curr_ihit ++ << endmsg;
0148       }
0149       // info() << hits1->size() << ", " << hits2->size() << endmsg;
0150     }
0151 
0152     return StatusCode::SUCCESS;
0153   }
0154 
0155 }; // class ImagingPixelDataCombiner
0156 
0157 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
0158 DECLARE_COMPONENT(ImagingPixelDataCombiner)
0159 
0160 } // namespace Jug::Reco