Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2024-09-27 07:03:09

0001 
0002 #include <JANA/JApplication.h>
0003 #include <JANA/JEvent.h>
0004 #include <JANA/JFactorySet.h>
0005 #include <JANA/JMultifactory.h>
0006 #include <JANA/Services/JComponentManager.h>
0007 #include <JANA/Services/JParameterManager.h>
0008 #include <JANA/Utils/JTypeInfo.h>
0009 #include <catch2/catch_test_macros.hpp>
0010 #include <edm4hep/SimCalorimeterHitCollection.h>
0011 #include <fmt/core.h>
0012 #include <spdlog/logger.h>
0013 #include <stdint.h>
0014 #include <iostream>
0015 #include <map>
0016 #include <memory>
0017 #include <string>
0018 #include <utility>
0019 #include <vector>
0020 
0021 #include "extensions/jana/JOmniFactory.h"
0022 #include "extensions/jana/JOmniFactoryGeneratorT.h"
0023 
0024 struct BasicTestAlgConfig {
0025     int bucket_count = 42;
0026     double threshold = 7.6;
0027 };
0028 
0029 struct BasicTestAlg : public JOmniFactory<BasicTestAlg, BasicTestAlgConfig> {
0030 
0031     PodioOutput<edm4hep::SimCalorimeterHit> output_hits_left {this, "output_hits_left"};
0032     PodioOutput<edm4hep::SimCalorimeterHit> output_hits_right {this, "output_hits_right"};
0033     Output<edm4hep::SimCalorimeterHit> output_vechits {this, "output_vechits"};
0034 
0035     ParameterRef<int> bucket_count {this, "bucket_count", config().bucket_count, "The total number of buckets [dimensionless]"};
0036     ParameterRef<double> threshold {this, "threshold", config().threshold, "The max cutoff threshold [V * A * kg^-1 * m^-2 * sec^-3]"};
0037 
0038     std::vector<OutputBase*> GetOutputs() { return this->m_outputs; }
0039 
0040     int m_init_call_count = 0;
0041     int m_changerun_call_count = 0;
0042     int m_process_call_count = 0;
0043 
0044     void Configure() {
0045         m_init_call_count++;
0046         logger()->info("Calling BasicTestAlg::Configure");
0047     }
0048 
0049     void ChangeRun(int64_t run_number) {
0050         m_changerun_call_count++;
0051         logger()->info("Calling BasicTestAlg::ChangeRun");
0052     }
0053 
0054     // NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
0055     void Process(int64_t run_number, uint64_t event_number) {
0056         m_process_call_count++;
0057         logger()->info("Calling BasicTestAlg::Process with bucket_count={}, threshold={}", config().bucket_count, config().threshold);
0058         // Provide empty collections (as opposed to nulls) so that PODIO doesn't crash
0059         output_hits_left() = std::make_unique<edm4hep::SimCalorimeterHitCollection>();
0060         output_hits_right() = std::make_unique<edm4hep::SimCalorimeterHitCollection>();
0061         output_vechits().push_back(new edm4hep::SimCalorimeterHit());
0062     }
0063 };
0064 
0065 template <typename OutputCollectionT, typename MultifactoryT>
0066 MultifactoryT* RetrieveMultifactory(JFactorySet* facset, std::string output_collection_name) {
0067     auto fac = facset->GetFactory(JTypeInfo::demangle<OutputCollectionT>(), output_collection_name);
0068     REQUIRE(fac != nullptr);
0069     auto helper = dynamic_cast<JMultifactoryHelperPodio<OutputCollectionT>*>(fac);
0070     REQUIRE(helper != nullptr);
0071     auto multifactory = helper->GetMultifactory();
0072     REQUIRE(multifactory != nullptr);
0073     auto typed = dynamic_cast<MultifactoryT*>(multifactory);
0074     REQUIRE(typed != nullptr);
0075     return typed;
0076 }
0077 
0078 TEST_CASE("Registering Podio outputs works") {
0079     BasicTestAlg alg;
0080     REQUIRE(alg.GetOutputs().size() == 3);
0081     REQUIRE(alg.GetOutputs()[0]->collection_names[0] == "output_hits_left");
0082     REQUIRE(alg.GetOutputs()[0]->type_name == "edm4hep::SimCalorimeterHit");
0083     REQUIRE(alg.GetOutputs()[1]->collection_names[0] == "output_hits_right");
0084     REQUIRE(alg.GetOutputs()[1]->type_name == "edm4hep::SimCalorimeterHit");
0085     REQUIRE(alg.GetOutputs()[2]->collection_names[0] == "output_vechits");
0086     REQUIRE(alg.GetOutputs()[2]->type_name == "edm4hep::SimCalorimeterHit");
0087 }
0088 
0089 TEST_CASE("Configuration object is correctly wired from untyped wiring data") {
0090     JApplication app;
0091     app.AddPlugin("log");
0092     app.Initialize();
0093     JOmniFactoryGeneratorT<BasicTestAlg> facgen (&app);
0094     facgen.AddWiring("ECalTestAlg", {}, {"ECalLeftHits", "ECalRightHits", "ECalVecHits"}, {{"threshold", "6.1"}, {"bucket_count", "22"}});
0095 
0096     JFactorySet facset;
0097     facgen.GenerateFactories(&facset);
0098     // for (auto* fac : facset.GetAllFactories()) {
0099         // std::cout << "typename=" << fac->GetFactoryName() << ", tag=" << fac->GetTag() << std::endl;
0100     // }
0101 
0102     auto basictestalg = RetrieveMultifactory<edm4hep::SimCalorimeterHit,BasicTestAlg>(&facset, "ECalLeftHits");
0103 
0104     REQUIRE(basictestalg->threshold() == 6.1);
0105     REQUIRE(basictestalg->bucket_count() == 22);
0106 
0107     REQUIRE(basictestalg->config().threshold == 6.1);
0108     REQUIRE(basictestalg->config().bucket_count == 22);
0109 
0110     REQUIRE(basictestalg->m_init_call_count == 0);
0111 }
0112 
0113 TEST_CASE("Multiple configuration objects are correctly wired from untyped wiring data") {
0114     JApplication app;
0115     app.AddPlugin("log");
0116     app.Initialize();
0117     JOmniFactoryGeneratorT<BasicTestAlg> facgen (&app);
0118     facgen.AddWiring("BCalTestAlg", {}, {"BCalLeftHits", "BCalRightHits", "BCalVecHits"}, {{"threshold", "6.1"}, {"bucket_count", "22"}});
0119     facgen.AddWiring("CCalTestAlg", {}, {"CCalLeftHits", "CCalRightHits", "CCalVecHits"}, {{"threshold", "9.0"}, {"bucket_count", "27"}});
0120     facgen.AddWiring("ECalTestAlg", {}, {"ECalLeftHits", "ECalRightHits", "ECalVecHits"}, {{"threshold", "16.25"}, {"bucket_count", "49"}});
0121 
0122     JFactorySet facset;
0123     facgen.GenerateFactories(&facset);
0124     // for (auto* fac : facset.GetAllFactories()) {
0125         // std::cout << "typename=" << fac->GetFactoryName() << ", tag=" << fac->GetTag() << std::endl;
0126     // }
0127     auto b = RetrieveMultifactory<edm4hep::SimCalorimeterHit,BasicTestAlg>(&facset, "BCalLeftHits");
0128     auto c = RetrieveMultifactory<edm4hep::SimCalorimeterHit,BasicTestAlg>(&facset, "CCalLeftHits");
0129     auto e = RetrieveMultifactory<edm4hep::SimCalorimeterHit,BasicTestAlg>(&facset, "ECalLeftHits");
0130 
0131     REQUIRE(b->threshold() == 6.1);
0132     REQUIRE(b->bucket_count() == 22);
0133     REQUIRE(b->config().threshold == 6.1);
0134     REQUIRE(b->config().bucket_count == 22);
0135 
0136     REQUIRE(c->threshold() == 9.0);
0137     REQUIRE(c->bucket_count() == 27);
0138     REQUIRE(c->config().threshold == 9.0);
0139     REQUIRE(c->config().bucket_count == 27);
0140 
0141     REQUIRE(e->threshold() == 16.25);
0142     REQUIRE(e->bucket_count() == 49);
0143     REQUIRE(e->config().threshold == 16.25);
0144     REQUIRE(e->config().bucket_count == 49);
0145 
0146     REQUIRE(b->m_init_call_count == 0);
0147     REQUIRE(c->m_init_call_count == 0);
0148     REQUIRE(e->m_init_call_count == 0);
0149 }
0150 
0151 TEST_CASE("JParameterManager correctly understands which values are defaulted and which are overridden") {
0152     JApplication app;
0153     app.AddPlugin("log");
0154 
0155     auto facgen = new JOmniFactoryGeneratorT<BasicTestAlg>(&app);
0156     facgen->AddWiring("FunTest", {}, {"BCalLeftHits", "BCalRightHits", "BCalVecHits"}, {{"threshold", "6.1"}, {"bucket_count", "22"}});
0157     app.Add(facgen);
0158 
0159     app.GetJParameterManager()->SetParameter("FunTest:threshold", 12.0);
0160     app.Initialize();
0161 
0162     auto event = std::make_shared<JEvent>();
0163     app.GetService<JComponentManager>()->configure_event(*event);
0164 
0165     // for (auto* fac : event->GetFactorySet()->GetAllFactories()) {
0166         // std::cout << "typename=" << fac->GetFactoryName() << ", tag=" << fac->GetTag() << std::endl;
0167     // }
0168 
0169     // Retrieve multifactory
0170     auto b = RetrieveMultifactory<edm4hep::SimCalorimeterHit,BasicTestAlg>(event->GetFactorySet(), "BCalLeftHits");
0171 
0172     // Overrides won't happen until factory gets Init()ed. However, defaults will be applied immediately
0173     REQUIRE(b->threshold() == 6.1);
0174     REQUIRE(b->config().threshold == 6.1);
0175 
0176     // Trigger JMF::Execute(), in order to trigger Init(), in order to Configure()s all Parameter fields...
0177     auto lefthits = event->Get<edm4hep::SimCalorimeterHit>("BCalLeftHits");
0178 
0179     REQUIRE(b->threshold() == 12.0);
0180     REQUIRE(b->config().threshold == 12.0);
0181 
0182     std::cout << "Showing the full table of config parameters" << std::endl;
0183     app.GetJParameterManager()->PrintParameters(2, 1); // verbosity, strictness
0184 
0185     std::cout << "Showing only overridden config parameters" << std::endl;
0186     app.GetJParameterManager()->PrintParameters(1, 1); // verbosity, strictness
0187 
0188 }
0189 
0190 TEST_CASE("Wiring itself is correctly defaulted") {
0191     JApplication app;
0192     app.AddPlugin("log");
0193 
0194     auto facgen = new JOmniFactoryGeneratorT<BasicTestAlg>(&app);
0195     facgen->AddWiring("FunTest", {}, {"BCalLeftHits", "BCalRightHits", "BCalVecHits"}, {{"threshold", "6.1"}});
0196     app.Add(facgen);
0197     app.Initialize();
0198 
0199     auto event = std::make_shared<JEvent>();
0200     app.GetService<JComponentManager>()->configure_event(*event);
0201 
0202     // Retrieve multifactory
0203     auto b = RetrieveMultifactory<edm4hep::SimCalorimeterHit,BasicTestAlg>(event->GetFactorySet(), "BCalLeftHits");
0204 
0205     // Overrides won't happen until factory gets Init()ed. However, defaults will be applied immediately
0206     REQUIRE(b->bucket_count() == 42);      // Not provided by wiring
0207     REQUIRE(b->config().bucket_count == 42);  // Not provided by wiring
0208 
0209     REQUIRE(b->threshold() == 6.1);        // Provided by wiring
0210     REQUIRE(b->config().threshold == 6.1);    // Provided by wiring
0211 
0212     // Trigger JMF::Execute(), in order to trigger Init(), in order to Configure()s all Parameter fields...
0213     auto lefthits = event->Get<edm4hep::SimCalorimeterHit>("BCalLeftHits");
0214 
0215     // We didn't override the config values via the parameter manager, so all of these should be the same
0216     REQUIRE(b->bucket_count() == 42);      // Not provided by wiring
0217     REQUIRE(b->config().bucket_count == 42);  // Not provided by wiring
0218 
0219     REQUIRE(b->threshold() == 6.1);        // Provided by wiring
0220     REQUIRE(b->config().threshold == 6.1);    // Provided by wiring
0221 
0222 
0223     b->logger()->info("Showing the full table of config parameters");
0224     app.GetJParameterManager()->PrintParameters(2,1); // verbosity, strictness
0225 
0226     b->logger()->info("Showing only overridden config parameters");
0227     // Should be empty because everything is defaulted
0228     app.GetJParameterManager()->PrintParameters(1,1); // verbosity, strictness
0229 }
0230 
0231 struct VariadicTestAlg : public JOmniFactory<VariadicTestAlg, BasicTestAlgConfig> {
0232 
0233     PodioInput<edm4hep::SimCalorimeterHit> m_hits_in {this};
0234     VariadicPodioInput<edm4hep::SimCalorimeterHit> m_variadic_hits_in {this};
0235     PodioOutput<edm4hep::SimCalorimeterHit> m_hits_out {this};
0236 
0237     std::vector<OutputBase*> GetOutputs() { return this->m_outputs; }
0238 
0239     int m_init_call_count = 0;
0240     int m_changerun_call_count = 0;
0241     int m_process_call_count = 0;
0242 
0243     void Configure() {
0244         m_init_call_count++;
0245         logger()->info("Calling VariadicTestAlg::Configure");
0246     }
0247 
0248     void ChangeRun(int64_t run_number) {
0249         m_changerun_call_count++;
0250         logger()->info("Calling VariadicTestAlg::ChangeRun");
0251     }
0252 
0253     // NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
0254     void Process(int64_t run_number, uint64_t event_number) {
0255         m_process_call_count++;
0256         logger()->info("Calling VariadicTestAlg::Process with bucket_count={}, threshold={}", config().bucket_count, config().threshold);
0257 
0258         REQUIRE(m_hits_in()->size() == 3);
0259         REQUIRE(m_variadic_hits_in().size() == 2);
0260         REQUIRE(m_variadic_hits_in()[0]->size() == 1);
0261         REQUIRE(m_variadic_hits_in()[1]->size() == 2);
0262 
0263         m_hits_out() = std::make_unique<edm4hep::SimCalorimeterHitCollection>();
0264         m_hits_out()->create();
0265         m_hits_out()->create();
0266         m_hits_out()->create();
0267         m_hits_out()->create();
0268     }
0269 };
0270 
0271 TEST_CASE("VariadicOmniFactoryTests") {
0272     VariadicTestAlg alg;
0273     JApplication app;
0274     app.AddPlugin("log");
0275 
0276     auto facgen = new JOmniFactoryGeneratorT<VariadicTestAlg>("VariadicTest", {"main_hits","fun_hits","funner_hits"}, {"processed_hits"}, &app);
0277     app.Add(facgen);
0278     app.Initialize();
0279 
0280     auto event = std::make_shared<JEvent>();
0281     app.GetService<JComponentManager>()->configure_event(*event);
0282 
0283     edm4hep::SimCalorimeterHitCollection mains;
0284     edm4hep::SimCalorimeterHitCollection funs;
0285     edm4hep::SimCalorimeterHitCollection funners;
0286 
0287     mains.create();
0288     mains.create();
0289     mains.create();
0290 
0291     funs.create();
0292     funners.create();
0293     funners.create();
0294 
0295     event->InsertCollection<edm4hep::SimCalorimeterHit>(std::move(mains), "main_hits");
0296     event->InsertCollection<edm4hep::SimCalorimeterHit>(std::move(funs), "fun_hits");
0297     event->InsertCollection<edm4hep::SimCalorimeterHit>(std::move(funners), "funner_hits");
0298 
0299     auto processed = event->GetCollection<edm4hep::SimCalorimeterHit>("processed_hits");
0300     REQUIRE(processed->size() == 4);
0301 }
0302 
0303 struct SubsetTestAlg : public JOmniFactory<SubsetTestAlg, BasicTestAlgConfig> {
0304 
0305     VariadicPodioInput<edm4hep::SimCalorimeterHit> m_left_hits_in {this};
0306     PodioInput<edm4hep::SimCalorimeterHit> m_center_hits_in {this};
0307     VariadicPodioInput<edm4hep::SimCalorimeterHit> m_right_hits_in {this};
0308     PodioOutput<edm4hep::SimCalorimeterHit> m_hits_out {this};
0309 
0310     std::vector<OutputBase*> GetOutputs() { return this->m_outputs; }
0311 
0312     void Configure() {}
0313 
0314     void ChangeRun(int64_t run_number) {}
0315 
0316     // NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
0317     void Process(int64_t run_number, uint64_t event_number) {
0318 
0319         // Variadic collection count constrained to be same size
0320         REQUIRE(m_left_hits_in().size() == 1);
0321         REQUIRE(m_right_hits_in().size() == 1);
0322 
0323         REQUIRE(m_left_hits_in()[0]->size() == 2);
0324         REQUIRE(m_right_hits_in()[0]->size() == 1);
0325 
0326         REQUIRE(m_center_hits_in()->size() == 3);
0327 
0328         m_hits_out() = std::make_unique<edm4hep::SimCalorimeterHitCollection>();
0329         m_hits_out()->setSubsetCollection();
0330 
0331         auto* lhi = m_left_hits_in()[0];
0332         for (const auto& hit : *lhi) {
0333             m_hits_out()->push_back(hit);
0334         }
0335         for (const auto& hit : *m_center_hits_in()) {
0336             m_hits_out()->push_back(hit);
0337         }
0338     }
0339 };
0340 
0341 
0342 TEST_CASE("SubsetOmniFactoryTests") {
0343     JApplication app;
0344     app.AddPlugin("log");
0345 
0346     auto facgen = new JOmniFactoryGeneratorT<SubsetTestAlg>("SubsetTest", {"left","center","right"}, {"processed_hits"}, &app);
0347     app.Add(facgen);
0348     app.Initialize();
0349 
0350     auto event = std::make_shared<JEvent>();
0351     app.GetService<JComponentManager>()->configure_event(*event);
0352 
0353     edm4hep::SimCalorimeterHitCollection left;
0354     edm4hep::SimCalorimeterHitCollection center;
0355     edm4hep::SimCalorimeterHitCollection right;
0356 
0357     left.create();
0358     left.create();
0359     right.create();
0360 
0361     center.create();
0362     center.create();
0363     center.create();
0364 
0365     event->InsertCollection<edm4hep::SimCalorimeterHit>(std::move(left), "left");
0366     event->InsertCollection<edm4hep::SimCalorimeterHit>(std::move(center), "center");
0367     event->InsertCollection<edm4hep::SimCalorimeterHit>(std::move(right), "right");
0368 
0369     auto processed = event->GetCollection<edm4hep::SimCalorimeterHit>("processed_hits");
0370     REQUIRE(processed->size() == 5);
0371 }
0372 
0373 struct VariadicOutputTestAlg : public JOmniFactory<VariadicOutputTestAlg, BasicTestAlgConfig> {
0374 
0375     PodioInput<edm4hep::SimCalorimeterHit> m_hits_in {this};
0376 
0377     VariadicPodioOutput<edm4hep::SimCalorimeterHit> m_hits_out {this};
0378 
0379     void Configure() {}
0380     void ChangeRun(int64_t run_number) {}
0381 
0382     // NOLINTNEXTLINE(bugprone-easily-swappable-parameters)
0383     void Process(int64_t run_number, uint64_t event_number) {
0384 
0385         REQUIRE(m_hits_out().size() == 2);
0386         m_hits_out()[0]->setSubsetCollection();
0387         m_hits_out()[1]->setSubsetCollection();
0388 
0389         int i = 0;
0390         for (const auto& hit : *(m_hits_in())) {
0391             m_hits_out()[i]->push_back(hit);
0392             i = (i == 1) ? 0 : 1;
0393         }
0394     }
0395 };
0396 
0397 
0398 
0399 TEST_CASE("VariadicPodioOutputTests") {
0400     JApplication app;
0401     app.AddPlugin("log");
0402 
0403     auto facgen = new JOmniFactoryGeneratorT<VariadicOutputTestAlg>("VariadicOutputTest", {"all_hits"}, {"left_hits", "right_hits"}, &app);
0404     app.Add(facgen);
0405     app.Initialize();
0406 
0407     auto event = std::make_shared<JEvent>();
0408     app.GetService<JComponentManager>()->configure_event(*event);
0409 
0410     edm4hep::SimCalorimeterHitCollection all_hits;
0411 
0412     all_hits.create();
0413     all_hits.create();
0414     all_hits.create();
0415 
0416     event->InsertCollection<edm4hep::SimCalorimeterHit>(std::move(all_hits), "all_hits");
0417 
0418     auto left_hits = event->GetCollection<edm4hep::SimCalorimeterHit>("left_hits");
0419     auto right_hits = event->GetCollection<edm4hep::SimCalorimeterHit>("right_hits");
0420     REQUIRE(left_hits->size() == 2);
0421     REQUIRE(right_hits->size() == 1);
0422 }