Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2024-09-28 07:02:24

0001 // SPDX-License-Identifier: LGPL-3.0-or-later
0002 // Copyright (C) 2022 Chao Peng, Sylvester Joosten, Whitney Armstrong
0003 
0004 /*
0005  *  A hits-level data combiner to combine two datasets into one for machine learning
0006  *
0007  *  Author: Chao Peng (ANL), 05/04/2021
0008  */
0009 #include <algorithm>
0010 #include <bitset>
0011 #include <fmt/format.h>
0012 #include <unordered_map>
0013 
0014 #include "Gaudi/Property.h"
0015 #include "GaudiAlg/GaudiAlgorithm.h"
0016 #include "GaudiAlg/GaudiTool.h"
0017 #include "GaudiAlg/Transformer.h"
0018 #include "GaudiKernel/PhysicalConstants.h"
0019 #include "GaudiKernel/RndmGenerators.h"
0020 #include "GaudiKernel/ToolHandle.h"
0021 
0022 #include "DDRec/CellIDPositionConverter.h"
0023 #include "DDRec/Surface.h"
0024 #include "DDRec/SurfaceManager.h"
0025 
0026 #include "JugBase/DataHandle.h"
0027 #include "JugBase/IGeoSvc.h"
0028 #include "JugBase/Utilities/Utils.hpp"
0029 
0030 // Event Model related classes
0031 #include "edm4eic/CalorimeterHitCollection.h"
0032 #include "edm4hep/utils/vector_utils.h"
0033 
0034 using namespace Gaudi::Units;
0035 
0036 namespace Jug::Reco {
0037 
0038 /** Hits combiner for ML algorithm input.
0039  *
0040  * A hits-level data combiner to combine two datasets into one for machine learning
0041  * It accepts inputs from data sorter that hits are sorted by layers
0042  * Two different datasets will be combined together following specified rules in handling the layers
0043  * Supported rules: concatenate, interlayer
0044  *
0045  * \ingroup reco
0046  */
0047 class ImagingPixelDataCombiner : public GaudiAlgorithm {
0048 private:
0049   Gaudi::Property<int> m_layerIncrement{this, "layerIncrement", 0};
0050   Gaudi::Property<std::string> m_rule{this, "rule", "concatenate"};
0051   DataHandle<edm4eic::CalorimeterHitCollection> m_inputHits1{"inputHits1", Gaudi::DataHandle::Reader, this};
0052   DataHandle<edm4eic::CalorimeterHitCollection> m_inputHits2{"inputHits2", Gaudi::DataHandle::Reader, this};
0053   DataHandle<edm4eic::CalorimeterHitCollection> m_outputHits{"outputHits", Gaudi::DataHandle::Writer, this};
0054   std::vector<std::string> supported_rules{"concatenate", "interlayer"};
0055 
0056 public:
0057   ImagingPixelDataCombiner(const std::string& name, ISvcLocator* svcLoc)
0058       : GaudiAlgorithm(name, svcLoc) {
0059     declareProperty("inputHits1", m_inputHits1, "");
0060     declareProperty("inputHits2", m_inputHits2, "");
0061     declareProperty("outputHits", m_outputHits, "");
0062   }
0063 
0064   StatusCode initialize() override {
0065     if (GaudiAlgorithm::initialize().isFailure()) {
0066       return StatusCode::FAILURE;
0067     }
0068 
0069     if (std::find(supported_rules.begin(), supported_rules.end(), m_rule.value()) == supported_rules.end()) {
0070       error() << fmt::format("unsupported rule: {}, please choose one from [{}]", m_rule.value(),
0071                              fmt::join(supported_rules, ", "))
0072               << endmsg;
0073       return StatusCode::FAILURE;
0074     }
0075 
0076     return StatusCode::SUCCESS;
0077   }
0078 
0079   StatusCode execute() override {
0080     // input collections
0081     const auto* const hits1 = m_inputHits1.get();
0082     const auto* const hits2 = m_inputHits2.get();
0083     std::vector<const edm4eic::CalorimeterHitCollection*> inputs{hits1, hits2};
0084     // Create output collections
0085     auto* mhits = m_outputHits.createAndPut();
0086 
0087     // concatenate
0088     if (m_rule.value() == supported_rules[0]) {
0089       for (int i = 0; i < (int)inputs.size(); ++i) {
0090         const auto* const coll = inputs[i];
0091         for (auto hit : *coll) {
0092           edm4eic::CalorimeterHit h2{
0093               hit.getCellID(),    hit.getEnergy(),   hit.getEnergyError(), hit.getTime(),
0094               hit.getTimeError(), hit.getPosition(), hit.getDimension(),   hit.getLayer() + m_layerIncrement * i,
0095               hit.getSector(),    hit.getLocal(),
0096           };
0097           mhits->push_back(h2);
0098         }
0099       }
0100       // interlayer
0101       // @NOTE: it assumes the input hits are sorted by layers
0102     } else if (m_rule.value() == supported_rules[1]) {
0103       std::vector<int> indices{0, 0};
0104       int curr_coll   = 0;
0105       bool init_layer = false;
0106       int curr_layer  = 0;
0107       // int curr_ihit = 0;
0108       while (indices[0] < (int)hits1->size() || indices[1] < (int)hits2->size()) {
0109         // cyclic index
0110         if (curr_coll >= (int)inputs.size()) {
0111           curr_coll -= (int)inputs.size();
0112         }
0113 
0114         // merge hits
0115         int& i                 = indices[curr_coll];
0116         const auto* const coll = inputs[curr_coll];
0117 
0118         // reach this collection's end
0119         if (i >= (int)coll->size()) {
0120           curr_coll++;
0121           init_layer = false;
0122           // curr_ihit = 0;
0123           // info() << "collection end" << endmsg;
0124           continue;
0125         }
0126 
0127         auto hit = (*coll)[i];
0128         if (!init_layer) {
0129           curr_layer = hit.getLayer();
0130           init_layer = true;
0131         }
0132 
0133         // reach this layer's end
0134         if (curr_layer != hit.getLayer()) {
0135           curr_coll++;
0136           init_layer = false;
0137           // curr_ihit = 0;
0138           // info() << "layer end : " << curr_layer << " != " << hit.getLayer() << endmsg;
0139           continue;
0140         }
0141 
0142         // push hit, increment of index
0143         edm4eic::CalorimeterHit h2{
0144             hit.getCellID(),    hit.getEnergy(),   hit.getEnergyError(), hit.getTime(),
0145             hit.getTimeError(), hit.getPosition(), hit.getDimension(),   hit.getLayer() + m_layerIncrement * curr_coll,
0146             hit.getSector(),    hit.getLocal()};
0147         mhits->push_back(h2);
0148         i++;
0149         // info() << curr_coll << ": " << curr_ihit ++ << endmsg;
0150       }
0151       // info() << hits1->size() << ", " << hits2->size() << endmsg;
0152     }
0153 
0154     return StatusCode::SUCCESS;
0155   }
0156 
0157 }; // class ImagingPixelDataCombiner
0158 
0159 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
0160 DECLARE_COMPONENT(ImagingPixelDataCombiner)
0161 
0162 } // namespace Jug::Reco