Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:17:33

0001 // Copyright 2023, Jefferson Science Associates, LLC.
0002 // Subject to the terms in the LICENSE file found in the top-level directory.
0003 // Created by Nathan Brei
0004 
0005 #pragma once
0006 
0007 /**
0008  * Omnifactories are a lightweight layer connecting JANA to generic algorithms
0009  * It is assumed multiple input data (controlled by input tags)
0010  * which might be changed by user parameters.
0011  */
0012 
0013 #include "JANA/Services/JParameterManager.h"
0014 #include <JANA/JEvent.h>
0015 #include <JANA/JMultifactory.h>
0016 #include <JANA/JVersion.h>
0017 #include <JANA/Components/JHasInputs.h>
0018 
0019 #include <string>
0020 #include <vector>
0021 
0022 namespace jana::components {
0023 
0024 struct EmptyConfig {};
0025 
0026 template <typename AlgoT, typename ConfigT=EmptyConfig>
0027 class JOmniFactory : public JMultifactory, public jana::components::JHasInputs {
0028 public:
0029 
0030     /// =========================
0031     /// Handle output collections
0032     /// =========================
0033 
0034     struct OutputBase {
0035         std::string type_name;
0036         std::vector<std::string> collection_names;
0037         bool is_variadic = false;
0038 
0039         virtual void CreateHelperFactory(JOmniFactory& fac) = 0;
0040         virtual void SetCollection(JOmniFactory& fac) = 0;
0041         virtual void Reset() = 0;
0042     };
0043 
0044     template <typename T>
0045     class Output : public OutputBase {
0046         std::vector<T*> m_data;
0047 
0048     public:
0049         Output(JOmniFactory* owner, std::string default_tag_name="") {
0050             owner->RegisterOutput(this);
0051             this->collection_names.push_back(default_tag_name);
0052             this->type_name = JTypeInfo::demangle<T>();
0053         }
0054 
0055         std::vector<T*>& operator()() { return m_data; }
0056 
0057     private:
0058         friend class JOmniFactory;
0059 
0060         void CreateHelperFactory(JOmniFactory& fac) override {
0061             fac.DeclareOutput<T>(this->collection_names[0]);
0062         }
0063 
0064         void SetCollection(JOmniFactory& fac) override {
0065             fac.SetData<T>(this->collection_names[0], this->m_data);
0066         }
0067 
0068         void Reset() override { }
0069     };
0070 
0071 
0072 #if JANA2_HAVE_PODIO
0073     template <typename PodioT>
0074     class PodioOutput : public OutputBase {
0075 
0076         std::unique_ptr<typename PodioT::collection_type> m_data;
0077 
0078     public:
0079 
0080         PodioOutput(JOmniFactory* owner, std::string default_collection_name="") {
0081             owner->RegisterOutput(this);
0082             this->collection_names.push_back(default_collection_name);
0083             this->type_name = JTypeInfo::demangle<PodioT>();
0084         }
0085 
0086         std::unique_ptr<typename PodioT::collection_type>& operator()() { return m_data; }
0087 
0088     private:
0089         friend class JOmniFactory;
0090 
0091         void CreateHelperFactory(JOmniFactory& fac) override {
0092             fac.DeclarePodioOutput<PodioT>(this->collection_names[0]);
0093         }
0094 
0095         void SetCollection(JOmniFactory& fac) override {
0096             if (m_data == nullptr) {
0097                 throw JException("JOmniFactory: SetCollection failed due to missing output collection '%s'", this->collection_names[0].c_str());
0098                 // Otherwise this leads to a PODIO segfault
0099             }
0100             fac.SetCollection<PodioT>(this->collection_names[0], std::move(this->m_data));
0101         }
0102 
0103         void Reset() override {
0104             m_data = std::move(std::make_unique<typename PodioT::collection_type>());
0105         }
0106     };
0107 
0108 
0109     template <typename PodioT>
0110     class VariadicPodioOutput : public OutputBase {
0111 
0112         std::vector<std::unique_ptr<typename PodioT::collection_type>> m_data;
0113 
0114     public:
0115 
0116         VariadicPodioOutput(JOmniFactory* owner, std::vector<std::string> default_collection_names={}) {
0117             owner->RegisterOutput(this);
0118             this->collection_names = default_collection_names;
0119             this->type_name = JTypeInfo::demangle<PodioT>();
0120             this->is_variadic = true;
0121         }
0122 
0123         std::vector<std::unique_ptr<typename PodioT::collection_type>>& operator()() { return m_data; }
0124 
0125     private:
0126         friend class JOmniFactory;
0127 
0128         void CreateHelperFactory(JOmniFactory& fac) override {
0129             for (auto& coll_name : this->collection_names) {
0130                 fac.DeclarePodioOutput<PodioT>(coll_name);
0131             }
0132         }
0133 
0134         void SetCollection(JOmniFactory& fac) override {
0135             if (m_data.size() != this->collection_names.size()) {
0136                 throw JException("JOmniFactory: VariadicPodioOutput SetCollection failed: Declared %d collections, but provided %d.", this->collection_names.size(), m_data.size());
0137                 // Otherwise this leads to a PODIO segfault
0138             }
0139             size_t i = 0;
0140             for (auto& coll_name : this->collection_names) {
0141                 fac.SetCollection<PodioT>(coll_name, std::move(this->m_data[i++]));
0142             }
0143         }
0144 
0145         void Reset() override {
0146             m_data.clear();
0147             for (auto& coll_name : this->collection_names) {
0148                 m_data.push_back(std::make_unique<typename PodioT::collection_type>());
0149             }
0150         }
0151     };
0152 #endif
0153 
0154     void RegisterOutput(OutputBase* output) {
0155         m_outputs.push_back(output);
0156     }
0157 
0158 
0159 public:
0160     std::vector<OutputBase*> m_outputs;
0161 
0162 private:
0163     /// Current logger
0164     JLogger m_logger;
0165 
0166     /// Configuration
0167     ConfigT m_config;
0168 
0169 public:
0170 
0171     size_t FindVariadicCollectionCount(size_t total_input_count, size_t variadic_input_count, size_t total_collection_count, bool is_input) {
0172 
0173         size_t variadic_collection_count = total_collection_count - (total_input_count - variadic_input_count);
0174 
0175         if (variadic_input_count == 0) {
0176             // No variadic inputs: check that collection_name count matches input count exactly
0177             if (total_input_count != total_collection_count) {
0178                 throw JException("JOmniFactory '%s': Wrong number of %s collection names: %d expected, %d found.",
0179                                 m_prefix.c_str(), (is_input ? "input" : "output"), total_input_count, total_collection_count);
0180             }
0181         }
0182         else {
0183             // Variadic inputs: check that we have enough collection names for the non-variadic inputs
0184             if (total_input_count-variadic_input_count > total_collection_count) {
0185                 throw JException("JOmniFactory '%s': Not enough %s collection names: %d needed, %d found.",
0186                                 m_prefix.c_str(), (is_input ? "input" : "output"), total_input_count-variadic_input_count, total_collection_count);
0187             }
0188 
0189             // Variadic inputs: check that the variadic collection names is evenly divided by the variadic input count
0190             if (variadic_collection_count % variadic_input_count != 0) {
0191                 throw JException("JOmniFactory '%s': Wrong number of %s collection names: %d found total, but %d can't be distributed among %d variadic inputs evenly.",
0192                                 m_prefix.c_str(), (is_input ? "input" : "output"), total_collection_count, variadic_collection_count, variadic_input_count);
0193             }
0194         }
0195         return variadic_collection_count;
0196     }
0197 
0198     inline void PreInit(std::string tag,
0199                         JEventLevel level,
0200                         std::vector<std::string> input_collection_names,
0201                         std::vector<JEventLevel> input_collection_levels,
0202                         std::vector<std::string> output_collection_names ) {
0203 
0204         m_prefix = (this->GetPluginName().empty()) ? tag : this->GetPluginName() + ":" + tag;
0205         m_level = level;
0206 
0207         // Obtain collection name overrides if provided.
0208         // Priority = [JParameterManager, JOmniFactoryGenerator]
0209         m_app->SetDefaultParameter(m_prefix + ":InputTags", input_collection_names, "Input collection names");
0210         m_app->SetDefaultParameter(m_prefix + ":OutputTags", output_collection_names, "Output collection names");
0211 
0212         // Figure out variadic inputs
0213         size_t variadic_input_count = 0;
0214         for (auto* input : m_inputs) {
0215             if (input->is_variadic) {
0216                variadic_input_count += 1;
0217             }
0218         }
0219         size_t variadic_input_collection_count = FindVariadicCollectionCount(m_inputs.size(), variadic_input_count, input_collection_names.size(), true);
0220 
0221         // Set input collection names
0222         size_t i = 0;
0223         for (auto* input : m_inputs) {
0224             input->names.clear();
0225             if (input->is_variadic) {
0226                 for (size_t j = 0; j<(variadic_input_collection_count/variadic_input_count); ++j) {
0227                     input->names.push_back(input_collection_names[i++]);
0228                     if (!input_collection_levels.empty()) {
0229                         input->levels.push_back(input_collection_levels[i++]);
0230                     }
0231                     else {
0232                         input->levels.push_back(level);
0233                     }
0234                 }
0235             }
0236             else {
0237                 input->names.push_back(input_collection_names[i++]);
0238                 if (!input_collection_levels.empty()) {
0239                     input->levels.push_back(input_collection_levels[i++]);
0240                 }
0241                 else {
0242                     input->levels.push_back(level);
0243                 }
0244             }
0245         }
0246 
0247         // Figure out variadic outputs
0248         size_t variadic_output_count = 0;
0249         for (auto* output : m_outputs) {
0250             if (output->is_variadic) {
0251                variadic_output_count += 1;
0252             }
0253         }
0254         size_t variadic_output_collection_count = FindVariadicCollectionCount(m_outputs.size(), variadic_output_count, output_collection_names.size(), true);
0255 
0256         // Set output collection names and create corresponding helper factories
0257         i = 0;
0258         for (auto* output : m_outputs) {
0259             output->collection_names.clear();
0260             if (output->is_variadic) {
0261                 for (size_t j = 0; j<(variadic_output_collection_count/variadic_output_count); ++j) {
0262                     output->collection_names.push_back(output_collection_names[i++]);
0263                 }
0264             }
0265             else {
0266                 output->collection_names.push_back(output_collection_names[i++]);
0267             }
0268             output->CreateHelperFactory(*this);
0269         }
0270 
0271         // Obtain logger
0272         m_logger = m_app->GetService<JParameterManager>()->GetLogger(m_prefix);
0273 
0274         // Configure logger. Priority = [JParameterManager, system log level]
0275         // std::string default_log_level = eicrecon::LogLevelToString(m_logger->level());
0276         // m_app->SetDefaultParameter(m_prefix + ":LogLevel", default_log_level, "LogLevel: trace, debug, info, warn, err, critical, off");
0277         // m_logger->set_level(eicrecon::ParseLogLevel(default_log_level));
0278     }
0279 
0280     void Init() override {
0281         for (auto* parameter : m_parameters) {
0282             parameter->Configure(*(m_app->GetJParameterManager()), m_prefix);
0283         }
0284         for (auto* service : m_services) {
0285             service->Fetch(m_app);
0286         }
0287         static_cast<AlgoT*>(this)->Configure();
0288     }
0289 
0290     void BeginRun(const std::shared_ptr<const JEvent>& event) override {
0291         for (auto* resource : m_resources) {
0292             resource->ChangeRun(event->GetRunNumber(), m_app);
0293         }
0294         static_cast<AlgoT*>(this)->ChangeRun(event->GetRunNumber());
0295     }
0296 
0297     void Process(const std::shared_ptr<const JEvent> &event) override {
0298         try {
0299             for (auto* input : m_inputs) {
0300                 input->GetCollection(*event);
0301             }
0302             for (auto* output : m_outputs) {
0303                 output->Reset();
0304             }
0305             static_cast<AlgoT*>(this)->Execute(event->GetRunNumber(), event->GetEventNumber());
0306             for (auto* output : m_outputs) {
0307                 output->SetCollection(*this);
0308             }
0309         }
0310         catch(std::exception &e) {
0311             throw JException(e.what());
0312         }
0313     }
0314 
0315     using ConfigType = ConfigT;
0316 
0317     /// Retrieve reference to already-configured logger
0318     //std::shared_ptr<spdlog::logger> &logger() { return m_logger; }
0319     JLogger& logger() { return m_logger; }
0320 
0321     /// Retrieve reference to embedded config object
0322     ConfigT& config() { return m_config; }
0323 
0324 
0325     /// Generate summary for UI, inspector
0326     void Summarize(JComponentSummary& summary) const override {
0327 
0328         auto* mfs = new JComponentSummary::Component(
0329             "OmniFactory", GetPrefix(), GetTypeName(), GetLevel(), GetPluginName());
0330 
0331         for (const auto* input : m_inputs) {
0332             size_t subinput_count = input->names.size();
0333             for (size_t i=0; i<subinput_count; ++i) {
0334                 mfs->AddInput(new JComponentSummary::Collection("", input->names[i], input->type_name, input->levels[i]));
0335             }
0336         }
0337         for (const auto* output : m_outputs) {
0338             size_t suboutput_count = output->collection_names.size();
0339             for (size_t i=0; i<suboutput_count; ++i) {
0340                 mfs->AddOutput(new JComponentSummary::Collection("", output->collection_names[i], output->type_name, GetLevel()));
0341             }
0342         }
0343         summary.Add(mfs);
0344     }
0345 
0346 };
0347 
0348 } // namespace jana::components
0349 
0350 using jana::components::JOmniFactory;
0351 
0352