Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-03-31 07:48:27

0001 // Copyright 2023, Jefferson Science Associates, LLC.
0002 // Subject to the terms in the LICENSE file found in the top-level directory.
0003 // Created by Nathan Brei
0004 
0005 #pragma once
0006 
0007 /**
0008  * Omnifactories are a lightweight layer connecting JANA to generic algorithms
0009  * It is assumed multiple input data (controlled by input tags)
0010  * which might be changed by user parameters.
0011  */
0012 
0013 #include <JANA/JEvent.h>
0014 #include <JANA/JMultifactory.h>
0015 #include <spdlog/spdlog.h>
0016 #include <spdlog/version.h>
0017 #if SPDLOG_VERSION >= 11400 && (!defined(SPDLOG_NO_TLS) || !SPDLOG_NO_TLS)
0018 #include <spdlog/mdc.h>
0019 #endif
0020 
0021 #include "services/io/podio/datamodel_glue_compat.h"
0022 #include "services/log/Log_service.h"
0023 
0024 #include <string>
0025 #include <vector>
0026 
0027 struct EmptyConfig {};
0028 
0029 template <typename AlgoT, typename ConfigT = EmptyConfig>
0030 class JOmniFactory : public JMultifactory {
0031 public:
0032   /// ========================
0033   /// Handle input collections
0034   /// ========================
0035 
0036   struct InputBase {
0037     std::string type_name;
0038     std::vector<std::string> collection_names;
0039     bool is_variadic = false;
0040 
0041     virtual void GetCollection(const JEvent& event) = 0;
0042   };
0043 
0044   template <typename T, bool IsOptional = false> class Input : public InputBase {
0045 
0046     std::vector<const T*> m_data;
0047 
0048   public:
0049     Input(JOmniFactory* owner, std::string default_tag = "") {
0050       owner->RegisterInput(this);
0051       this->collection_names.push_back(default_tag);
0052       this->type_name = JTypeInfo::demangle<T>();
0053     }
0054 
0055     const std::vector<const T*>& operator()() { return m_data; }
0056 
0057   private:
0058     friend class JOmniFactory;
0059 
0060     void GetCollection(const JEvent& event) {
0061       try {
0062         m_data = event.Get<T>(this->collection_names[0], !IsOptional);
0063       } catch (const JException& e) {
0064         if constexpr (!IsOptional) {
0065           throw JException("JOmniFactory: Failed to get collection %s: %s",
0066                            this->collection_names[0].c_str(), e.what());
0067         }
0068       }
0069     }
0070   };
0071 
0072   template <typename PodioT, bool IsOptional = false> class PodioInput : public InputBase {
0073 
0074     const typename PodioTypeMap<PodioT>::collection_t* m_data;
0075 
0076   public:
0077     PodioInput(JOmniFactory* owner, std::string default_collection_name = "") {
0078       owner->RegisterInput(this);
0079       this->collection_names.push_back(default_collection_name);
0080       this->type_name = JTypeInfo::demangle<PodioT>();
0081     }
0082 
0083     const typename PodioTypeMap<PodioT>::collection_t* operator()() { return m_data; }
0084 
0085   private:
0086     friend class JOmniFactory;
0087 
0088     void GetCollection(const JEvent& event) {
0089       try {
0090         m_data = event.GetCollection<PodioT>(this->collection_names[0], !IsOptional);
0091       } catch (const JException& e) {
0092         if constexpr (!IsOptional) {
0093           throw JException("JOmniFactory: Failed to get collection %s: %s",
0094                            this->collection_names[0].c_str(), e.what());
0095         }
0096       }
0097     }
0098   };
0099 
0100   template <typename PodioT, bool IsOptional = false> class VariadicPodioInput : public InputBase {
0101 
0102     std::vector<const typename PodioTypeMap<PodioT>::collection_t*> m_data;
0103 
0104   public:
0105     VariadicPodioInput(JOmniFactory* owner, std::vector<std::string> default_names = {}) {
0106       owner->RegisterInput(this);
0107       this->collection_names = default_names;
0108       this->type_name        = JTypeInfo::demangle<PodioT>();
0109       this->is_variadic      = true;
0110     }
0111 
0112     const std::vector<const typename PodioTypeMap<PodioT>::collection_t*> operator()() {
0113       return m_data;
0114     }
0115 
0116   private:
0117     friend class JOmniFactory;
0118 
0119     void GetCollection(const JEvent& event) {
0120       m_data.clear();
0121       for (auto& coll_name : this->collection_names) {
0122         try {
0123           m_data.push_back(event.GetCollection<PodioT>(coll_name, !IsOptional));
0124         } catch (const JException& e) {
0125           if constexpr (!IsOptional) {
0126             throw JException("JOmniFactory: Failed to get collection %s: %s", coll_name.c_str(),
0127                              e.what());
0128           }
0129         }
0130       }
0131     }
0132   };
0133 
0134   void RegisterInput(InputBase* input) { m_inputs.push_back(input); }
0135 
0136   /// =========================
0137   /// Handle output collections
0138   /// =========================
0139 
0140   struct OutputBase {
0141     std::string type_name;
0142     std::vector<std::string> collection_names;
0143     bool is_variadic = false;
0144 
0145     virtual void CreateHelperFactory(JOmniFactory& fac) = 0;
0146     virtual void SetCollection(JOmniFactory& fac)       = 0;
0147     virtual void Reset()                                = 0;
0148   };
0149 
0150   template <typename T> class Output : public OutputBase {
0151     std::vector<T*> m_data;
0152 
0153   public:
0154     Output(JOmniFactory* owner, std::string default_tag_name = "") {
0155       owner->RegisterOutput(this);
0156       this->collection_names.push_back(default_tag_name);
0157       this->type_name = JTypeInfo::demangle<T>();
0158     }
0159 
0160     std::vector<T*>& operator()() { return m_data; }
0161 
0162   private:
0163     friend class JOmniFactory;
0164 
0165     void CreateHelperFactory(JOmniFactory& fac) override {
0166       fac.DeclareOutput<T>(this->collection_names[0]);
0167     }
0168 
0169     void SetCollection(JOmniFactory& fac) override {
0170       fac.SetData<T>(this->collection_names[0], this->m_data);
0171     }
0172 
0173     void Reset() override { m_data.clear(); }
0174   };
0175 
0176   template <typename PodioT> class PodioOutput : public OutputBase {
0177 
0178     std::unique_ptr<typename PodioTypeMap<PodioT>::collection_t> m_data;
0179 
0180   public:
0181     PodioOutput(JOmniFactory* owner, std::string default_collection_name = "") {
0182       owner->RegisterOutput(this);
0183       this->collection_names.push_back(default_collection_name);
0184       this->type_name = JTypeInfo::demangle<PodioT>();
0185     }
0186 
0187     std::unique_ptr<typename PodioTypeMap<PodioT>::collection_t>& operator()() { return m_data; }
0188 
0189   private:
0190     friend class JOmniFactory;
0191 
0192     void CreateHelperFactory(JOmniFactory& fac) override {
0193       fac.DeclarePodioOutput<PodioT>(this->collection_names[0]);
0194     }
0195 
0196     void SetCollection(JOmniFactory& fac) override {
0197       if (m_data == nullptr) {
0198         throw JException("JOmniFactory: SetCollection failed due to missing output collection '%s'",
0199                          this->collection_names[0].c_str());
0200         // Otherwise this leads to a PODIO segfault
0201       }
0202       fac.SetCollection<PodioT>(this->collection_names[0], std::move(this->m_data));
0203     }
0204 
0205     void Reset() override {
0206       m_data = std::move(std::make_unique<typename PodioTypeMap<PodioT>::collection_t>());
0207     }
0208   };
0209 
0210   template <typename PodioT> class VariadicPodioOutput : public OutputBase {
0211 
0212     std::vector<std::unique_ptr<typename PodioTypeMap<PodioT>::collection_t>> m_data;
0213 
0214   public:
0215     VariadicPodioOutput(JOmniFactory* owner,
0216                         std::vector<std::string> default_collection_names = {}) {
0217       owner->RegisterOutput(this);
0218       this->collection_names = default_collection_names;
0219       this->type_name        = JTypeInfo::demangle<PodioT>();
0220       this->is_variadic      = true;
0221     }
0222 
0223     std::vector<std::unique_ptr<typename PodioTypeMap<PodioT>::collection_t>>& operator()() {
0224       return m_data;
0225     }
0226 
0227   private:
0228     friend class JOmniFactory;
0229 
0230     void CreateHelperFactory(JOmniFactory& fac) override {
0231       for (auto& coll_name : this->collection_names) {
0232         fac.DeclarePodioOutput<PodioT>(coll_name);
0233       }
0234     }
0235 
0236     void SetCollection(JOmniFactory& fac) override {
0237       if (m_data.size() != this->collection_names.size()) {
0238         throw JException("JOmniFactory: VariadicPodioOutput SetCollection failed: Declared %d "
0239                          "collections, but provided %d.",
0240                          this->collection_names.size(), m_data.size());
0241         // Otherwise this leads to a PODIO segfault
0242       }
0243       std::size_t i = 0;
0244       for (auto& coll_name : this->collection_names) {
0245         fac.SetCollection<PodioT>(coll_name, std::move(this->m_data[i++]));
0246       }
0247     }
0248 
0249     void Reset() override {
0250       m_data.clear();
0251       for (auto& coll_name [[maybe_unused]] : this->collection_names) {
0252         m_data.push_back(std::make_unique<typename PodioTypeMap<PodioT>::collection_t>());
0253       }
0254     }
0255   };
0256 
0257   void RegisterOutput(OutputBase* output) { m_outputs.push_back(output); }
0258 
0259   // =================
0260   // Handle parameters
0261   // =================
0262 
0263   struct ParameterBase {
0264     std::string m_name;
0265     std::string m_description;
0266     virtual void Configure(JParameterManager& parman, const std::string& prefix) = 0;
0267     virtual void Configure(std::map<std::string, std::string> fields)            = 0;
0268   };
0269 
0270   template <typename T> class ParameterRef : public ParameterBase {
0271 
0272     T* m_data;
0273 
0274   public:
0275     ParameterRef(JOmniFactory* owner, std::string name, T& slot, std::string description = "") {
0276       owner->RegisterParameter(this);
0277       this->m_name        = name;
0278       this->m_description = description;
0279       m_data              = &slot;
0280     }
0281 
0282     const T& operator()() { return *m_data; }
0283 
0284   private:
0285     friend class JOmniFactory;
0286 
0287     void Configure(JParameterManager& parman, const std::string& prefix) override {
0288       parman.SetDefaultParameter(prefix + ":" + this->m_name, *m_data, this->m_description);
0289     }
0290     void Configure(std::map<std::string, std::string> fields) override {
0291       auto it = fields.find(this->m_name);
0292       if (it != fields.end()) {
0293         const auto& value_str = it->second;
0294         JParameterManager::Parse(value_str, *m_data);
0295       }
0296     }
0297   };
0298 
0299   template <typename T> class Parameter : public ParameterBase {
0300 
0301     T m_data;
0302 
0303   public:
0304     Parameter(JOmniFactory* owner, std::string name, T default_value, std::string description) {
0305       owner->RegisterParameter(this);
0306       this->m_name        = name;
0307       this->m_description = description;
0308       m_data              = default_value;
0309     }
0310 
0311     const T& operator()() { return m_data; }
0312 
0313   private:
0314     friend class JOmniFactory;
0315 
0316     void Configure(JParameterManager& parman, const std::string& /* prefix */) override {
0317       parman.SetDefaultParameter(m_prefix + ":" + this->m_name, m_data, this->m_description);
0318     }
0319     void Configure(std::map<std::string, std::string> fields) override {
0320       auto it = fields.find(this->m_name);
0321       if (it != fields.end()) {
0322         const auto& value_str = it->second;
0323         if constexpr (10000 * JVersion::major + 100 * JVersion::minor + 1 * JVersion::patch <
0324                       20102) {
0325           m_data = JParameterManager::Parse<T>(value_str);
0326         } else {
0327           JParameterManager::Parse(value_str, m_data);
0328         }
0329       }
0330     }
0331   };
0332 
0333   void RegisterParameter(ParameterBase* parameter) { m_parameters.push_back(parameter); }
0334 
0335   void ConfigureAllParameters(std::map<std::string, std::string> fields) {
0336     for (auto* parameter : this->m_parameters) {
0337       parameter->Configure(fields);
0338     }
0339   }
0340 
0341   // ===============
0342   // Handle services
0343   // ===============
0344 
0345   struct ServiceBase {
0346     virtual void Init(JApplication* app) = 0;
0347   };
0348 
0349   template <typename ServiceT> class Service : public ServiceBase {
0350 
0351     std::shared_ptr<ServiceT> m_data;
0352 
0353   public:
0354     Service(JOmniFactory* owner) { owner->RegisterService(this); }
0355 
0356     ServiceT& operator()() { return *m_data; }
0357 
0358   private:
0359     friend class JOmniFactory;
0360 
0361     void Init(JApplication* app) { m_data = app->GetService<ServiceT>(); }
0362   };
0363 
0364   void RegisterService(ServiceBase* service) { m_services.push_back(service); }
0365 
0366   // ================
0367   // Handle resources
0368   // ================
0369 
0370   struct ResourceBase {
0371     virtual void ChangeRun(const JEvent& event) = 0;
0372   };
0373 
0374   template <typename ServiceT, typename ResourceT, typename LambdaT>
0375   class Resource : public ResourceBase {
0376     ResourceT m_data;
0377     LambdaT m_lambda;
0378 
0379   public:
0380     Resource(JOmniFactory* owner, LambdaT lambda) : m_lambda(lambda) {
0381       owner->RegisterResource(this);
0382     };
0383 
0384     const ResourceT& operator()() { return m_data; }
0385 
0386   private:
0387     friend class JOmniFactory;
0388 
0389     void ChangeRun(const JEvent& event) {
0390       auto run_nr                       = event.GetRunNumber();
0391       std::shared_ptr<ServiceT> service = event.GetJApplication()->template GetService<ServiceT>();
0392       m_data                            = m_lambda(service, run_nr);
0393     }
0394   };
0395 
0396   void RegisterResource(ResourceBase* resource) { m_resources.push_back(resource); }
0397 
0398 public:
0399   std::vector<InputBase*> m_inputs;
0400   std::vector<OutputBase*> m_outputs;
0401   std::vector<ParameterBase*> m_parameters;
0402   std::vector<ServiceBase*> m_services;
0403   std::vector<ResourceBase*> m_resources;
0404 
0405 private:
0406   // App belongs on JMultifactory, it is just missing temporarily
0407   JApplication* m_app;
0408 
0409   // Plugin name belongs on JMultifactory, it is just missing temporarily
0410   std::string m_plugin_name;
0411 
0412   // Prefix for parameters and loggers, derived from plugin name and tag in PreInit().
0413   std::string m_prefix;
0414 
0415   /// Current logger
0416   std::shared_ptr<spdlog::logger> m_logger;
0417 
0418   /// Configuration
0419   ConfigT m_config;
0420 
0421 public:
0422   std::size_t FindVariadicCollectionCount(std::size_t total_input_count,
0423                                           std::size_t variadic_input_count,
0424                                           std::size_t total_collection_count, bool is_input) {
0425 
0426     std::size_t variadic_collection_count =
0427         total_collection_count - (total_input_count - variadic_input_count);
0428 
0429     if (variadic_input_count == 0) {
0430       // No variadic inputs: check that collection_name count matches input count exactly
0431       if (total_input_count != total_collection_count) {
0432         throw JException(
0433             "JOmniFactory '%s': Wrong number of %s collection names: %d expected, %d found.",
0434             m_prefix.c_str(), (is_input ? "input" : "output"), total_input_count,
0435             total_collection_count);
0436       }
0437     } else {
0438       // Variadic inputs: check that we have enough collection names for the non-variadic inputs
0439       if (total_input_count - variadic_input_count > total_collection_count) {
0440         throw JException("JOmniFactory '%s': Not enough %s collection names: %d needed, %d found.",
0441                          m_prefix.c_str(), (is_input ? "input" : "output"),
0442                          total_input_count - variadic_input_count, total_collection_count);
0443       }
0444 
0445       // Variadic inputs: check that the variadic collection names is evenly divided by the variadic input count
0446       if (variadic_collection_count % variadic_input_count != 0) {
0447         throw JException("JOmniFactory '%s': Wrong number of %s collection names: %d found total, "
0448                          "but %d can't be distributed among %d variadic inputs evenly.",
0449                          m_prefix.c_str(), (is_input ? "input" : "output"), total_collection_count,
0450                          variadic_collection_count, variadic_input_count);
0451       }
0452     }
0453     return variadic_collection_count;
0454   }
0455 
0456   inline void PreInit(std::string tag, std::vector<std::string> default_input_collection_names,
0457                       std::vector<std::string> default_output_collection_names) {
0458 
0459     m_prefix = (this->GetPluginName().empty()) ? tag : this->GetPluginName() + ":" + tag;
0460 
0461     // Obtain collection name overrides if provided.
0462     // Priority = [JParameterManager, JOmniFactoryGenerator]
0463     m_app->SetDefaultParameter(m_prefix + ":InputTags", default_input_collection_names,
0464                                "Input collection names");
0465     m_app->SetDefaultParameter(m_prefix + ":OutputTags", default_output_collection_names,
0466                                "Output collection names");
0467 
0468     // Figure out variadic inputs
0469     std::size_t variadic_input_count = 0;
0470     for (auto* input : m_inputs) {
0471       if (input->is_variadic) {
0472         variadic_input_count += 1;
0473       }
0474     }
0475     std::size_t variadic_input_collection_count = FindVariadicCollectionCount(
0476         m_inputs.size(), variadic_input_count, default_input_collection_names.size(), true);
0477 
0478     // Set input collection names
0479     for (std::size_t i = 0; auto* input : m_inputs) {
0480       input->collection_names.clear();
0481       if (input->is_variadic) {
0482         for (std::size_t j = 0; j < (variadic_input_collection_count / variadic_input_count); ++j) {
0483           input->collection_names.push_back(default_input_collection_names[i++]);
0484         }
0485       } else {
0486         input->collection_names.push_back(default_input_collection_names[i++]);
0487       }
0488     }
0489 
0490     // Figure out variadic outputs
0491     std::size_t variadic_output_count = 0;
0492     for (auto* output : m_outputs) {
0493       if (output->is_variadic) {
0494         variadic_output_count += 1;
0495       }
0496     }
0497     std::size_t variadic_output_collection_count = FindVariadicCollectionCount(
0498         m_outputs.size(), variadic_output_count, default_output_collection_names.size(), true);
0499 
0500     // Set output collection names and create corresponding helper factories
0501     for (std::size_t i = 0; auto* output : m_outputs) {
0502       output->collection_names.clear();
0503       if (output->is_variadic) {
0504         for (std::size_t j = 0; j < (variadic_output_collection_count / variadic_output_count);
0505              ++j) {
0506           output->collection_names.push_back(default_output_collection_names[i++]);
0507         }
0508       } else {
0509         output->collection_names.push_back(default_output_collection_names[i++]);
0510       }
0511       output->CreateHelperFactory(*this);
0512     }
0513 
0514     // Obtain logger (defines the parameter option)
0515     m_logger = m_app->GetService<Log_service>()->logger(m_prefix);
0516   }
0517 
0518   void Init() override {
0519     auto app = GetApplication();
0520     for (auto* parameter : m_parameters) {
0521       parameter->Configure(*(app->GetJParameterManager()), m_prefix);
0522     }
0523     for (auto* service : m_services) {
0524       service->Init(app);
0525     }
0526     static_cast<AlgoT*>(this)->Configure();
0527   }
0528 
0529   void BeginRun(const std::shared_ptr<const JEvent>& event) override {
0530     for (auto* resource : m_resources) {
0531       resource->ChangeRun(*event);
0532     }
0533     static_cast<AlgoT*>(this)->ChangeRun(event->GetRunNumber());
0534   }
0535 
0536   virtual void ChangeRun(int32_t /* run_number */) override {};
0537 
0538   virtual void Process(int32_t /* run_number */, uint64_t /* event_number */) {};
0539 
0540   void Process(const std::shared_ptr<const JEvent>& event) override {
0541     try {
0542       for (auto* input : m_inputs) {
0543         input->GetCollection(*event);
0544       }
0545       for (auto* output : m_outputs) {
0546         output->Reset();
0547       }
0548 #if SPDLOG_VERSION >= 11400 && (!defined(SPDLOG_NO_TLS) || !SPDLOG_NO_TLS)
0549       spdlog::mdc::put("e", std::to_string(event->GetEventNumber()));
0550 #endif
0551       static_cast<AlgoT*>(this)->Process(event->GetRunNumber(), event->GetEventNumber());
0552       for (auto* output : m_outputs) {
0553         output->SetCollection(*this);
0554       }
0555     } catch (std::exception& e) {
0556       throw JException(e.what());
0557     }
0558   }
0559 
0560   using ConfigType = ConfigT;
0561 
0562   void SetApplication(JApplication* app) { m_app = app; }
0563 
0564   JApplication* GetApplication() { return m_app; }
0565 
0566   void SetPluginName(std::string plugin_name) { m_plugin_name = plugin_name; }
0567 
0568   std::string GetPluginName() { return m_plugin_name; }
0569 
0570   inline std::string GetPrefix() { return m_prefix; }
0571 
0572   /// Retrieve reference to already-configured logger
0573   std::shared_ptr<spdlog::logger>& logger() { return m_logger; }
0574 
0575   /// Retrieve reference to embedded config object
0576   ConfigT& config() { return m_config; }
0577 };