Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-07-01 07:56:30

0001 // Copyright 2023, Jefferson Science Associates, LLC.
0002 // Subject to the terms in the LICENSE file found in the top-level directory.
0003 // Created by Nathan Brei
0004 
0005 #pragma once
0006 
0007 /**
0008  * Omnifactories are a lightweight layer connecting JANA to generic algorithms
0009  * It is assumed multiple input data (controlled by input tags)
0010  * which might be changed by user parameters.
0011  */
0012 
0013 #include <JANA/CLI/JVersion.h>
0014 #include <JANA/JMultifactory.h>
0015 #include <JANA/JEvent.h>
0016 #include <spdlog/spdlog.h>
0017 #include <spdlog/version.h>
0018 #if SPDLOG_VERSION >= 11400 && (!defined(SPDLOG_NO_TLS) || !SPDLOG_NO_TLS)
0019 #include <spdlog/mdc.h>
0020 #endif
0021 
0022 #include "services/io/podio/datamodel_glue.h"
0023 #include "services/log/Log_service.h"
0024 
0025 #include <string>
0026 #include <vector>
0027 
0028 struct EmptyConfig {};
0029 
0030 template <typename AlgoT, typename ConfigT = EmptyConfig>
0031 class JOmniFactory : public JMultifactory {
0032 public:
0033   /// ========================
0034   /// Handle input collections
0035   /// ========================
0036 
0037   struct InputBase {
0038     std::string type_name;
0039     std::vector<std::string> collection_names;
0040     bool is_variadic = false;
0041 
0042     virtual void GetCollection(const JEvent& event) = 0;
0043   };
0044 
0045   template <typename T, bool IsOptional = false> class Input : public InputBase {
0046 
0047     std::vector<const T*> m_data;
0048 
0049   public:
0050     Input(JOmniFactory* owner, std::string default_tag = "") {
0051       owner->RegisterInput(this);
0052       this->collection_names.push_back(default_tag);
0053       this->type_name = JTypeInfo::demangle<T>();
0054     }
0055 
0056     const std::vector<const T*>& operator()() { return m_data; }
0057 
0058   private:
0059     friend class JOmniFactory;
0060 
0061     void GetCollection(const JEvent& event) {
0062       try {
0063         m_data = event.Get<T>(this->collection_names[0], !IsOptional);
0064       } catch (const JException& e) {
0065         if constexpr (!IsOptional) {
0066           throw JException("JOmniFactory: Failed to get collection %s: %s",
0067                            this->collection_names[0].c_str(), e.what());
0068         }
0069       }
0070     }
0071   };
0072 
0073   template <typename PodioT, bool IsOptional = false> class PodioInput : public InputBase {
0074 
0075     const typename PodioTypeMap<PodioT>::collection_t* m_data;
0076 
0077   public:
0078     PodioInput(JOmniFactory* owner, std::string default_collection_name = "") {
0079       owner->RegisterInput(this);
0080       this->collection_names.push_back(default_collection_name);
0081       this->type_name = JTypeInfo::demangle<PodioT>();
0082     }
0083 
0084     const typename PodioTypeMap<PodioT>::collection_t* operator()() { return m_data; }
0085 
0086   private:
0087     friend class JOmniFactory;
0088 
0089     void GetCollection(const JEvent& event) {
0090       try {
0091         m_data = event.GetCollection<PodioT>(this->collection_names[0], !IsOptional);
0092       } catch (const JException& e) {
0093         if constexpr (!IsOptional) {
0094           throw JException("JOmniFactory: Failed to get collection %s: %s",
0095                            this->collection_names[0].c_str(), e.what());
0096         }
0097       }
0098     }
0099   };
0100 
0101   template <typename PodioT, bool IsOptional = false> class VariadicPodioInput : public InputBase {
0102 
0103     std::vector<const typename PodioTypeMap<PodioT>::collection_t*> m_data;
0104 
0105   public:
0106     VariadicPodioInput(JOmniFactory* owner, std::vector<std::string> default_names = {}) {
0107       owner->RegisterInput(this);
0108       this->collection_names = default_names;
0109       this->type_name        = JTypeInfo::demangle<PodioT>();
0110       this->is_variadic      = true;
0111     }
0112 
0113     const std::vector<const typename PodioTypeMap<PodioT>::collection_t*> operator()() {
0114       return m_data;
0115     }
0116 
0117   private:
0118     friend class JOmniFactory;
0119 
0120     void GetCollection(const JEvent& event) {
0121       m_data.clear();
0122       for (auto& coll_name : this->collection_names) {
0123         try {
0124           m_data.push_back(event.GetCollection<PodioT>(coll_name, !IsOptional));
0125         } catch (const JException& e) {
0126           if constexpr (!IsOptional) {
0127             throw JException("JOmniFactory: Failed to get collection %s: %s", coll_name.c_str(),
0128                              e.what());
0129           }
0130         }
0131       }
0132     }
0133   };
0134 
0135   void RegisterInput(InputBase* input) { m_inputs.push_back(input); }
0136 
0137   /// =========================
0138   /// Handle output collections
0139   /// =========================
0140 
0141   struct OutputBase {
0142     std::string type_name;
0143     std::vector<std::string> collection_names;
0144     bool is_variadic = false;
0145 
0146     virtual void CreateHelperFactory(JOmniFactory& fac) = 0;
0147     virtual void SetCollection(JOmniFactory& fac)       = 0;
0148     virtual void Reset()                                = 0;
0149   };
0150 
0151   template <typename T> class Output : public OutputBase {
0152     std::vector<T*> m_data;
0153 
0154   public:
0155     Output(JOmniFactory* owner, std::string default_tag_name = "") {
0156       owner->RegisterOutput(this);
0157       this->collection_names.push_back(default_tag_name);
0158       this->type_name = JTypeInfo::demangle<T>();
0159     }
0160 
0161     std::vector<T*>& operator()() { return m_data; }
0162 
0163   private:
0164     friend class JOmniFactory;
0165 
0166     void CreateHelperFactory(JOmniFactory& fac) override {
0167       fac.DeclareOutput<T>(this->collection_names[0]);
0168     }
0169 
0170     void SetCollection(JOmniFactory& fac) override {
0171       fac.SetData<T>(this->collection_names[0], this->m_data);
0172     }
0173 
0174     void Reset() override { m_data.clear(); }
0175   };
0176 
0177   template <typename PodioT> class PodioOutput : public OutputBase {
0178 
0179     std::unique_ptr<typename PodioTypeMap<PodioT>::collection_t> m_data;
0180 
0181   public:
0182     PodioOutput(JOmniFactory* owner, std::string default_collection_name = "") {
0183       owner->RegisterOutput(this);
0184       this->collection_names.push_back(default_collection_name);
0185       this->type_name = JTypeInfo::demangle<PodioT>();
0186     }
0187 
0188     std::unique_ptr<typename PodioTypeMap<PodioT>::collection_t>& operator()() { return m_data; }
0189 
0190   private:
0191     friend class JOmniFactory;
0192 
0193     void CreateHelperFactory(JOmniFactory& fac) override {
0194       fac.DeclarePodioOutput<PodioT>(this->collection_names[0]);
0195     }
0196 
0197     void SetCollection(JOmniFactory& fac) override {
0198       if (m_data == nullptr) {
0199         throw JException("JOmniFactory: SetCollection failed due to missing output collection '%s'",
0200                          this->collection_names[0].c_str());
0201         // Otherwise this leads to a PODIO segfault
0202       }
0203       fac.SetCollection<PodioT>(this->collection_names[0], std::move(this->m_data));
0204     }
0205 
0206     void Reset() override {
0207       m_data = std::move(std::make_unique<typename PodioTypeMap<PodioT>::collection_t>());
0208     }
0209   };
0210 
0211   template <typename PodioT> class VariadicPodioOutput : public OutputBase {
0212 
0213     std::vector<std::unique_ptr<typename PodioTypeMap<PodioT>::collection_t>> m_data;
0214 
0215   public:
0216     VariadicPodioOutput(JOmniFactory* owner,
0217                         std::vector<std::string> default_collection_names = {}) {
0218       owner->RegisterOutput(this);
0219       this->collection_names = default_collection_names;
0220       this->type_name        = JTypeInfo::demangle<PodioT>();
0221       this->is_variadic      = true;
0222     }
0223 
0224     std::vector<std::unique_ptr<typename PodioTypeMap<PodioT>::collection_t>>& operator()() {
0225       return m_data;
0226     }
0227 
0228   private:
0229     friend class JOmniFactory;
0230 
0231     void CreateHelperFactory(JOmniFactory& fac) override {
0232       for (auto& coll_name : this->collection_names) {
0233         fac.DeclarePodioOutput<PodioT>(coll_name);
0234       }
0235     }
0236 
0237     void SetCollection(JOmniFactory& fac) override {
0238       if (m_data.size() != this->collection_names.size()) {
0239         throw JException("JOmniFactory: VariadicPodioOutput SetCollection failed: Declared %d "
0240                          "collections, but provided %d.",
0241                          this->collection_names.size(), m_data.size());
0242         // Otherwise this leads to a PODIO segfault
0243       }
0244       std::size_t i = 0;
0245       for (auto& coll_name : this->collection_names) {
0246         fac.SetCollection<PodioT>(coll_name, std::move(this->m_data[i++]));
0247       }
0248     }
0249 
0250     void Reset() override {
0251       m_data.clear();
0252       for (auto& coll_name [[maybe_unused]] : this->collection_names) {
0253         m_data.push_back(std::make_unique<typename PodioTypeMap<PodioT>::collection_t>());
0254       }
0255     }
0256   };
0257 
0258   void RegisterOutput(OutputBase* output) { m_outputs.push_back(output); }
0259 
0260   // =================
0261   // Handle parameters
0262   // =================
0263 
0264   struct ParameterBase {
0265     std::string m_name;
0266     std::string m_description;
0267     virtual void Configure(JParameterManager& parman, const std::string& prefix) = 0;
0268     virtual void Configure(std::map<std::string, std::string> fields)            = 0;
0269   };
0270 
0271   template <typename T> class ParameterRef : public ParameterBase {
0272 
0273     T* m_data;
0274 
0275   public:
0276     ParameterRef(JOmniFactory* owner, std::string name, T& slot, std::string description = "") {
0277       owner->RegisterParameter(this);
0278       this->m_name        = name;
0279       this->m_description = description;
0280       m_data              = &slot;
0281     }
0282 
0283     const T& operator()() { return *m_data; }
0284 
0285   private:
0286     friend class JOmniFactory;
0287 
0288     void Configure(JParameterManager& parman, const std::string& prefix) override {
0289       parman.SetDefaultParameter(prefix + ":" + this->m_name, *m_data, this->m_description);
0290     }
0291     void Configure(std::map<std::string, std::string> fields) override {
0292       auto it = fields.find(this->m_name);
0293       if (it != fields.end()) {
0294         const auto& value_str = it->second;
0295         if constexpr (10000 * JVersion::major + 100 * JVersion::minor + 1 * JVersion::patch <
0296                       20102) {
0297           *m_data = JParameterManager::Parse<T>(value_str);
0298         } else {
0299           JParameterManager::Parse(value_str, *m_data);
0300         }
0301       }
0302     }
0303   };
0304 
0305   template <typename T> class Parameter : public ParameterBase {
0306 
0307     T m_data;
0308 
0309   public:
0310     Parameter(JOmniFactory* owner, std::string name, T default_value, std::string description) {
0311       owner->RegisterParameter(this);
0312       this->m_name        = name;
0313       this->m_description = description;
0314       m_data              = default_value;
0315     }
0316 
0317     const T& operator()() { return m_data; }
0318 
0319   private:
0320     friend class JOmniFactory;
0321 
0322     void Configure(JParameterManager& parman, const std::string& /* prefix */) override {
0323       parman.SetDefaultParameter(m_prefix + ":" + this->m_name, m_data, this->m_description);
0324     }
0325     void Configure(std::map<std::string, std::string> fields) override {
0326       auto it = fields.find(this->m_name);
0327       if (it != fields.end()) {
0328         const auto& value_str = it->second;
0329         if constexpr (10000 * JVersion::major + 100 * JVersion::minor + 1 * JVersion::patch <
0330                       20102) {
0331           m_data = JParameterManager::Parse<T>(value_str);
0332         } else {
0333           JParameterManager::Parse(value_str, m_data);
0334         }
0335       }
0336     }
0337   };
0338 
0339   void RegisterParameter(ParameterBase* parameter) { m_parameters.push_back(parameter); }
0340 
0341   void ConfigureAllParameters(std::map<std::string, std::string> fields) {
0342     for (auto* parameter : this->m_parameters) {
0343       parameter->Configure(fields);
0344     }
0345   }
0346 
0347   // ===============
0348   // Handle services
0349   // ===============
0350 
0351   struct ServiceBase {
0352     virtual void Init(JApplication* app) = 0;
0353   };
0354 
0355   template <typename ServiceT> class Service : public ServiceBase {
0356 
0357     std::shared_ptr<ServiceT> m_data;
0358 
0359   public:
0360     Service(JOmniFactory* owner) { owner->RegisterService(this); }
0361 
0362     ServiceT& operator()() { return *m_data; }
0363 
0364   private:
0365     friend class JOmniFactory;
0366 
0367     void Init(JApplication* app) { m_data = app->GetService<ServiceT>(); }
0368   };
0369 
0370   void RegisterService(ServiceBase* service) { m_services.push_back(service); }
0371 
0372   // ================
0373   // Handle resources
0374   // ================
0375 
0376   struct ResourceBase {
0377     virtual void ChangeRun(const JEvent& event) = 0;
0378   };
0379 
0380   template <typename ServiceT, typename ResourceT, typename LambdaT>
0381   class Resource : public ResourceBase {
0382     ResourceT m_data;
0383     LambdaT m_lambda;
0384 
0385   public:
0386     Resource(JOmniFactory* owner, LambdaT lambda) : m_lambda(lambda) {
0387       owner->RegisterResource(this);
0388     };
0389 
0390     const ResourceT& operator()() { return m_data; }
0391 
0392   private:
0393     friend class JOmniFactory;
0394 
0395     void ChangeRun(const JEvent& event) {
0396       auto run_nr                       = event.GetRunNumber();
0397       std::shared_ptr<ServiceT> service = event.GetJApplication()->template GetService<ServiceT>();
0398       m_data                            = m_lambda(service, run_nr);
0399     }
0400   };
0401 
0402   void RegisterResource(ResourceBase* resource) { m_resources.push_back(resource); }
0403 
0404 public:
0405   std::vector<InputBase*> m_inputs;
0406   std::vector<OutputBase*> m_outputs;
0407   std::vector<ParameterBase*> m_parameters;
0408   std::vector<ServiceBase*> m_services;
0409   std::vector<ResourceBase*> m_resources;
0410 
0411 private:
0412   // App belongs on JMultifactory, it is just missing temporarily
0413   JApplication* m_app;
0414 
0415   // Plugin name belongs on JMultifactory, it is just missing temporarily
0416   std::string m_plugin_name;
0417 
0418   // Prefix for parameters and loggers, derived from plugin name and tag in PreInit().
0419   std::string m_prefix;
0420 
0421   /// Current logger
0422   std::shared_ptr<spdlog::logger> m_logger;
0423 
0424   /// Configuration
0425   ConfigT m_config;
0426 
0427 public:
0428   std::size_t FindVariadicCollectionCount(std::size_t total_input_count,
0429                                           std::size_t variadic_input_count,
0430                                           std::size_t total_collection_count, bool is_input) {
0431 
0432     std::size_t variadic_collection_count =
0433         total_collection_count - (total_input_count - variadic_input_count);
0434 
0435     if (variadic_input_count == 0) {
0436       // No variadic inputs: check that collection_name count matches input count exactly
0437       if (total_input_count != total_collection_count) {
0438         throw JException(
0439             "JOmniFactory '%s': Wrong number of %s collection names: %d expected, %d found.",
0440             m_prefix.c_str(), (is_input ? "input" : "output"), total_input_count,
0441             total_collection_count);
0442       }
0443     } else {
0444       // Variadic inputs: check that we have enough collection names for the non-variadic inputs
0445       if (total_input_count - variadic_input_count > total_collection_count) {
0446         throw JException("JOmniFactory '%s': Not enough %s collection names: %d needed, %d found.",
0447                          m_prefix.c_str(), (is_input ? "input" : "output"),
0448                          total_input_count - variadic_input_count, total_collection_count);
0449       }
0450 
0451       // Variadic inputs: check that the variadic collection names is evenly divided by the variadic input count
0452       if (variadic_collection_count % variadic_input_count != 0) {
0453         throw JException("JOmniFactory '%s': Wrong number of %s collection names: %d found total, "
0454                          "but %d can't be distributed among %d variadic inputs evenly.",
0455                          m_prefix.c_str(), (is_input ? "input" : "output"), total_collection_count,
0456                          variadic_collection_count, variadic_input_count);
0457       }
0458     }
0459     return variadic_collection_count;
0460   }
0461 
0462   inline void PreInit(std::string tag, std::vector<std::string> default_input_collection_names,
0463                       std::vector<std::string> default_output_collection_names) {
0464 
0465     m_prefix = (this->GetPluginName().empty()) ? tag : this->GetPluginName() + ":" + tag;
0466 
0467     // Obtain collection name overrides if provided.
0468     // Priority = [JParameterManager, JOmniFactoryGenerator]
0469     m_app->SetDefaultParameter(m_prefix + ":InputTags", default_input_collection_names,
0470                                "Input collection names");
0471     m_app->SetDefaultParameter(m_prefix + ":OutputTags", default_output_collection_names,
0472                                "Output collection names");
0473 
0474     // Figure out variadic inputs
0475     std::size_t variadic_input_count = 0;
0476     for (auto* input : m_inputs) {
0477       if (input->is_variadic) {
0478         variadic_input_count += 1;
0479       }
0480     }
0481     std::size_t variadic_input_collection_count = FindVariadicCollectionCount(
0482         m_inputs.size(), variadic_input_count, default_input_collection_names.size(), true);
0483 
0484     // Set input collection names
0485     for (std::size_t i = 0; auto* input : m_inputs) {
0486       input->collection_names.clear();
0487       if (input->is_variadic) {
0488         for (std::size_t j = 0; j < (variadic_input_collection_count / variadic_input_count); ++j) {
0489           input->collection_names.push_back(default_input_collection_names[i++]);
0490         }
0491       } else {
0492         input->collection_names.push_back(default_input_collection_names[i++]);
0493       }
0494     }
0495 
0496     // Figure out variadic outputs
0497     std::size_t variadic_output_count = 0;
0498     for (auto* output : m_outputs) {
0499       if (output->is_variadic) {
0500         variadic_output_count += 1;
0501       }
0502     }
0503     std::size_t variadic_output_collection_count = FindVariadicCollectionCount(
0504         m_outputs.size(), variadic_output_count, default_output_collection_names.size(), true);
0505 
0506     // Set output collection names and create corresponding helper factories
0507     for (std::size_t i = 0; auto* output : m_outputs) {
0508       output->collection_names.clear();
0509       if (output->is_variadic) {
0510         for (std::size_t j = 0; j < (variadic_output_collection_count / variadic_output_count);
0511              ++j) {
0512           output->collection_names.push_back(default_output_collection_names[i++]);
0513         }
0514       } else {
0515         output->collection_names.push_back(default_output_collection_names[i++]);
0516       }
0517       output->CreateHelperFactory(*this);
0518     }
0519 
0520     // Obtain logger (defines the parameter option)
0521     m_logger = m_app->GetService<Log_service>()->logger(m_prefix);
0522   }
0523 
0524   void Init() override {
0525     auto app = GetApplication();
0526     for (auto* parameter : m_parameters) {
0527       parameter->Configure(*(app->GetJParameterManager()), m_prefix);
0528     }
0529     for (auto* service : m_services) {
0530       service->Init(app);
0531     }
0532     static_cast<AlgoT*>(this)->Configure();
0533   }
0534 
0535   void BeginRun(const std::shared_ptr<const JEvent>& event) override {
0536     for (auto* resource : m_resources) {
0537       resource->ChangeRun(*event);
0538     }
0539     static_cast<AlgoT*>(this)->ChangeRun(event->GetRunNumber());
0540   }
0541 
0542   virtual void Process(int32_t /* run_number */, uint64_t /* event_number */){};
0543 
0544   void Process(const std::shared_ptr<const JEvent>& event) override {
0545     try {
0546       for (auto* input : m_inputs) {
0547         input->GetCollection(*event);
0548       }
0549       for (auto* output : m_outputs) {
0550         output->Reset();
0551       }
0552 #if SPDLOG_VERSION >= 11400 && (!defined(SPDLOG_NO_TLS) || !SPDLOG_NO_TLS)
0553       spdlog::mdc::put("e", std::to_string(event->GetEventNumber()));
0554 #endif
0555       static_cast<AlgoT*>(this)->Process(event->GetRunNumber(), event->GetEventNumber());
0556       for (auto* output : m_outputs) {
0557         output->SetCollection(*this);
0558       }
0559     } catch (std::exception& e) {
0560       throw JException(e.what());
0561     }
0562   }
0563 
0564   using ConfigType = ConfigT;
0565 
0566   void SetApplication(JApplication* app) { m_app = app; }
0567 
0568   JApplication* GetApplication() { return m_app; }
0569 
0570   void SetPluginName(std::string plugin_name) { m_plugin_name = plugin_name; }
0571 
0572   std::string GetPluginName() { return m_plugin_name; }
0573 
0574   inline std::string GetPrefix() { return m_prefix; }
0575 
0576   /// Retrieve reference to already-configured logger
0577   std::shared_ptr<spdlog::logger>& logger() { return m_logger; }
0578 
0579   /// Retrieve reference to embedded config object
0580   ConfigT& config() { return m_config; }
0581 };