Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-07 08:56:23

0001 
0002 #define CATCH_CONFIG_MAIN
0003 #include <catch.hpp>
0004 #include <JANA/JFactory.h>
0005 #include <JANA/JEventSource.h>
0006 #include <JANA/JEventProcessor.h>
0007 #include <JANA/JFactoryGenerator.h>
0008 #include <JANA/Topology/JTopologyBuilder.h>
0009 #include <JANA/Topology/JSourceArrow.h>
0010 #include <JANA/Topology/JMapArrow.h>
0011 #include <JANA/Topology/JTapArrow.h>
0012 
0013 // This integration test covers the end-to-end testing of a GPU override
0014 // We set this up so that we have the following factory chain:
0015 // A (cpu) -> B (gpu) -> C (cpu)
0016 
0017 struct A {int a; };
0018 struct B {int b; };
0019 struct C {int c; };
0020 
0021 struct AFac : public JFactory {
0022   Output<A> a_out {this};
0023   void Process(const JEvent& event) override {
0024     LOG_INFO(GetLogger()) << "Running AFac (hopefully on CPU)" << LOG_END;
0025     A* a = new A;
0026     a->a = event.GetEventNumber() + 1;
0027     a_out().push_back(a);
0028   }
0029 };
0030 
0031 struct BFac : public JFactory {
0032   Input<A> a_in {this};
0033   Output<B> b_out {this};
0034   void Process(const JEvent&) override {
0035     LOG_INFO(GetLogger()) << "Running BFac (hopefully on GPU)" << LOG_END;
0036     auto* a = a_in->at(0);
0037     B* b = new B;
0038     b->b = a->a * 2;
0039     b_out().push_back(b);
0040   }
0041 };
0042 
0043 struct CFac : public JFactory {
0044   Input<B> b_in {this};
0045   Output<C> c_out {this};
0046   void Process(const JEvent&) override {
0047     LOG_INFO(GetLogger()) << "Running CFac (hopefully on CPU)" << LOG_END;
0048     auto* b = b_in->at(0);
0049     C* c = new C;
0050     c->c = b->b + 4;
0051     c_out().push_back(c);
0052   }
0053 };
0054 
0055 struct Proc : public JEventProcessor {
0056   Input<C> c_in {this};
0057   Proc() {
0058     SetCallbackStyle(JFactory::CallbackStyle::ExpertMode);
0059   }
0060   void ProcessSequential(const JEvent& event) override {
0061     LOG_INFO(GetLogger()) << "Retrieving C (hopefully on CPU)" << LOG_END;
0062     auto* c = c_in->at(0);
0063     auto evtnr = event.GetEventNumber();
0064     int expected = ((evtnr + 1) * 2) + 4;
0065     LOG_INFO(GetLogger()) << "Evt nr " << evtnr << ": " << "Expected " << expected << ", found " << c->c << std::endl;
0066     REQUIRE(expected == c->c);
0067   }
0068 };
0069 
0070 
0071 struct TriggerFactoryInputsArrow : public JArrow {
0072     std::string unique_name;
0073 
0074     TriggerFactoryInputsArrow(JEventLevel level) {
0075         SetName("trigger");
0076         SetIsParallel(true);
0077         AddPort("in", level);
0078         AddPort("out", level);
0079     }
0080 
0081     void Fire(JEvent* event, OutputData& outputs, size_t& output_count, JArrow::FireResult& status) override {
0082         auto* fac = event->GetFactorySet()->GetDatabundle(unique_name)->GetFactory();
0083         for (auto* input : fac->GetInputs()) {
0084             input->TriggerFactoryCreate(*event);
0085         }
0086         for (auto* input : fac->GetVariadicInputs()) {
0087             input->TriggerFactoryCreate(*event);
0088         }
0089         LOG_DEBUG(m_logger) << "Executed arrow " << GetName() << " for event# " << event->GetEventNumber() << LOG_END;
0090         outputs[0] = {event, 1};
0091         output_count = 1;
0092         status = JArrow::FireResult::KeepGoing;
0093     }
0094 };
0095 
0096 struct OffloadArrow : public JArrow {
0097     std::string unique_name;
0098     OffloadArrow(JEventLevel level) {
0099         SetName("offload");
0100         SetIsParallel(false);
0101         AddPort("in", level);
0102         AddPort("out", level);
0103     }
0104 
0105     ~OffloadArrow() override {}
0106 
0107     void Fire(JEvent* event, OutputData& outputs, size_t& output_count, JArrow::FireResult& status) override {
0108 
0109         event->GetFactorySet()->GetDatabundle(unique_name)->GetFactory()->Create(*event);
0110 
0111         LOG_DEBUG(m_logger) << "Executed arrow " << GetName() << " for event# " << event->GetEventNumber() << LOG_END;
0112         outputs[0] = {event, 1};
0113         output_count = 1;
0114         status = JArrow::FireResult::KeepGoing;
0115     }
0116 };
0117 
0118 
0119 void configure_topology(JTopologyBuilder& builder, JComponentManager& components) {
0120 
0121     auto* src_arrow = new JSourceArrow("src", JEventLevel::PhysicsEvent, components.get_evt_srces());
0122 
0123     TriggerFactoryInputsArrow* trigger_inputs_arrow = new TriggerFactoryInputsArrow(JEventLevel::PhysicsEvent);
0124     trigger_inputs_arrow->unique_name = "B";
0125 
0126     OffloadArrow* offload_arrow = new OffloadArrow(JEventLevel::PhysicsEvent);
0127     offload_arrow->unique_name = "B";
0128 
0129     JMapArrow* map_arrow = new JMapArrow("map", JEventLevel::PhysicsEvent);
0130     for (auto proc : components.get_evt_procs()) {
0131         map_arrow->AddProcessor(proc);
0132     }
0133 
0134     JTapArrow* tap_arrow = new JTapArrow("tap", JEventLevel::PhysicsEvent);
0135     for (auto proc : components.get_evt_procs()) {
0136         tap_arrow->AddProcessor(proc);
0137     }
0138 
0139     builder.AddArrow(src_arrow);
0140     builder.AddArrow(trigger_inputs_arrow);
0141     builder.AddArrow(offload_arrow);
0142     builder.AddArrow(map_arrow);
0143     builder.AddArrow(tap_arrow);
0144 
0145     builder.ConnectPool("src", "in", JEventLevel::PhysicsEvent);
0146     builder.ConnectPool("tap", "out", JEventLevel::PhysicsEvent);
0147 
0148     builder.ConnectQueue("src", "out", "trigger", "in");
0149     builder.ConnectQueue("trigger", "out", "offload", "in");
0150     builder.ConnectQueue("offload", "out", "map", "in");
0151     builder.ConnectQueue("map", "out", "tap", "in");
0152 }
0153 
0154 
0155 TEST_CASE("SimpleOffloading") {
0156   JApplication app;
0157   app.Add(new JFactoryGeneratorT<AFac>());
0158   app.Add(new JFactoryGeneratorT<BFac>());
0159   app.Add(new JFactoryGeneratorT<CFac>());
0160   app.Add(new JEventSource);
0161   app.Add(new Proc);
0162   app.SetParameterValue("jana:nevents", 3);
0163   app.SetParameterValue("nthreads", 2);
0164   app.SetParameterValue("jana:log:show_threadstamp", 1);
0165   app.SetParameterValue("jana:loglevel", "DEBUG");
0166 
0167   auto builder = app.GetService<JTopologyBuilder>();
0168   builder->SetConfigureFn(configure_topology);
0169   app.Run();
0170 }
0171 
0172 
0173