File indexing completed on 2026-05-07 08:56:23
0001
0002 #define CATCH_CONFIG_MAIN
0003 #include <catch.hpp>
0004 #include <JANA/JFactory.h>
0005 #include <JANA/JEventSource.h>
0006 #include <JANA/JEventProcessor.h>
0007 #include <JANA/JFactoryGenerator.h>
0008 #include <JANA/Topology/JTopologyBuilder.h>
0009 #include <JANA/Topology/JSourceArrow.h>
0010 #include <JANA/Topology/JMapArrow.h>
0011 #include <JANA/Topology/JTapArrow.h>
0012
0013
0014
0015
0016
0017 struct A {int a; };
0018 struct B {int b; };
0019 struct C {int c; };
0020
0021 struct AFac : public JFactory {
0022 Output<A> a_out {this};
0023 void Process(const JEvent& event) override {
0024 LOG_INFO(GetLogger()) << "Running AFac (hopefully on CPU)" << LOG_END;
0025 A* a = new A;
0026 a->a = event.GetEventNumber() + 1;
0027 a_out().push_back(a);
0028 }
0029 };
0030
0031 struct BFac : public JFactory {
0032 Input<A> a_in {this};
0033 Output<B> b_out {this};
0034 void Process(const JEvent&) override {
0035 LOG_INFO(GetLogger()) << "Running BFac (hopefully on GPU)" << LOG_END;
0036 auto* a = a_in->at(0);
0037 B* b = new B;
0038 b->b = a->a * 2;
0039 b_out().push_back(b);
0040 }
0041 };
0042
0043 struct CFac : public JFactory {
0044 Input<B> b_in {this};
0045 Output<C> c_out {this};
0046 void Process(const JEvent&) override {
0047 LOG_INFO(GetLogger()) << "Running CFac (hopefully on CPU)" << LOG_END;
0048 auto* b = b_in->at(0);
0049 C* c = new C;
0050 c->c = b->b + 4;
0051 c_out().push_back(c);
0052 }
0053 };
0054
0055 struct Proc : public JEventProcessor {
0056 Input<C> c_in {this};
0057 Proc() {
0058 SetCallbackStyle(JFactory::CallbackStyle::ExpertMode);
0059 }
0060 void ProcessSequential(const JEvent& event) override {
0061 LOG_INFO(GetLogger()) << "Retrieving C (hopefully on CPU)" << LOG_END;
0062 auto* c = c_in->at(0);
0063 auto evtnr = event.GetEventNumber();
0064 int expected = ((evtnr + 1) * 2) + 4;
0065 LOG_INFO(GetLogger()) << "Evt nr " << evtnr << ": " << "Expected " << expected << ", found " << c->c << std::endl;
0066 REQUIRE(expected == c->c);
0067 }
0068 };
0069
0070
0071 struct TriggerFactoryInputsArrow : public JArrow {
0072 std::string unique_name;
0073
0074 TriggerFactoryInputsArrow(JEventLevel level) {
0075 SetName("trigger");
0076 SetIsParallel(true);
0077 AddPort("in", level);
0078 AddPort("out", level);
0079 }
0080
0081 void Fire(JEvent* event, OutputData& outputs, size_t& output_count, JArrow::FireResult& status) override {
0082 auto* fac = event->GetFactorySet()->GetDatabundle(unique_name)->GetFactory();
0083 for (auto* input : fac->GetInputs()) {
0084 input->TriggerFactoryCreate(*event);
0085 }
0086 for (auto* input : fac->GetVariadicInputs()) {
0087 input->TriggerFactoryCreate(*event);
0088 }
0089 LOG_DEBUG(m_logger) << "Executed arrow " << GetName() << " for event# " << event->GetEventNumber() << LOG_END;
0090 outputs[0] = {event, 1};
0091 output_count = 1;
0092 status = JArrow::FireResult::KeepGoing;
0093 }
0094 };
0095
0096 struct OffloadArrow : public JArrow {
0097 std::string unique_name;
0098 OffloadArrow(JEventLevel level) {
0099 SetName("offload");
0100 SetIsParallel(false);
0101 AddPort("in", level);
0102 AddPort("out", level);
0103 }
0104
0105 ~OffloadArrow() override {}
0106
0107 void Fire(JEvent* event, OutputData& outputs, size_t& output_count, JArrow::FireResult& status) override {
0108
0109 event->GetFactorySet()->GetDatabundle(unique_name)->GetFactory()->Create(*event);
0110
0111 LOG_DEBUG(m_logger) << "Executed arrow " << GetName() << " for event# " << event->GetEventNumber() << LOG_END;
0112 outputs[0] = {event, 1};
0113 output_count = 1;
0114 status = JArrow::FireResult::KeepGoing;
0115 }
0116 };
0117
0118
0119 void configure_topology(JTopologyBuilder& builder, JComponentManager& components) {
0120
0121 auto* src_arrow = new JSourceArrow("src", JEventLevel::PhysicsEvent, components.get_evt_srces());
0122
0123 TriggerFactoryInputsArrow* trigger_inputs_arrow = new TriggerFactoryInputsArrow(JEventLevel::PhysicsEvent);
0124 trigger_inputs_arrow->unique_name = "B";
0125
0126 OffloadArrow* offload_arrow = new OffloadArrow(JEventLevel::PhysicsEvent);
0127 offload_arrow->unique_name = "B";
0128
0129 JMapArrow* map_arrow = new JMapArrow("map", JEventLevel::PhysicsEvent);
0130 for (auto proc : components.get_evt_procs()) {
0131 map_arrow->AddProcessor(proc);
0132 }
0133
0134 JTapArrow* tap_arrow = new JTapArrow("tap", JEventLevel::PhysicsEvent);
0135 for (auto proc : components.get_evt_procs()) {
0136 tap_arrow->AddProcessor(proc);
0137 }
0138
0139 builder.AddArrow(src_arrow);
0140 builder.AddArrow(trigger_inputs_arrow);
0141 builder.AddArrow(offload_arrow);
0142 builder.AddArrow(map_arrow);
0143 builder.AddArrow(tap_arrow);
0144
0145 builder.ConnectPool("src", "in", JEventLevel::PhysicsEvent);
0146 builder.ConnectPool("tap", "out", JEventLevel::PhysicsEvent);
0147
0148 builder.ConnectQueue("src", "out", "trigger", "in");
0149 builder.ConnectQueue("trigger", "out", "offload", "in");
0150 builder.ConnectQueue("offload", "out", "map", "in");
0151 builder.ConnectQueue("map", "out", "tap", "in");
0152 }
0153
0154
0155 TEST_CASE("SimpleOffloading") {
0156 JApplication app;
0157 app.Add(new JFactoryGeneratorT<AFac>());
0158 app.Add(new JFactoryGeneratorT<BFac>());
0159 app.Add(new JFactoryGeneratorT<CFac>());
0160 app.Add(new JEventSource);
0161 app.Add(new Proc);
0162 app.SetParameterValue("jana:nevents", 3);
0163 app.SetParameterValue("nthreads", 2);
0164 app.SetParameterValue("jana:log:show_threadstamp", 1);
0165 app.SetParameterValue("jana:loglevel", "DEBUG");
0166
0167 auto builder = app.GetService<JTopologyBuilder>();
0168 builder->SetConfigureFn(configure_topology);
0169 app.Run();
0170 }
0171
0172
0173