Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2024-06-17 07:06:30

0001 
0002 #include <JANA/JApplication.h>
0003 #include <JANA/JEvent.h>
0004 #include <JANA/JFactorySet.h>
0005 #include <JANA/JMultifactory.h>
0006 #include <JANA/Services/JComponentManager.h>
0007 #include <JANA/Services/JParameterManager.h>
0008 #include <catch2/catch_test_macros.hpp>
0009 #include <edm4hep/SimCalorimeterHitCollection.h>
0010 #include <fmt/core.h>
0011 #include <spdlog/logger.h>
0012 #include <stdint.h>
0013 #include <iostream>
0014 #include <map>
0015 #include <memory>
0016 #include <string>
0017 #include <utility>
0018 #include <vector>
0019 
0020 #include "extensions/jana/JOmniFactory.h"
0021 #include "extensions/jana/JOmniFactoryGeneratorT.h"
0022 
0023 struct BasicTestAlgConfig {
0024     int bucket_count = 42;
0025     double threshold = 7.6;
0026 };
0027 
0028 struct BasicTestAlg : public JOmniFactory<BasicTestAlg, BasicTestAlgConfig> {
0029 
0030     PodioOutput<edm4hep::SimCalorimeterHit> output_hits_left {this, "output_hits_left"};
0031     PodioOutput<edm4hep::SimCalorimeterHit> output_hits_right {this, "output_hits_right"};
0032     Output<edm4hep::SimCalorimeterHit> output_vechits {this, "output_vechits"};
0033 
0034     ParameterRef<int> bucket_count {this, "bucket_count", config().bucket_count, "The total number of buckets [dimensionless]"};
0035     ParameterRef<double> threshold {this, "threshold", config().threshold, "The max cutoff threshold [V * A * kg^-1 * m^-2 * sec^-3]"};
0036 
0037     std::vector<OutputBase*> GetOutputs() { return this->m_outputs; }
0038 
0039     int m_init_call_count = 0;
0040     int m_changerun_call_count = 0;
0041     int m_process_call_count = 0;
0042 
0043     void Configure() {
0044         m_init_call_count++;
0045         logger()->info("Calling BasicTestAlg::Configure");
0046     }
0047 
0048     void ChangeRun(int64_t run_number) {
0049         m_changerun_call_count++;
0050         logger()->info("Calling BasicTestAlg::ChangeRun");
0051     }
0052 
0053     // NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
0054     void Process(int64_t run_number, uint64_t event_number) {
0055         m_process_call_count++;
0056         logger()->info("Calling BasicTestAlg::Process with bucket_count={}, threshold={}", config().bucket_count, config().threshold);
0057         // Provide empty collections (as opposed to nulls) so that PODIO doesn't crash
0058         // TODO: NWB: I though multifactories already took care of this under the hood somewhere
0059         output_hits_left() = std::make_unique<edm4hep::SimCalorimeterHitCollection>();
0060         output_hits_right() = std::make_unique<edm4hep::SimCalorimeterHitCollection>();
0061         output_vechits().push_back(new edm4hep::SimCalorimeterHit());
0062     }
0063 };
0064 
0065 template <typename OutputCollectionT, typename MultifactoryT>
0066 MultifactoryT* RetrieveMultifactory(JFactorySet* facset, std::string output_collection_name) {
0067     auto fac = facset->GetFactory<OutputCollectionT>(output_collection_name);
0068     REQUIRE(fac != nullptr);
0069     auto helper = dynamic_cast<JMultifactoryHelperPodio<OutputCollectionT>*>(fac);
0070     REQUIRE(helper != nullptr);
0071     auto multifactory = helper->GetMultifactory();
0072     REQUIRE(multifactory != nullptr);
0073     auto typed = dynamic_cast<MultifactoryT*>(multifactory);
0074     REQUIRE(typed != nullptr);
0075     return typed;
0076 }
0077 
0078 TEST_CASE("Registering Podio outputs works") {
0079     BasicTestAlg alg;
0080     REQUIRE(alg.GetOutputs().size() == 3);
0081     REQUIRE(alg.GetOutputs()[0]->collection_names[0] == "output_hits_left");
0082     REQUIRE(alg.GetOutputs()[0]->type_name == "edm4hep::SimCalorimeterHit");
0083     REQUIRE(alg.GetOutputs()[1]->collection_names[0] == "output_hits_right");
0084     REQUIRE(alg.GetOutputs()[1]->type_name == "edm4hep::SimCalorimeterHit");
0085     REQUIRE(alg.GetOutputs()[2]->collection_names[0] == "output_vechits");
0086     REQUIRE(alg.GetOutputs()[2]->type_name == "edm4hep::SimCalorimeterHit");
0087 }
0088 
0089 TEST_CASE("Configuration object is correctly wired from untyped wiring data") {
0090     JApplication app;
0091     app.AddPlugin("log");
0092     app.Initialize();
0093     JOmniFactoryGeneratorT<BasicTestAlg> facgen (&app);
0094     facgen.AddWiring("ECalTestAlg", {}, {"ECalLeftHits", "ECalRightHits", "ECalVecHits"}, {{"threshold", "6.1"}, {"bucket_count", "22"}});
0095 
0096     JFactorySet facset;
0097     facgen.GenerateFactories(&facset);
0098     // for (auto* fac : facset.GetAllFactories()) {
0099         // std::cout << "typename=" << fac->GetFactoryName() << ", tag=" << fac->GetTag() << std::endl;
0100     // }
0101 
0102     auto basictestalg = RetrieveMultifactory<edm4hep::SimCalorimeterHit,BasicTestAlg>(&facset, "ECalLeftHits");
0103 
0104     REQUIRE(basictestalg->threshold() == 6.1);
0105     REQUIRE(basictestalg->bucket_count() == 22);
0106 
0107     REQUIRE(basictestalg->config().threshold == 6.1);
0108     REQUIRE(basictestalg->config().bucket_count == 22);
0109 
0110     REQUIRE(basictestalg->m_init_call_count == 0);
0111 }
0112 
0113 TEST_CASE("Multiple configuration objects are correctly wired from untyped wiring data") {
0114     JApplication app;
0115     app.AddPlugin("log");
0116     app.Initialize();
0117     JOmniFactoryGeneratorT<BasicTestAlg> facgen (&app);
0118     facgen.AddWiring("BCalTestAlg", {}, {"BCalLeftHits", "BCalRightHits", "BCalVecHits"}, {{"threshold", "6.1"}, {"bucket_count", "22"}});
0119     facgen.AddWiring("CCalTestAlg", {}, {"CCalLeftHits", "CCalRightHits", "CCalVecHits"}, {{"threshold", "9.0"}, {"bucket_count", "27"}});
0120     facgen.AddWiring("ECalTestAlg", {}, {"ECalLeftHits", "ECalRightHits", "ECalVecHits"}, {{"threshold", "16.25"}, {"bucket_count", "49"}});
0121 
0122     JFactorySet facset;
0123     facgen.GenerateFactories(&facset);
0124     // for (auto* fac : facset.GetAllFactories()) {
0125         // std::cout << "typename=" << fac->GetFactoryName() << ", tag=" << fac->GetTag() << std::endl;
0126     // }
0127     auto b = RetrieveMultifactory<edm4hep::SimCalorimeterHit,BasicTestAlg>(&facset, "BCalLeftHits");
0128     auto c = RetrieveMultifactory<edm4hep::SimCalorimeterHit,BasicTestAlg>(&facset, "CCalLeftHits");
0129     auto e = RetrieveMultifactory<edm4hep::SimCalorimeterHit,BasicTestAlg>(&facset, "ECalLeftHits");
0130 
0131     REQUIRE(b->threshold() == 6.1);
0132     REQUIRE(b->bucket_count() == 22);
0133     REQUIRE(b->config().threshold == 6.1);
0134     REQUIRE(b->config().bucket_count == 22);
0135 
0136     REQUIRE(c->threshold() == 9.0);
0137     REQUIRE(c->bucket_count() == 27);
0138     REQUIRE(c->config().threshold == 9.0);
0139     REQUIRE(c->config().bucket_count == 27);
0140 
0141     REQUIRE(e->threshold() == 16.25);
0142     REQUIRE(e->bucket_count() == 49);
0143     REQUIRE(e->config().threshold == 16.25);
0144     REQUIRE(e->config().bucket_count == 49);
0145 
0146     REQUIRE(b->m_init_call_count == 0);
0147     REQUIRE(c->m_init_call_count == 0);
0148     REQUIRE(e->m_init_call_count == 0);
0149 }
0150 
0151 TEST_CASE("JParameterManager correctly understands which values are defaulted and which are overridden") {
0152     JApplication app;
0153     app.AddPlugin("log");
0154 
0155     auto facgen = new JOmniFactoryGeneratorT<BasicTestAlg>(&app);
0156     facgen->AddWiring("FunTest", {}, {"BCalLeftHits", "BCalRightHits", "BCalVecHits"}, {{"threshold", "6.1"}, {"bucket_count", "22"}});
0157     app.Add(facgen);
0158 
0159     app.GetJParameterManager()->SetParameter("FunTest:threshold", 12.0);
0160     app.Initialize();
0161 
0162     auto event = std::make_shared<JEvent>();
0163     app.GetService<JComponentManager>()->configure_event(*event);
0164 
0165     // for (auto* fac : event->GetFactorySet()->GetAllFactories()) {
0166         // std::cout << "typename=" << fac->GetFactoryName() << ", tag=" << fac->GetTag() << std::endl;
0167     // }
0168 
0169     // Retrieve multifactory
0170     auto b = RetrieveMultifactory<edm4hep::SimCalorimeterHit,BasicTestAlg>(event->GetFactorySet(), "BCalLeftHits");
0171 
0172     // Overrides won't happen until factory gets Init()ed. However, defaults will be applied immediately
0173     REQUIRE(b->threshold() == 6.1);
0174     REQUIRE(b->config().threshold == 6.1);
0175 
0176     // Trigger JMF::Execute(), in order to trigger Init(), in order to Configure()s all Parameter fields...
0177     auto lefthits = event->Get<edm4hep::SimCalorimeterHit>("BCalLeftHits");
0178 
0179     REQUIRE(b->threshold() == 12.0);
0180     REQUIRE(b->config().threshold == 12.0);
0181 
0182     std::cout << "Showing the full table of config parameters" << std::endl;
0183     app.GetJParameterManager()->PrintParameters(true, false, true);
0184 
0185     std::cout << "Showing only overridden config parameters" << std::endl;
0186     app.GetJParameterManager()->PrintParameters(false, false, true);
0187 }
0188 
0189 TEST_CASE("Wiring itself is correctly defaulted") {
0190     JApplication app;
0191     app.AddPlugin("log");
0192 
0193     auto facgen = new JOmniFactoryGeneratorT<BasicTestAlg>(&app);
0194     facgen->AddWiring("FunTest", {}, {"BCalLeftHits", "BCalRightHits", "BCalVecHits"}, {{"threshold", "6.1"}});
0195     app.Add(facgen);
0196     app.Initialize();
0197 
0198     auto event = std::make_shared<JEvent>();
0199     app.GetService<JComponentManager>()->configure_event(*event);
0200 
0201     // Retrieve multifactory
0202     auto b = RetrieveMultifactory<edm4hep::SimCalorimeterHit,BasicTestAlg>(event->GetFactorySet(), "BCalLeftHits");
0203 
0204     // Overrides won't happen until factory gets Init()ed. However, defaults will be applied immediately
0205     REQUIRE(b->bucket_count() == 42);      // Not provided by wiring
0206     REQUIRE(b->config().bucket_count == 42);  // Not provided by wiring
0207 
0208     REQUIRE(b->threshold() == 6.1);        // Provided by wiring
0209     REQUIRE(b->config().threshold == 6.1);    // Provided by wiring
0210 
0211     // Trigger JMF::Execute(), in order to trigger Init(), in order to Configure()s all Parameter fields...
0212     auto lefthits = event->Get<edm4hep::SimCalorimeterHit>("BCalLeftHits");
0213 
0214     // We didn't override the config values via the parameter manager, so all of these should be the same
0215     REQUIRE(b->bucket_count() == 42);      // Not provided by wiring
0216     REQUIRE(b->config().bucket_count == 42);  // Not provided by wiring
0217 
0218     REQUIRE(b->threshold() == 6.1);        // Provided by wiring
0219     REQUIRE(b->config().threshold == 6.1);    // Provided by wiring
0220 
0221 
0222     b->logger()->info("Showing the full table of config parameters");
0223     app.GetJParameterManager()->PrintParameters(true, false, true);
0224 
0225     b->logger()->info("Showing only overridden config parameters");
0226     // Should be empty because everything is defaulted
0227     app.GetJParameterManager()->PrintParameters(false, false, true);
0228 }
0229 
0230 struct VariadicTestAlg : public JOmniFactory<VariadicTestAlg, BasicTestAlgConfig> {
0231 
0232     PodioInput<edm4hep::SimCalorimeterHit> m_hits_in {this};
0233     VariadicPodioInput<edm4hep::SimCalorimeterHit> m_variadic_hits_in {this};
0234     PodioOutput<edm4hep::SimCalorimeterHit> m_hits_out {this};
0235 
0236     std::vector<OutputBase*> GetOutputs() { return this->m_outputs; }
0237 
0238     int m_init_call_count = 0;
0239     int m_changerun_call_count = 0;
0240     int m_process_call_count = 0;
0241 
0242     void Configure() {
0243         m_init_call_count++;
0244         logger()->info("Calling VariadicTestAlg::Configure");
0245     }
0246 
0247     void ChangeRun(int64_t run_number) {
0248         m_changerun_call_count++;
0249         logger()->info("Calling VariadicTestAlg::ChangeRun");
0250     }
0251 
0252     // NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
0253     void Process(int64_t run_number, uint64_t event_number) {
0254         m_process_call_count++;
0255         logger()->info("Calling VariadicTestAlg::Process with bucket_count={}, threshold={}", config().bucket_count, config().threshold);
0256 
0257         REQUIRE(m_hits_in()->size() == 3);
0258         REQUIRE(m_variadic_hits_in().size() == 2);
0259         REQUIRE(m_variadic_hits_in()[0]->size() == 1);
0260         REQUIRE(m_variadic_hits_in()[1]->size() == 2);
0261 
0262         m_hits_out() = std::make_unique<edm4hep::SimCalorimeterHitCollection>();
0263         m_hits_out()->create();
0264         m_hits_out()->create();
0265         m_hits_out()->create();
0266         m_hits_out()->create();
0267     }
0268 };
0269 
0270 TEST_CASE("VariadicOmniFactoryTests") {
0271     VariadicTestAlg alg;
0272     JApplication app;
0273     app.AddPlugin("log");
0274 
0275     auto facgen = new JOmniFactoryGeneratorT<VariadicTestAlg>("VariadicTest", {"main_hits","fun_hits","funner_hits"}, {"processed_hits"}, &app);
0276     app.Add(facgen);
0277     app.Initialize();
0278 
0279     auto event = std::make_shared<JEvent>();
0280     app.GetService<JComponentManager>()->configure_event(*event);
0281 
0282     edm4hep::SimCalorimeterHitCollection mains;
0283     edm4hep::SimCalorimeterHitCollection funs;
0284     edm4hep::SimCalorimeterHitCollection funners;
0285 
0286     mains.create();
0287     mains.create();
0288     mains.create();
0289 
0290     funs.create();
0291     funners.create();
0292     funners.create();
0293 
0294     event->InsertCollection<edm4hep::SimCalorimeterHit>(std::move(mains), "main_hits");
0295     event->InsertCollection<edm4hep::SimCalorimeterHit>(std::move(funs), "fun_hits");
0296     event->InsertCollection<edm4hep::SimCalorimeterHit>(std::move(funners), "funner_hits");
0297 
0298     auto processed = event->GetCollection<edm4hep::SimCalorimeterHit>("processed_hits");
0299     REQUIRE(processed->size() == 4);
0300 }
0301 
0302 struct SubsetTestAlg : public JOmniFactory<SubsetTestAlg, BasicTestAlgConfig> {
0303 
0304     VariadicPodioInput<edm4hep::SimCalorimeterHit> m_left_hits_in {this};
0305     PodioInput<edm4hep::SimCalorimeterHit> m_center_hits_in {this};
0306     VariadicPodioInput<edm4hep::SimCalorimeterHit> m_right_hits_in {this};
0307     PodioOutput<edm4hep::SimCalorimeterHit> m_hits_out {this};
0308 
0309     std::vector<OutputBase*> GetOutputs() { return this->m_outputs; }
0310 
0311     void Configure() {}
0312 
0313     void ChangeRun(int64_t run_number) {}
0314 
0315     // NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
0316     void Process(int64_t run_number, uint64_t event_number) {
0317 
0318         // Variadic collection count constrained to be same size
0319         REQUIRE(m_left_hits_in().size() == 1);
0320         REQUIRE(m_right_hits_in().size() == 1);
0321 
0322         REQUIRE(m_left_hits_in()[0]->size() == 2);
0323         REQUIRE(m_right_hits_in()[0]->size() == 1);
0324 
0325         REQUIRE(m_center_hits_in()->size() == 3);
0326 
0327         m_hits_out() = std::make_unique<edm4hep::SimCalorimeterHitCollection>();
0328         m_hits_out()->setSubsetCollection();
0329 
0330         auto* lhi = m_left_hits_in()[0];
0331         for (const auto& hit : *lhi) {
0332             m_hits_out()->push_back(hit);
0333         }
0334         for (const auto& hit : *m_center_hits_in()) {
0335             m_hits_out()->push_back(hit);
0336         }
0337     }
0338 };
0339 
0340 
0341 TEST_CASE("SubsetOmniFactoryTests") {
0342     JApplication app;
0343     app.AddPlugin("log");
0344 
0345     auto facgen = new JOmniFactoryGeneratorT<SubsetTestAlg>("SubsetTest", {"left","center","right"}, {"processed_hits"}, &app);
0346     app.Add(facgen);
0347     app.Initialize();
0348 
0349     auto event = std::make_shared<JEvent>();
0350     app.GetService<JComponentManager>()->configure_event(*event);
0351 
0352     edm4hep::SimCalorimeterHitCollection left;
0353     edm4hep::SimCalorimeterHitCollection center;
0354     edm4hep::SimCalorimeterHitCollection right;
0355 
0356     left.create();
0357     left.create();
0358     right.create();
0359 
0360     center.create();
0361     center.create();
0362     center.create();
0363 
0364     event->InsertCollection<edm4hep::SimCalorimeterHit>(std::move(left), "left");
0365     event->InsertCollection<edm4hep::SimCalorimeterHit>(std::move(center), "center");
0366     event->InsertCollection<edm4hep::SimCalorimeterHit>(std::move(right), "right");
0367 
0368     auto processed = event->GetCollection<edm4hep::SimCalorimeterHit>("processed_hits");
0369     REQUIRE(processed->size() == 5);
0370 }
0371 
0372 struct VariadicOutputTestAlg : public JOmniFactory<VariadicOutputTestAlg, BasicTestAlgConfig> {
0373 
0374     PodioInput<edm4hep::SimCalorimeterHit> m_hits_in {this};
0375 
0376     VariadicPodioOutput<edm4hep::SimCalorimeterHit> m_hits_out {this};
0377 
0378     void Configure() {}
0379     void ChangeRun(int64_t run_number) {}
0380 
0381     // NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
0382     void Process(int64_t run_number, uint64_t event_number) {
0383 
0384         REQUIRE(m_hits_out().size() == 2);
0385         m_hits_out()[0]->setSubsetCollection();
0386         m_hits_out()[1]->setSubsetCollection();
0387 
0388         int i = 0;
0389         for (const auto& hit : *(m_hits_in())) {
0390             m_hits_out()[i]->push_back(hit);
0391             i = (i == 1) ? 0 : 1;
0392         }
0393     }
0394 };
0395 
0396 
0397 
0398 TEST_CASE("VariadicPodioOutputTests") {
0399     JApplication app;
0400     app.AddPlugin("log");
0401 
0402     auto facgen = new JOmniFactoryGeneratorT<VariadicOutputTestAlg>("VariadicOutputTest", {"all_hits"}, {"left_hits", "right_hits"}, &app);
0403     app.Add(facgen);
0404     app.Initialize();
0405 
0406     auto event = std::make_shared<JEvent>();
0407     app.GetService<JComponentManager>()->configure_event(*event);
0408 
0409     edm4hep::SimCalorimeterHitCollection all_hits;
0410 
0411     all_hits.create();
0412     all_hits.create();
0413     all_hits.create();
0414 
0415     event->InsertCollection<edm4hep::SimCalorimeterHit>(std::move(all_hits), "all_hits");
0416 
0417     auto left_hits = event->GetCollection<edm4hep::SimCalorimeterHit>("left_hits");
0418     auto right_hits = event->GetCollection<edm4hep::SimCalorimeterHit>("right_hits");
0419     REQUIRE(left_hits->size() == 2);
0420     REQUIRE(right_hits->size() == 1);
0421 }