Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2024-06-26 07:05:46

0001 // Copyright 2023, Jefferson Science Associates, LLC.
0002 // Subject to the terms in the LICENSE file found in the top-level directory.
0003 // Created by Nathan Brei
0004 
0005 #pragma once
0006 
0007 /**
0008  * Omnifactories are a lightweight layer connecting JANA to generic algorithms
0009  * It is assumed multiple input data (controlled by input tags)
0010  * which might be changed by user parameters.
0011  */
0012 
0013 #include <JANA/CLI/JVersion.h>
0014 #include <JANA/JMultifactory.h>
0015 #include <JANA/JEvent.h>
0016 #include <spdlog/spdlog.h>
0017 
0018 #include "services/io/podio/datamodel_glue.h"
0019 #include "services/log/Log_service.h"
0020 
0021 #include <string>
0022 #include <vector>
0023 
0024 struct EmptyConfig {};
0025 
0026 template <typename AlgoT, typename ConfigT=EmptyConfig>
0027 class JOmniFactory : public JMultifactory {
0028 public:
0029 
0030     /// ========================
0031     /// Handle input collections
0032     /// ========================
0033 
0034     struct InputBase {
0035         std::string type_name;
0036         std::vector<std::string> collection_names;
0037         bool is_variadic = false;
0038 
0039         virtual void GetCollection(const JEvent& event) = 0;
0040     };
0041 
0042     template <typename T>
0043     class Input : public InputBase {
0044 
0045         std::vector<const T*> m_data;
0046 
0047     public:
0048         Input(JOmniFactory* owner, std::string default_tag="") {
0049             owner->RegisterInput(this);
0050             this->collection_names.push_back(default_tag);
0051             this->type_name = JTypeInfo::demangle<T>();
0052         }
0053 
0054         const std::vector<const T*>& operator()() { return m_data; }
0055 
0056     private:
0057         friend class JOmniFactory;
0058 
0059         void GetCollection(const JEvent& event) {
0060             m_data = event.Get<T>(this->collection_names[0]);
0061         }
0062     };
0063 
0064 
0065     template <typename PodioT>
0066     class PodioInput : public InputBase {
0067 
0068         const typename PodioTypeMap<PodioT>::collection_t* m_data;
0069 
0070     public:
0071 
0072         PodioInput(JOmniFactory* owner, std::string default_collection_name="") {
0073             owner->RegisterInput(this);
0074             this->collection_names.push_back(default_collection_name);
0075             this->type_name = JTypeInfo::demangle<PodioT>();
0076         }
0077 
0078         const typename PodioTypeMap<PodioT>::collection_t* operator()() {
0079             return m_data;
0080         }
0081 
0082     private:
0083         friend class JOmniFactory;
0084 
0085         void GetCollection(const JEvent& event) {
0086             m_data = event.GetCollection<PodioT>(this->collection_names[0]);
0087         }
0088     };
0089 
0090 
0091     template <typename PodioT>
0092     class VariadicPodioInput : public InputBase {
0093 
0094         std::vector<const typename PodioTypeMap<PodioT>::collection_t*> m_data;
0095 
0096     public:
0097 
0098         VariadicPodioInput(JOmniFactory* owner, std::vector<std::string> default_names = {}) {
0099             owner->RegisterInput(this);
0100             this->collection_names = default_names;
0101             this->type_name = JTypeInfo::demangle<PodioT>();
0102             this->is_variadic = true;
0103         }
0104 
0105         const std::vector<const typename PodioTypeMap<PodioT>::collection_t*> operator()() {
0106             return m_data;
0107         }
0108 
0109     private:
0110         friend class JOmniFactory;
0111 
0112         void GetCollection(const JEvent& event) {
0113             m_data.clear();
0114             for (auto& coll_name : this->collection_names) {
0115                 m_data.push_back(event.GetCollection<PodioT>(coll_name));
0116             }
0117         }
0118     };
0119 
0120     void RegisterInput(InputBase* input) {
0121         m_inputs.push_back(input);
0122     }
0123 
0124 
0125     /// =========================
0126     /// Handle output collections
0127     /// =========================
0128 
0129     struct OutputBase {
0130         std::string type_name;
0131         std::vector<std::string> collection_names;
0132         bool is_variadic = false;
0133 
0134         virtual void CreateHelperFactory(JOmniFactory& fac) = 0;
0135         virtual void SetCollection(JOmniFactory& fac) = 0;
0136         virtual void Reset() = 0;
0137     };
0138 
0139     template <typename T>
0140     class Output : public OutputBase {
0141         std::vector<T*> m_data;
0142 
0143     public:
0144         Output(JOmniFactory* owner, std::string default_tag_name="") {
0145             owner->RegisterOutput(this);
0146             this->collection_names.push_back(default_tag_name);
0147             this->type_name = JTypeInfo::demangle<T>();
0148         }
0149 
0150         std::vector<T*>& operator()() { return m_data; }
0151 
0152     private:
0153         friend class JOmniFactory;
0154 
0155         void CreateHelperFactory(JOmniFactory& fac) override {
0156             fac.DeclareOutput<T>(this->collection_names[0]);
0157         }
0158 
0159         void SetCollection(JOmniFactory& fac) override {
0160             fac.SetData<T>(this->collection_names[0], this->m_data);
0161         }
0162 
0163         void Reset() override { }
0164     };
0165 
0166 
0167     template <typename PodioT>
0168     class PodioOutput : public OutputBase {
0169 
0170         std::unique_ptr<typename PodioTypeMap<PodioT>::collection_t> m_data;
0171 
0172     public:
0173 
0174         PodioOutput(JOmniFactory* owner, std::string default_collection_name="") {
0175             owner->RegisterOutput(this);
0176             this->collection_names.push_back(default_collection_name);
0177             this->type_name = JTypeInfo::demangle<PodioT>();
0178         }
0179 
0180         std::unique_ptr<typename PodioTypeMap<PodioT>::collection_t>& operator()() { return m_data; }
0181 
0182     private:
0183         friend class JOmniFactory;
0184 
0185         void CreateHelperFactory(JOmniFactory& fac) override {
0186             fac.DeclarePodioOutput<PodioT>(this->collection_names[0]);
0187         }
0188 
0189         void SetCollection(JOmniFactory& fac) override {
0190             if (m_data == nullptr) {
0191                 throw JException("JOmniFactory: SetCollection failed due to missing output collection '%s'", this->collection_names[0].c_str());
0192                 // Otherwise this leads to a PODIO segfault
0193             }
0194             fac.SetCollection<PodioT>(this->collection_names[0], std::move(this->m_data));
0195         }
0196 
0197         void Reset() override {
0198             m_data = std::move(std::make_unique<typename PodioTypeMap<PodioT>::collection_t>());
0199         }
0200     };
0201 
0202 
0203     template <typename PodioT>
0204     class VariadicPodioOutput : public OutputBase {
0205 
0206         std::vector<std::unique_ptr<typename PodioTypeMap<PodioT>::collection_t>> m_data;
0207 
0208     public:
0209 
0210         VariadicPodioOutput(JOmniFactory* owner, std::vector<std::string> default_collection_names={}) {
0211             owner->RegisterOutput(this);
0212             this->collection_names = default_collection_names;
0213             this->type_name = JTypeInfo::demangle<PodioT>();
0214             this->is_variadic = true;
0215         }
0216 
0217         std::vector<std::unique_ptr<typename PodioTypeMap<PodioT>::collection_t>>& operator()() { return m_data; }
0218 
0219     private:
0220         friend class JOmniFactory;
0221 
0222         void CreateHelperFactory(JOmniFactory& fac) override {
0223             for (auto& coll_name : this->collection_names) {
0224                 fac.DeclarePodioOutput<PodioT>(coll_name);
0225             }
0226         }
0227 
0228         void SetCollection(JOmniFactory& fac) override {
0229             if (m_data.size() != this->collection_names.size()) {
0230                 throw JException("JOmniFactory: VariadicPodioOutput SetCollection failed: Declared %d collections, but provided %d.", this->collection_names.size(), m_data.size());
0231                 // Otherwise this leads to a PODIO segfault
0232             }
0233             size_t i = 0;
0234             for (auto& coll_name : this->collection_names) {
0235                 fac.SetCollection<PodioT>(coll_name, std::move(this->m_data[i++]));
0236             }
0237         }
0238 
0239         void Reset() override {
0240             m_data.clear();
0241             for (auto& coll_name : this->collection_names) {
0242                 m_data.push_back(std::make_unique<typename PodioTypeMap<PodioT>::collection_t>());
0243             }
0244         }
0245     };
0246 
0247     void RegisterOutput(OutputBase* output) {
0248         m_outputs.push_back(output);
0249     }
0250 
0251 
0252     // =================
0253     // Handle parameters
0254     // =================
0255 
0256     struct ParameterBase {
0257         std::string m_name;
0258         std::string m_description;
0259         virtual void Configure(JParameterManager& parman, const std::string& prefix) = 0;
0260         virtual void Configure(std::map<std::string, std::string> fields) = 0;
0261     };
0262 
0263     template <typename T>
0264     class ParameterRef : public ParameterBase {
0265 
0266         T* m_data;
0267 
0268     public:
0269         ParameterRef(JOmniFactory* owner, std::string name, T& slot, std::string description="") {
0270             owner->RegisterParameter(this);
0271             this->m_name = name;
0272             this->m_description = description;
0273             m_data = &slot;
0274         }
0275 
0276         const T& operator()() { return *m_data; }
0277 
0278     private:
0279         friend class JOmniFactory;
0280 
0281         void Configure(JParameterManager& parman, const std::string& prefix) override {
0282             parman.SetDefaultParameter(prefix + ":" + this->m_name, *m_data, this->m_description);
0283         }
0284         void Configure(std::map<std::string, std::string> fields) override {
0285             auto it = fields.find(this->m_name);
0286             if (it != fields.end()) {
0287                 const auto& value_str = it->second;
0288                 if constexpr (10000 * JVersion::major
0289                               + 100 * JVersion::minor
0290                                 + 1 * JVersion::patch < 20102) {
0291                     *m_data = JParameterManager::Parse<T>(value_str);
0292                 } else {
0293                     JParameterManager::Parse(value_str, *m_data);
0294                 }
0295             }
0296         }
0297     };
0298 
0299     template <typename T>
0300     class Parameter : public ParameterBase {
0301 
0302         T m_data;
0303 
0304     public:
0305         Parameter(JOmniFactory* owner, std::string name, T default_value, std::string description) {
0306             owner->RegisterParameter(this);
0307             this->m_name = name;
0308             this->m_description = description;
0309             m_data = default_value;
0310         }
0311 
0312         const T& operator()() { return m_data; }
0313 
0314     private:
0315         friend class JOmniFactory;
0316 
0317         void Configure(JParameterManager& parman, const std::string& prefix) override {
0318             parman.SetDefaultParameter(m_prefix + ":" + this->m_name, m_data, this->m_description);
0319         }
0320         void Configure(std::map<std::string, std::string> fields) override {
0321             auto it = fields.find(this->m_name);
0322             if (it != fields.end()) {
0323                 const auto& value_str = it->second;
0324                 if constexpr (10000 * JVersion::major
0325                               + 100 * JVersion::minor
0326                                 + 1 * JVersion::patch < 20102) {
0327                     m_data = JParameterManager::Parse<T>(value_str);
0328                 } else {
0329                     JParameterManager::Parse(value_str, m_data);
0330                 }
0331             }
0332         }
0333     };
0334 
0335     void RegisterParameter(ParameterBase* parameter) {
0336         m_parameters.push_back(parameter);
0337     }
0338 
0339     void ConfigureAllParameters(std::map<std::string, std::string> fields) {
0340         for (auto* parameter : this->m_parameters) {
0341             parameter->Configure(fields);
0342         }
0343     }
0344 
0345     // ===============
0346     // Handle services
0347     // ===============
0348 
0349     struct ServiceBase {
0350         virtual void Init(JApplication* app) = 0;
0351     };
0352 
0353     template <typename ServiceT>
0354     class Service : public ServiceBase {
0355 
0356         std::shared_ptr<ServiceT> m_data;
0357 
0358     public:
0359 
0360         Service(JOmniFactory* owner) {
0361             owner->RegisterService(this);
0362         }
0363 
0364         ServiceT& operator()() {
0365             return *m_data;
0366         }
0367 
0368     private:
0369 
0370         friend class JOmniFactory;
0371 
0372         void Init(JApplication* app) {
0373             m_data = app->GetService<ServiceT>();
0374         }
0375 
0376     };
0377 
0378     void RegisterService(ServiceBase* service) {
0379         m_services.push_back(service);
0380     }
0381 
0382 
0383     // ================
0384     // Handle resources
0385     // ================
0386 
0387     struct ResourceBase {
0388         virtual void ChangeRun(const JEvent& event) = 0;
0389     };
0390 
0391     template <typename ServiceT, typename ResourceT, typename LambdaT>
0392     class Resource : public ResourceBase {
0393         ResourceT m_data;
0394         LambdaT m_lambda;
0395 
0396     public:
0397 
0398         Resource(JOmniFactory* owner, LambdaT lambda) : m_lambda(lambda) {
0399             owner->RegisterResource(this);
0400         };
0401 
0402         const ResourceT& operator()() { return m_data; }
0403 
0404     private:
0405         friend class JOmniFactory;
0406 
0407         void ChangeRun(const JEvent& event) {
0408             auto run_nr = event.GetRunNumber();
0409             std::shared_ptr<ServiceT> service = event.GetJApplication()->template GetService<ServiceT>();
0410             m_data = m_lambda(service, run_nr);
0411         }
0412     };
0413 
0414     void RegisterResource(ResourceBase* resource) {
0415         m_resources.push_back(resource);
0416     }
0417 
0418 
0419 public:
0420     std::vector<InputBase*> m_inputs;
0421     std::vector<OutputBase*> m_outputs;
0422     std::vector<ParameterBase*> m_parameters;
0423     std::vector<ServiceBase*> m_services;
0424     std::vector<ResourceBase*> m_resources;
0425 
0426 private:
0427 
0428     // App belongs on JMultifactory, it is just missing temporarily
0429     JApplication* m_app;
0430 
0431     // Plugin name belongs on JMultifactory, it is just missing temporarily
0432     std::string m_plugin_name;
0433 
0434     // Prefix for parameters and loggers, derived from plugin name and tag in PreInit().
0435     std::string m_prefix;
0436 
0437     /// Current logger
0438     std::shared_ptr<spdlog::logger> m_logger;
0439 
0440     /// Configuration
0441     ConfigT m_config;
0442 
0443 public:
0444 
0445     size_t FindVariadicCollectionCount(size_t total_input_count, size_t variadic_input_count, size_t total_collection_count, bool is_input) {
0446 
0447         size_t variadic_collection_count = total_collection_count - (total_input_count - variadic_input_count);
0448 
0449         if (variadic_input_count == 0) {
0450             // No variadic inputs: check that collection_name count matches input count exactly
0451             if (total_input_count != total_collection_count) {
0452                 throw JException("JOmniFactory '%s': Wrong number of %s collection names: %d expected, %d found.",
0453                                 m_prefix.c_str(), (is_input ? "input" : "output"), total_input_count, total_collection_count);
0454             }
0455         }
0456         else {
0457             // Variadic inputs: check that we have enough collection names for the non-variadic inputs
0458             if (total_input_count-variadic_input_count > total_collection_count) {
0459                 throw JException("JOmniFactory '%s': Not enough %s collection names: %d needed, %d found.",
0460                                 m_prefix.c_str(), (is_input ? "input" : "output"), total_input_count-variadic_input_count, total_collection_count);
0461             }
0462 
0463             // Variadic inputs: check that the variadic collection names is evenly divided by the variadic input count
0464             if (variadic_collection_count % variadic_input_count != 0) {
0465                 throw JException("JOmniFactory '%s': Wrong number of %s collection names: %d found total, but %d can't be distributed among %d variadic inputs evenly.",
0466                                 m_prefix.c_str(), (is_input ? "input" : "output"), total_collection_count, variadic_collection_count, variadic_input_count);
0467             }
0468         }
0469         return variadic_collection_count;
0470     }
0471 
0472     inline void PreInit(std::string tag,
0473                         std::vector<std::string> default_input_collection_names,
0474                         std::vector<std::string> default_output_collection_names ) {
0475 
0476         // TODO: NWB: JMultiFactory::GetTag,SetTag are not currently usable
0477         m_prefix = (this->GetPluginName().empty()) ? tag : this->GetPluginName() + ":" + tag;
0478 
0479         // Obtain collection name overrides if provided.
0480         // Priority = [JParameterManager, JOmniFactoryGenerator]
0481         m_app->SetDefaultParameter(m_prefix + ":InputTags", default_input_collection_names, "Input collection names");
0482         m_app->SetDefaultParameter(m_prefix + ":OutputTags", default_output_collection_names, "Output collection names");
0483 
0484         // Figure out variadic inputs
0485         size_t variadic_input_count = 0;
0486         for (auto* input : m_inputs) {
0487             if (input->is_variadic) {
0488                variadic_input_count += 1;
0489             }
0490         }
0491         size_t variadic_input_collection_count = FindVariadicCollectionCount(m_inputs.size(), variadic_input_count, default_input_collection_names.size(), true);
0492 
0493         // Set input collection names
0494         for (size_t i = 0; auto* input : m_inputs) {
0495             input->collection_names.clear();
0496             if (input->is_variadic) {
0497                 for (size_t j = 0; j<(variadic_input_collection_count/variadic_input_count); ++j) {
0498                     input->collection_names.push_back(default_input_collection_names[i++]);
0499                 }
0500             }
0501             else {
0502                 input->collection_names.push_back(default_input_collection_names[i++]);
0503             }
0504         }
0505 
0506         // Figure out variadic outputs
0507         size_t variadic_output_count = 0;
0508         for (auto* output : m_outputs) {
0509             if (output->is_variadic) {
0510                variadic_output_count += 1;
0511             }
0512         }
0513         size_t variadic_output_collection_count = FindVariadicCollectionCount(m_outputs.size(), variadic_output_count, default_output_collection_names.size(), true);
0514 
0515         // Set output collection names and create corresponding helper factories
0516         for (size_t i = 0; auto* output : m_outputs) {
0517             output->collection_names.clear();
0518             if (output->is_variadic) {
0519                 for (size_t j = 0; j<(variadic_output_collection_count/variadic_output_count); ++j) {
0520                     output->collection_names.push_back(default_output_collection_names[i++]);
0521                 }
0522             }
0523             else {
0524                 output->collection_names.push_back(default_output_collection_names[i++]);
0525             }
0526             output->CreateHelperFactory(*this);
0527         }
0528 
0529         // Obtain logger (defines the parameter option)
0530         m_logger = m_app->GetService<Log_service>()->logger(m_prefix);
0531     }
0532 
0533     void Init() override {
0534         auto app = GetApplication();
0535         for (auto* parameter : m_parameters) {
0536             parameter->Configure(*(app->GetJParameterManager()), m_prefix);
0537         }
0538         for (auto* service : m_services) {
0539             service->Init(app);
0540         }
0541         static_cast<AlgoT*>(this)->Configure();
0542     }
0543 
0544     void BeginRun(const std::shared_ptr<const JEvent>& event) override {
0545         for (auto* resource : m_resources) {
0546             resource->ChangeRun(*event);
0547         }
0548         static_cast<AlgoT*>(this)->ChangeRun(event->GetRunNumber());
0549     }
0550 
0551     void Process(const std::shared_ptr<const JEvent> &event) override {
0552         try {
0553             for (auto* input : m_inputs) {
0554                 input->GetCollection(*event);
0555             }
0556             for (auto* output : m_outputs) {
0557                 output->Reset();
0558             }
0559             static_cast<AlgoT*>(this)->Process(event->GetRunNumber(), event->GetEventNumber());
0560             for (auto* output : m_outputs) {
0561                 output->SetCollection(*this);
0562             }
0563         }
0564         catch(std::exception &e) {
0565             throw JException(e.what());
0566         }
0567     }
0568 
0569     using ConfigType = ConfigT;
0570 
0571     void SetApplication(JApplication* app) { m_app = app; }
0572 
0573     JApplication* GetApplication() { return m_app; }
0574 
0575     void SetPluginName(std::string plugin_name) { m_plugin_name = plugin_name; }
0576 
0577     std::string GetPluginName() { return m_plugin_name; }
0578 
0579     inline std::string GetPrefix() { return m_prefix; }
0580 
0581     /// Retrieve reference to already-configured logger
0582     std::shared_ptr<spdlog::logger> &logger() { return m_logger; }
0583 
0584     /// Retrieve reference to embedded config object
0585     ConfigT& config() { return m_config; }
0586 
0587 };