Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-07-03 08:57:24

0001 // Copyright 2024, Jefferson Science Associates, LLC.
0002 // Subject to the terms in the LICENSE file found in the top-level directory.
0003 // Created by Nathan Brei
0004 
0005 #pragma once
0006 #include "JANA/Components/JComponentSummary.h"
0007 #if JANA2_HAVE_PODIO
0008 #include "JANA/Podio/JFactoryPodioT.h"
0009 #endif
0010 #include <JANA/JEvent.h>
0011 
0012 
0013 namespace jana::components {
0014 
0015 struct JHasInputs {
0016 protected:
0017 
0018     class InputBase;
0019     class VariadicInputBase;
0020 
0021     std::vector<InputBase*> m_inputs;
0022     std::vector<VariadicInputBase*> m_variadic_inputs;
0023 
0024     void RegisterInput(InputBase* input) {
0025         m_inputs.push_back(input);
0026     }
0027 
0028     void RegisterInput(VariadicInputBase* input) {
0029         m_variadic_inputs.push_back(input);
0030     }
0031 
0032     struct InputOptions {
0033         std::string name {""};
0034         JEventLevel level {JEventLevel::None};
0035         bool is_optional {false};
0036     };
0037 
0038     struct VariadicInputOptions {
0039         std::vector<std::string> names {""};
0040         JEventLevel level {JEventLevel::None};
0041         bool is_optional {false};
0042     };
0043 
0044     class InputBase {
0045     protected:
0046         std::string m_type_name;
0047         std::string m_databundle_name;
0048         JEventLevel m_level = JEventLevel::None;
0049         bool m_is_optional = false;
0050 
0051     public:
0052 
0053         void SetOptional(bool isOptional) {
0054             m_is_optional = isOptional;
0055         }
0056 
0057         void SetLevel(JEventLevel level) {
0058             m_level = level;
0059         }
0060 
0061         void SetDatabundleName(std::string name) {
0062             m_databundle_name = name;
0063         }
0064 
0065         const std::string& GetTypeName() const {
0066             return m_type_name;
0067         }
0068 
0069         const std::string& GetDatabundleName() const {
0070             return m_databundle_name;
0071         }
0072 
0073         JEventLevel GetLevel() const {
0074             return m_level;
0075         }
0076 
0077         void Configure(const InputOptions& options) {
0078             m_databundle_name = options.name;
0079             m_level = options.level;
0080             m_is_optional = options.is_optional;
0081         }
0082 
0083         virtual void GetCollection(const JEvent& event) = 0;
0084         virtual void PrefetchCollection(const JEvent& event) = 0;
0085     };
0086 
0087     class VariadicInputBase {
0088     public:
0089         enum class EmptyInputPolicy { IncludeNothing, IncludeEverything };
0090 
0091     protected:
0092         std::string m_type_name;
0093         std::vector<std::string> m_requested_databundle_names;
0094         std::vector<std::string> m_realized_databundle_names;
0095         JEventLevel m_level = JEventLevel::None;
0096         bool m_is_optional = false;
0097         EmptyInputPolicy m_empty_input_policy = EmptyInputPolicy::IncludeNothing;
0098 
0099     public:
0100 
0101         void SetOptional(bool isOptional) {
0102             m_is_optional = isOptional;
0103         }
0104 
0105         void SetLevel(JEventLevel level) {
0106             m_level = level;
0107         }
0108 
0109         void SetRequestedDatabundleNames(std::vector<std::string> names) {
0110             m_requested_databundle_names = names;
0111             m_realized_databundle_names = names;
0112         }
0113 
0114         void SetEmptyInputPolicy(EmptyInputPolicy policy) {
0115             m_empty_input_policy = policy;
0116         }
0117 
0118         const std::string& GetTypeName() const {
0119             return m_type_name;
0120         }
0121 
0122         const std::vector<std::string>& GetRequestedDatabundleNames() const {
0123             return m_requested_databundle_names;
0124         }
0125 
0126         const std::vector<std::string>& GetRealizedDatabundleNames() const {
0127             return m_realized_databundle_names;
0128         }
0129 
0130         JEventLevel GetLevel() const {
0131             return m_level;
0132         }
0133 
0134         void Configure(const VariadicInputOptions& options) {
0135             m_requested_databundle_names = options.names;
0136             m_level = options.level;
0137             m_is_optional = options.is_optional;
0138         }
0139 
0140         virtual void GetCollection(const JEvent& event) = 0;
0141         virtual void PrefetchCollection(const JEvent& event) = 0;
0142     };
0143 
0144 
0145     template <typename T>
0146     class Input : public InputBase {
0147 
0148         std::vector<const T*> m_data;
0149         std::string m_tag;
0150 
0151     public:
0152 
0153         Input(JHasInputs* owner) {
0154             owner->RegisterInput(this);
0155             m_type_name = JTypeInfo::demangle<T>();
0156             m_databundle_name = m_type_name;
0157             m_level = JEventLevel::None;
0158         }
0159 
0160         Input(JHasInputs* owner, const InputOptions& options) {
0161             owner->RegisterInput(this);
0162             m_type_name = JTypeInfo::demangle<T>();
0163             Configure(options);
0164         }
0165 
0166         void SetTag(std::string tag) {
0167             m_tag = tag;
0168             m_databundle_name = m_type_name + ":" + tag;
0169         }
0170 
0171         const std::vector<const T*>& operator()() { return m_data; }
0172         const std::vector<const T*>& operator*() { return m_data; }
0173         const std::vector<const T*>* operator->() { return &m_data; }
0174 
0175 
0176     private:
0177         friend class JComponentT;
0178 
0179         void GetCollection(const JEvent& event) {
0180             auto& level = m_level;
0181             m_data.clear();
0182             if (level == event.GetLevel() || level == JEventLevel::None) {
0183                 event.Get<T>(m_data, m_tag, !m_is_optional);
0184             }
0185             else {
0186                 if (m_is_optional && !event.HasParent(level)) return;
0187                 event.GetParent(level).template Get<T>(m_data, m_tag, !m_is_optional);
0188             }
0189         }
0190         void PrefetchCollection(const JEvent& event) {
0191             if (m_level == event.GetLevel() || m_level == JEventLevel::None) {
0192                 auto fac = event.GetFactory<T>(m_tag, !m_is_optional);
0193                 if (fac != nullptr) {
0194                     fac->Create(event);
0195                 }
0196             }
0197             else {
0198                 if (m_is_optional && !event.HasParent(m_level)) return;
0199                 auto fac = event.GetParent(m_level).template GetFactory<T>(m_tag, !m_is_optional);
0200                 if (fac != nullptr) {
0201                     fac->Create(event);
0202                 }
0203             }
0204         }
0205     };
0206 
0207 #if JANA2_HAVE_PODIO
0208     template <typename PodioT>
0209     class PodioInput : public InputBase {
0210 
0211         const typename PodioT::collection_type* m_data;
0212 
0213     public:
0214 
0215         PodioInput(JHasInputs* owner) {
0216             owner->RegisterInput(this);
0217             m_type_name = JTypeInfo::demangle<PodioT>();
0218             m_databundle_name = m_type_name;
0219             m_level = JEventLevel::None;
0220         }
0221 
0222         PodioInput(JHasInputs* owner, const InputOptions& options) {
0223             owner->RegisterInput(this);
0224             m_type_name = JTypeInfo::demangle<PodioT>();
0225             m_databundle_name = m_type_name;
0226             Configure(options);
0227         }
0228 
0229         const typename PodioT::collection_type* operator()() {
0230             return m_data;
0231         }
0232         const typename PodioT::collection_type& operator*() {
0233             return *m_data;
0234         }
0235         const typename PodioT::collection_type* operator->() {
0236             return m_data;
0237         }
0238 
0239         void SetCollectionName(std::string name) {
0240             m_databundle_name = name;
0241         }
0242 
0243         void SetTag(std::string tag) {
0244             m_databundle_name = m_type_name + ":" + tag;
0245         }
0246 
0247         void GetCollection(const JEvent& event) {
0248             if (m_level == event.GetLevel() || m_level == JEventLevel::None) {
0249                 m_data = event.GetCollection<PodioT>(m_databundle_name, !m_is_optional);
0250             }
0251             else {
0252                 if (m_is_optional && !event.HasParent(m_level)) return;
0253                 m_data = event.GetParent(m_level).template GetCollection<PodioT>(m_databundle_name, !m_is_optional);
0254             }
0255         }
0256 
0257         void PrefetchCollection(const JEvent& event) {
0258             if (m_level == event.GetLevel() || m_level == JEventLevel::None) {
0259                 event.GetCollection<PodioT>(m_databundle_name, !m_is_optional);
0260             }
0261             else {
0262                 if (m_is_optional && !event.HasParent(m_level)) return;
0263                 event.GetParent(m_level).template GetCollection<PodioT>(m_databundle_name, !m_is_optional);
0264             }
0265         }
0266     };
0267 #endif
0268 
0269 
0270     template <typename T>
0271     class VariadicInput : public VariadicInputBase {
0272 
0273         std::vector<std::vector<const T*>> m_datas;
0274 
0275     public:
0276 
0277         VariadicInput(JHasInputs* owner) {
0278             owner->RegisterInput(this);
0279             m_type_name = JTypeInfo::demangle<T>();
0280             m_level = JEventLevel::None;
0281         }
0282 
0283         VariadicInput(JHasInputs* owner, const VariadicInputOptions& options) {
0284             owner->RegisterInput(this);
0285             m_type_name = JTypeInfo::demangle<T>();
0286             Configure(options);
0287         }
0288 
0289         void SetTags(std::vector<std::string> tags) {
0290             m_requested_databundle_names = tags;
0291         }
0292 
0293         const std::vector<std::vector<const T*>>& operator()() { return m_datas; }
0294         const std::vector<std::vector<const T*>>& operator*() { return m_datas; }
0295         const std::vector<std::vector<const T*>>* operator->() { return &m_datas; }
0296 
0297         const std::vector<const T*>& operator()(size_t index) { return m_datas.at(index); }
0298 
0299 
0300     private:
0301         friend class JComponentT;
0302 
0303         void GetCollection(const JEvent& event) {
0304             m_datas.clear();
0305             if (!m_requested_databundle_names.empty()) {
0306                 // We have a nonempty input, so we provide the user exactly the inputs they asked for (some of these may be null IF is_optional=true)
0307                 if (m_level == event.GetLevel() || m_level == JEventLevel::None) {
0308                     size_t i=0;
0309                     for (auto& tag : m_requested_databundle_names) {
0310                         m_datas.push_back({});
0311                         event.Get<T>(m_datas.at(i++), tag, !m_is_optional);
0312                     }
0313                 }
0314                 else {
0315                     if (m_is_optional && !event.HasParent(m_level)) return;
0316                     auto& parent = event.GetParent(m_level);
0317                     size_t i=0;
0318                     for (auto& tag : m_requested_databundle_names) {
0319                         m_datas.push_back({});
0320                         parent.template Get<T>(m_datas.at(i++), tag, !m_is_optional);
0321                     }
0322                 }
0323             }
0324             else if (m_empty_input_policy == EmptyInputPolicy::IncludeEverything) {
0325                 // We have an empty input and a nontrivial empty input policy
0326                 m_realized_databundle_names.clear();
0327 
0328                 if (m_level == event.GetLevel() || m_level == JEventLevel::None) {
0329                     // We are fetching from the JEvent we already have
0330                     auto facs = event.GetFactorySet()->template GetAllFactories<T>();
0331                     size_t i=0;
0332                     for (auto* fac : facs) {
0333                         m_datas.push_back({});                                   // Create a destination for this factory's data
0334                         auto iters = fac->CreateAndGetData(event);
0335                         auto& dest = m_datas.at(i);
0336                         dest.insert(dest.end(), iters.first, iters.second);
0337                         m_realized_databundle_names.push_back(fac->GetTag());
0338                         i += 1;
0339                     }
0340                 }
0341                 else {
0342                     // We are fetching from a parent event
0343                     if (m_is_optional && !event.HasParent(m_level)) return;      // Short-circuit if optional and parent missing
0344                     auto& parent = event.GetParent(m_level);                     // GetParent throws if parent missing
0345                     auto facs = parent.GetFactorySet()->template GetAllFactories<T>();
0346                     size_t i=0;
0347                     for (auto* fac : facs) {
0348                         m_datas.push_back({});                                   // Create a destination for this factory's data
0349                         auto iters = fac->CreateAndGetData(event);
0350                         auto& dest = m_datas.at(i);
0351                         dest.insert(dest.end(), iters.first, iters.second);
0352                         m_realized_databundle_names.push_back(fac->GetTag());
0353                         i += 1;
0354                     }
0355                 }
0356             }
0357         }
0358         void PrefetchCollection(const JEvent& event) {
0359             if (m_level == event.GetLevel() || m_level == JEventLevel::None) {
0360                 for (auto& tag : m_requested_databundle_names) {
0361                     event.GetFactory<T>(tag, !m_is_optional)->Create(event);
0362                 }
0363             }
0364             else {
0365                 if (m_is_optional && !event.HasParent(m_level)) return;
0366                 auto& parent = event.GetParent(m_level);
0367                 for (auto& tag : m_requested_databundle_names) {
0368                     parent.template GetFactory<T>(tag, !m_is_optional)->Create(event);
0369                 }
0370             }
0371         }
0372     };
0373 
0374 
0375 
0376 #if JANA2_HAVE_PODIO
0377     template <typename PodioT>
0378     class VariadicPodioInput : public VariadicInputBase {
0379 
0380         std::vector<const typename PodioT::collection_type*> m_datas;
0381 
0382     public:
0383 
0384         VariadicPodioInput(JHasInputs* owner) {
0385             owner->RegisterInput(this);
0386             m_type_name = JTypeInfo::demangle<PodioT>();
0387         }
0388 
0389         VariadicPodioInput(JHasInputs* owner, const VariadicInputOptions& options) {
0390             owner->RegisterInput(this);
0391             m_type_name = JTypeInfo::demangle<PodioT>();
0392             Configure(options);
0393         }
0394 
0395         const std::vector<const typename PodioT::collection_type*> operator()() {
0396             return m_datas;
0397         }
0398 
0399         void SetRequestedCollectionNames(std::vector<std::string> names) {
0400             m_requested_databundle_names = names;
0401             m_realized_databundle_names = std::move(names);
0402         }
0403 
0404         const std::vector<std::string>& GetRealizedCollectionNames() {
0405             return GetRealizedDatabundleNames();
0406         }
0407 
0408         void GetCollection(const JEvent& event) {
0409             bool need_dynamic_realized_databundle_names = (m_requested_databundle_names.empty()) && (m_empty_input_policy != EmptyInputPolicy::IncludeNothing);
0410             if (need_dynamic_realized_databundle_names) {
0411                 m_realized_databundle_names.clear();
0412             }
0413             m_datas.clear();
0414             if (!m_requested_databundle_names.empty()) {
0415                 for (auto& coll_name : m_requested_databundle_names) {
0416                     if (m_level == event.GetLevel() || m_level == JEventLevel::None) {
0417                         auto coll = event.GetCollection<PodioT>(coll_name, !m_is_optional);
0418                         m_datas.push_back(coll);
0419                     }
0420                     else {
0421                         if (m_is_optional && !event.HasParent(m_level)) return;
0422                         auto coll = event.GetParent(m_level).template GetCollection<PodioT>(coll_name, !m_is_optional);
0423                         m_datas.push_back(coll);
0424                     }
0425                 }
0426             }
0427             else if (m_empty_input_policy == EmptyInputPolicy::IncludeEverything) {
0428                 auto facs = event.GetFactorySet()->GetAllFactories<PodioT>();
0429                 for (auto* fac : facs) {
0430                     JFactoryPodioT<PodioT>* podio_fac = dynamic_cast<JFactoryPodioT<PodioT>*>(fac);
0431                     if (podio_fac == nullptr) {
0432                         throw JException("Found factory which is NOT a podio factory!");
0433                     }
0434                     auto typed_collection = dynamic_cast<const typename PodioT::collection_type*>(podio_fac->GetCollection());
0435                     m_datas.push_back(typed_collection);
0436                     if (need_dynamic_realized_databundle_names) {
0437                         m_realized_databundle_names.push_back(podio_fac->GetTag());
0438                     }
0439                 }
0440             }
0441         }
0442 
0443         void PrefetchCollection(const JEvent& event) {
0444             if (!m_requested_databundle_names.empty()) {
0445                 for (auto& coll_name : m_requested_databundle_names) {
0446                     if (m_level == event.GetLevel() || m_level == JEventLevel::None) {
0447                         event.GetCollection<PodioT>(coll_name, !m_is_optional);
0448                     }
0449                     else {
0450                         if (m_is_optional && !event.HasParent(m_level)) return;
0451                         event.GetParent(m_level).template GetCollection<PodioT>(coll_name, !m_is_optional);
0452                     }
0453                 }
0454             }
0455             else if (m_empty_input_policy == EmptyInputPolicy::IncludeEverything) {
0456                 auto facs = event.GetFactorySet()->GetAllFactories<PodioT>();
0457                 for (auto* fac : facs) {
0458                     fac->Create(event);
0459                 }
0460             }
0461         }
0462     };
0463 #endif
0464     void WireInputs(JEventLevel component_level,
0465                     const std::vector<JEventLevel>& single_input_levels,
0466                     const std::vector<std::string>& single_input_databundle_names,
0467                     const std::vector<JEventLevel>& variadic_input_levels,
0468                     const std::vector<std::vector<std::string>>& variadic_input_databundle_names) {
0469 
0470         // Validate that we have the correct number of input databundle names
0471         if (single_input_databundle_names.size() != m_inputs.size()) {
0472             throw JException("Wrong number of (nonvariadic) input databundle names! Expected %d, found %d", m_inputs.size(), single_input_databundle_names.size());
0473         }
0474 
0475         if (variadic_input_databundle_names.size() != m_variadic_inputs.size()) {
0476             throw JException("Wrong number of variadic input databundle names! Expected %d, found %d", m_variadic_inputs.size(), variadic_input_databundle_names.size());
0477         }
0478 
0479         size_t i = 0;
0480         for (auto* input : m_inputs) {
0481             input->SetDatabundleName(single_input_databundle_names.at(i));
0482             if (single_input_levels.empty()) {
0483                 input->SetLevel(component_level);
0484             }
0485             else {
0486                 input->SetLevel(single_input_levels.at(i));
0487             }
0488             i += 1;
0489         }
0490 
0491         i = 0;
0492         for (auto* variadic_input : m_variadic_inputs) {
0493             variadic_input->SetRequestedDatabundleNames(variadic_input_databundle_names.at(i));
0494             if (variadic_input_levels.empty()) {
0495                 variadic_input->SetLevel(component_level);
0496             }
0497             else {
0498                 variadic_input->SetLevel(variadic_input_levels.at(i));
0499             }
0500             i += 1;
0501         }
0502     }
0503 
0504     void SummarizeInputs(JComponentSummary::Component& summary) const {
0505         for (const auto* input : m_inputs) {
0506             summary.AddInput(new JComponentSummary::Collection("", input->GetDatabundleName(), input->GetTypeName(), input->GetLevel()));
0507         }
0508         for (const auto* input : m_variadic_inputs) {
0509             for (auto& databundle_name : input->GetRequestedDatabundleNames()) {
0510                 summary.AddInput(new JComponentSummary::Collection("", databundle_name, input->GetTypeName(), input->GetLevel()));
0511             }
0512         }
0513     }
0514 };
0515 
0516 } // namespace jana::components
0517