Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-09 07:49:12

0001 #include <filesystem>
0002 #include <vector>
0003 
0004 #include "G4BooleanSolid.hh"
0005 #include "G4Event.hh"
0006 #include "G4GDMLParser.hh"
0007 #include "G4LogicalVolumeStore.hh"
0008 #include "G4OpBoundaryProcess.hh"
0009 #include "G4OpticalPhoton.hh"
0010 #include "G4PhysicalConstants.hh"
0011 #include "G4PrimaryParticle.hh"
0012 #include "G4PrimaryVertex.hh"
0013 #include "G4SDManager.hh"
0014 #include "G4SubtractionSolid.hh"
0015 #include "G4SystemOfUnits.hh"
0016 #include "G4ThreeVector.hh"
0017 #include "G4Track.hh"
0018 #include "G4TrackStatus.hh"
0019 #include "G4UserEventAction.hh"
0020 #include "G4UserSteppingAction.hh"
0021 #include "G4UserTrackingAction.hh"
0022 #include "G4VPhysicalVolume.hh"
0023 #include "G4VProcess.hh"
0024 #include "G4VUserDetectorConstruction.hh"
0025 #include "G4VUserPrimaryGeneratorAction.hh"
0026 
0027 #include "g4cx/G4CXOpticks.hh"
0028 #include "sysrap/NP.hh"
0029 #include "sysrap/SEvt.hh"
0030 #include "sysrap/STrackInfo.h"
0031 #include "sysrap/spho.h"
0032 #include "sysrap/sphoton.h"
0033 #include "u4/U4Random.hh"
0034 #include "u4/U4StepPoint.hh"
0035 #include "u4/U4Touchable.h"
0036 #include "u4/U4Track.h"
0037 
0038 #include "config.h"
0039 #include "torch.h"
0040 
0041 bool IsSubtractionSolid(G4VSolid *solid)
0042 {
0043     if (!solid)
0044         return false;
0045 
0046     // Check if the solid is directly a G4SubtractionSolid
0047     if (dynamic_cast<G4SubtractionSolid *>(solid))
0048         return true;
0049 
0050     // If the solid is a Boolean solid, check its constituent solids
0051     G4BooleanSolid *booleanSolid = dynamic_cast<G4BooleanSolid *>(solid);
0052     if (booleanSolid)
0053     {
0054         G4VSolid *solidA = booleanSolid->GetConstituentSolid(0);
0055         G4VSolid *solidB = booleanSolid->GetConstituentSolid(1);
0056 
0057         // Recursively check the constituent solids
0058         if (IsSubtractionSolid(solidA) || IsSubtractionSolid(solidB))
0059             return true;
0060     }
0061 
0062     // For other solid types, return false
0063     return false;
0064 }
0065 
0066 std::string str_tolower(std::string s)
0067 {
0068     std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); });
0069     return s;
0070 }
0071 
0072 struct PhotonHit : public G4VHit
0073 {
0074     PhotonHit() = default;
0075 
0076     PhotonHit(G4double energy, G4double time, G4ThreeVector position, G4ThreeVector direction,
0077               G4ThreeVector polarization)
0078         : photon()
0079     {
0080         photon.pos = {static_cast<float>(position.x()), static_cast<float>(position.y()),
0081                       static_cast<float>(position.z())};
0082         photon.time = time;
0083         photon.mom = {static_cast<float>(direction.x()), static_cast<float>(direction.y()),
0084                       static_cast<float>(direction.z())};
0085         photon.pol = {static_cast<float>(polarization.x()), static_cast<float>(polarization.y()),
0086                       static_cast<float>(polarization.z())};
0087         photon.wavelength = h_Planck * c_light / (energy * CLHEP::eV);
0088     }
0089 
0090     // Print method
0091     void Print() override
0092     {
0093         G4cout << photon << G4endl;
0094     }
0095 
0096     // Member variables
0097     sphoton photon;
0098 };
0099 
0100 using PhotonHitsCollection = G4THitsCollection<PhotonHit>;
0101 
0102 struct PhotonSD : public G4VSensitiveDetector
0103 {
0104     PhotonSD(G4String name) : G4VSensitiveDetector(name), fHCID(-1)
0105     {
0106         G4String HCname = name + "_HC";
0107         collectionName.insert(HCname);
0108         G4cout << collectionName.size() << "   PhotonSD name:  " << name << " collection Name: " << HCname << G4endl;
0109     }
0110 
0111     void Initialize(G4HCofThisEvent *hce) override
0112     {
0113         fPhotonHitsCollection = new PhotonHitsCollection(SensitiveDetectorName, collectionName[0]);
0114         if (fHCID < 0)
0115         {
0116             G4cout << "PhotonSD::Initialize:  " << SensitiveDetectorName << "   " << collectionName[0] << G4endl;
0117             fHCID = G4SDManager::GetSDMpointer()->GetCollectionID(collectionName[0]);
0118         }
0119         hce->AddHitsCollection(fHCID, fPhotonHitsCollection);
0120     }
0121 
0122     G4bool ProcessHits(G4Step *aStep, G4TouchableHistory *) override
0123     {
0124         G4Track *track = aStep->GetTrack();
0125 
0126         // Only process optical photons
0127         if (track->GetDefinition() != G4OpticalPhoton::OpticalPhotonDefinition())
0128             return false;
0129 
0130         // Create a new hit (CopyNr is set to 0 as DetectorID is omitted)
0131         PhotonHit *hit = new PhotonHit(
0132             track->GetTotalEnergy(), track->GetGlobalTime(), aStep->GetPostStepPoint()->GetPosition(),
0133             aStep->GetPostStepPoint()->GetMomentumDirection(), aStep->GetPostStepPoint()->GetPolarization());
0134 
0135         fPhotonHitsCollection->insert(hit);
0136         track->SetTrackStatus(fStopAndKill);
0137 
0138         return true;
0139     }
0140 
0141     void EndOfEvent(G4HCofThisEvent *) override
0142     {
0143         G4int num_hits = fPhotonHitsCollection->entries();
0144         G4cout << "PhotonSD::EndOfEvent Number of PhotonHits: " << num_hits << G4endl;
0145 
0146         NP *hits = NP::Make<float>(num_hits, 4, 4);
0147         int i = 0;
0148 
0149         for (PhotonHit *hit : *fPhotonHitsCollection->GetVector())
0150         {
0151             float *photon_data = reinterpret_cast<float *>(&hit->photon);
0152             std::copy(photon_data, photon_data + 16, hits->values<float>() + (i++) * 16);
0153         }
0154 
0155         hits->save("g_hits.npy");
0156         delete hits;
0157     }
0158 
0159   private:
0160     PhotonHitsCollection *fPhotonHitsCollection{nullptr};
0161     G4int fHCID;
0162 };
0163 
0164 struct DetectorConstruction : G4VUserDetectorConstruction
0165 {
0166     DetectorConstruction(std::filesystem::path gdml_file) : gdml_file_(gdml_file)
0167     {
0168     }
0169 
0170     G4VPhysicalVolume *Construct() override
0171     {
0172         parser_.Read(gdml_file_.string(), false);
0173         G4VPhysicalVolume *world = parser_.GetWorldVolume();
0174 
0175         G4CXOpticks::SetGeometry(world);
0176 
0177         return world;
0178     }
0179 
0180     void ConstructSDandField() override
0181     {
0182         G4cout << "ConstructSDandField is called." << G4endl;
0183         G4SDManager *SDman = G4SDManager::GetSDMpointer();
0184 
0185         const G4GDMLAuxMapType *auxmap = parser_.GetAuxMap();
0186         for (auto const &[logVol, listType] : *auxmap)
0187         {
0188             for (auto const &auxtype : listType)
0189             {
0190                 if (auxtype.type == "SensDet")
0191                 {
0192                     G4cout << "Attaching sensitive detector to logical volume: " << logVol->GetName() << G4endl;
0193                     G4String name = logVol->GetName() + "_PhotonDetector";
0194                     PhotonSD *aPhotonSD = new PhotonSD(name);
0195                     SDman->AddNewDetector(aPhotonSD);
0196                     logVol->SetSensitiveDetector(aPhotonSD);
0197                 }
0198             }
0199         }
0200     }
0201 
0202   private:
0203     std::filesystem::path gdml_file_;
0204     G4GDMLParser parser_;
0205 };
0206 
0207 struct PrimaryGenerator : G4VUserPrimaryGeneratorAction
0208 {
0209     gphox::Config cfg;
0210     SEvt *sev;
0211 
0212     PrimaryGenerator(const gphox::Config& cfg, SEvt *sev) : cfg(cfg), sev(sev)
0213     {
0214     }
0215 
0216     void GeneratePrimaries(G4Event *event) override
0217     {
0218         std::vector<sphoton> sphotons = generate_photons(cfg.torch);
0219 
0220         size_t num_floats = sphotons.size()*4*4;
0221         float* data = reinterpret_cast<float*>(sphotons.data());
0222         NP* photons = NP::MakeFromValues<float>(data, num_floats);
0223 
0224         photons->reshape({ static_cast<int64_t>(sphotons.size()), 4, 4});
0225 
0226         for (const sphoton& p : sphotons)
0227         {
0228             G4ThreeVector position_mm(p.pos.x, p.pos.y, p.pos.z);
0229             G4double time_ns = p.time;
0230             G4ThreeVector direction(p.mom.x, p.mom.y, p.mom.z);
0231             // direction = direction.unit();
0232             G4double wavelength_nm = p.wavelength;
0233             G4ThreeVector polarization(p.pol.x, p.pol.y, p.pol.z);
0234 
0235             G4PrimaryVertex *vertex = new G4PrimaryVertex(position_mm, time_ns);
0236             G4double kineticEnergy = h_Planck * c_light / (wavelength_nm * nm);
0237 
0238             G4PrimaryParticle *particle = new G4PrimaryParticle(G4OpticalPhoton::Definition());
0239             particle->SetKineticEnergy(kineticEnergy);
0240             particle->SetMomentumDirection(direction);
0241             particle->SetPolarization(polarization);
0242 
0243             vertex->SetPrimary(particle);
0244             event->AddPrimaryVertex(vertex);
0245         }
0246 
0247         sev->SetInputPhoton(photons);
0248     }
0249 };
0250 
0251 struct EventAction : G4UserEventAction
0252 {
0253     SEvt *sev;
0254 
0255     EventAction(SEvt *sev) : sev(sev)
0256     {
0257     }
0258 
0259     void BeginOfEventAction(const G4Event *event) override
0260     {
0261         sev->beginOfEvent(event->GetEventID());
0262     }
0263 
0264     void EndOfEventAction(const G4Event *event) override
0265     {
0266         int eventID = event->GetEventID();
0267         sev->addEventConfigArray();
0268         sev->gather();
0269         sev->endOfEvent(eventID);
0270 
0271         // GPU-based simulation
0272         G4CXOpticks *gx = G4CXOpticks::Get();
0273 
0274         gx->simulate(eventID, false);
0275         cudaDeviceSynchronize();
0276 
0277         unsigned int num_hits = SEvt::GetNumHit(SEvt::EGPU);
0278 
0279         std::cout << "Opticks: NumHits:  " << num_hits << std::endl;
0280 
0281         SEvt *sev = SEvt::Get_EGPU();
0282         NP *hits = NP::Make<float>(num_hits, 4, 4);
0283 
0284         for (unsigned idx = 0; idx < num_hits; idx++)
0285         {
0286             sphoton *photon = reinterpret_cast<sphoton *>(hits->values<float>() + idx * 16);
0287             sev->getHit(*photon, idx);
0288         }
0289 
0290         hits->save("o_hits.npy");
0291         delete hits;
0292 
0293         gx->reset(eventID);
0294     }
0295 };
0296 
0297 void get_label(spho &ulabel, const G4Track *track)
0298 {
0299     spho *label = STrackInfo::GetRef(track);
0300     assert(label && label->isDefined() && "all photons are expected to be labelled");
0301 
0302     std::array<int, spho::N> a_label;
0303     label->serialize(a_label);
0304 
0305     ulabel.load(a_label);
0306 }
0307 
0308 struct SteppingAction : G4UserSteppingAction
0309 {
0310     SEvt *sev;
0311 
0312     SteppingAction(SEvt *sev) : sev(sev)
0313     {
0314     }
0315 
0316     void UserSteppingAction(const G4Step *step)
0317     {
0318         if (step->GetTrack()->GetDefinition() != G4OpticalPhoton::OpticalPhotonDefinition())
0319             return;
0320 
0321         const G4VProcess *process = step->GetPreStepPoint()->GetProcessDefinedStep();
0322 
0323         if (process == nullptr)
0324             return;
0325 
0326         const G4Track *track = step->GetTrack();
0327         G4VPhysicalVolume *pv = track->GetVolume();
0328         const G4VTouchable *touch = track->GetTouchable();
0329 
0330         spho ulabel = {};
0331         get_label(ulabel, track);
0332 
0333         const G4StepPoint *pre = step->GetPreStepPoint();
0334         const G4StepPoint *post = step->GetPostStepPoint();
0335 
0336         sev->checkPhotonLineage(ulabel);
0337 
0338         sphoton &current_photon = sev->current_ctx.p;
0339 
0340         if (current_photon.flagmask_count() == 1)
0341         {
0342             U4StepPoint::Update(current_photon, pre); // populate current_photon with pos, mom, pol, time, wavelength
0343             sev->pointPhoton(ulabel);                 // copying current into buffers
0344         }
0345 
0346         bool tir;
0347         unsigned flag = U4StepPoint::Flag<G4OpBoundaryProcess>(post, true, tir);
0348         bool is_detect_flag = OpticksPhoton::IsSurfaceDetectFlag(flag);
0349 
0350         current_photon.hitcount_iindex =
0351             is_detect_flag ? U4Touchable::ImmediateReplicaNumber(touch) : U4Touchable::AncestorReplicaNumber(touch);
0352 
0353         U4StepPoint::Update(current_photon, post);
0354 
0355         current_photon.set_flag(flag);
0356 
0357         sev->pointPhoton(ulabel);
0358     }
0359 };
0360 
0361 struct TrackingAction : G4UserTrackingAction
0362 {
0363     const G4Track *transient_fSuspend_track = nullptr;
0364     SEvt *sev;
0365 
0366     TrackingAction(SEvt *sev) : sev(sev)
0367     {
0368     }
0369 
0370     void PreUserTrackingAction_Optical_FabricateLabel(const G4Track *track)
0371     {
0372         U4Track::SetFabricatedLabel(track);
0373         spho *label = STrackInfo::GetRef(track);
0374         assert(label);
0375     }
0376 
0377     void PreUserTrackingAction(const G4Track *track) override
0378     {
0379         spho *label = STrackInfo::GetRef(track);
0380 
0381         if (label == nullptr)
0382         {
0383             PreUserTrackingAction_Optical_FabricateLabel(track);
0384             label = STrackInfo::GetRef(track);
0385         }
0386 
0387         assert(label && label->isDefined());
0388 
0389         std::array<int, spho::N> a_label;
0390         label->serialize(a_label);
0391 
0392         spho ulabel = {};
0393         ulabel.load(a_label);
0394 
0395         U4Random::SetSequenceIndex(ulabel.id);
0396 
0397         bool resume_fSuspend = track == transient_fSuspend_track;
0398 
0399         if (ulabel.gen() == 0)
0400         {
0401             if (resume_fSuspend == false)
0402                 sev->beginPhoton(ulabel);
0403             else
0404                 sev->resumePhoton(ulabel);
0405         }
0406         else if (ulabel.gen() > 0)
0407         {
0408             if (resume_fSuspend == false)
0409                 sev->rjoinPhoton(ulabel);
0410             else
0411                 sev->rjoin_resumePhoton(ulabel);
0412         }
0413     }
0414 
0415     void PostUserTrackingAction(const G4Track *track) override
0416     {
0417         G4TrackStatus tstat = track->GetTrackStatus();
0418 
0419         bool is_fStopAndKill = tstat == fStopAndKill;
0420         bool is_fSuspend = tstat == fSuspend;
0421         bool is_fStopAndKill_or_fSuspend = is_fStopAndKill || is_fSuspend;
0422 
0423         assert(is_fStopAndKill_or_fSuspend);
0424 
0425         spho ulabel = {};
0426         get_label(ulabel, track);
0427 
0428         if (is_fStopAndKill)
0429         {
0430             U4Random::SetSequenceIndex(-1);
0431             sev->finalPhoton(ulabel);
0432             transient_fSuspend_track = nullptr;
0433         }
0434         else if (is_fSuspend)
0435         {
0436             transient_fSuspend_track = track;
0437         }
0438     }
0439 };
0440 
0441 struct G4App
0442 {
0443     G4App(const gphox::Config& cfg, std::filesystem::path gdml_file)
0444         : sev(SEvt::CreateOrReuse_ECPU()), det_cons_(new DetectorConstruction(gdml_file)),
0445           prim_gen_(new PrimaryGenerator(cfg, sev)), event_act_(new EventAction(sev)), stepping_(new SteppingAction(sev)),
0446           tracking_(new TrackingAction(sev))
0447     {
0448     }
0449 
0450     //~G4App(){ G4CXOpticks::Finalize();}
0451 
0452     // Create "global" event
0453     SEvt *sev;
0454 
0455     G4VUserDetectorConstruction *det_cons_;
0456     G4VUserPrimaryGeneratorAction *prim_gen_;
0457     EventAction *event_act_;
0458     SteppingAction *stepping_;
0459     TrackingAction *tracking_;
0460 };