Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-09 07:49:12

0001 #include <filesystem>
0002 #include <fstream>
0003 #include <vector>
0004 
0005 #include "G4BooleanSolid.hh"
0006 #include "G4Event.hh"
0007 #include "G4GDMLParser.hh"
0008 #include "G4LogicalVolumeStore.hh"
0009 #include "G4OpBoundaryProcess.hh"
0010 #include "G4OpticalPhoton.hh"
0011 #include "G4PhysicalConstants.hh"
0012 #include "G4PrimaryParticle.hh"
0013 #include "G4PrimaryVertex.hh"
0014 #include "G4SDManager.hh"
0015 #include "G4SubtractionSolid.hh"
0016 #include "G4SystemOfUnits.hh"
0017 #include "G4ThreeVector.hh"
0018 #include "G4Track.hh"
0019 #include "G4TrackStatus.hh"
0020 #include "G4UserEventAction.hh"
0021 #include "G4UserRunAction.hh"
0022 #include "G4UserSteppingAction.hh"
0023 #include "G4UserTrackingAction.hh"
0024 #include "G4VPhysicalVolume.hh"
0025 #include "G4VProcess.hh"
0026 #include "G4VUserDetectorConstruction.hh"
0027 #include "G4VUserPrimaryGeneratorAction.hh"
0028 
0029 #include "g4cx/G4CXOpticks.hh"
0030 #include "sysrap/NP.hh"
0031 #include "sysrap/SEvt.hh"
0032 #include "sysrap/STrackInfo.h"
0033 #include "sysrap/spho.h"
0034 #include "sysrap/sphoton.h"
0035 #include "u4/U4Random.hh"
0036 #include "u4/U4StepPoint.hh"
0037 #include "u4/U4Touchable.h"
0038 #include "u4/U4Track.h"
0039 
0040 #include "config.h"
0041 #include "torch.h"
0042 
0043 struct PhotonHit : public G4VHit
0044 {
0045     PhotonHit() = default;
0046     PhotonHit(G4double energy, G4double time, G4ThreeVector position, G4ThreeVector direction,
0047               G4ThreeVector polarization)
0048         : photon()
0049     {
0050         photon.pos = {static_cast<float>(position.x()), static_cast<float>(position.y()),
0051                       static_cast<float>(position.z())};
0052         photon.time = time;
0053         photon.mom = {static_cast<float>(direction.x()), static_cast<float>(direction.y()),
0054                       static_cast<float>(direction.z())};
0055         photon.pol = {static_cast<float>(polarization.x()), static_cast<float>(polarization.y()),
0056                       static_cast<float>(polarization.z())};
0057         photon.wavelength = h_Planck * c_light / (energy * CLHEP::eV);
0058     }
0059 
0060     void Print() override
0061     {
0062         G4cout << photon << G4endl;
0063     }
0064 
0065     sphoton photon;
0066 };
0067 
0068 using PhotonHitsCollection = G4THitsCollection<PhotonHit>;
0069 
0070 struct PhotonSD : public G4VSensitiveDetector
0071 {
0072     PhotonSD(G4String name) : G4VSensitiveDetector(name), fHCID(-1)
0073     {
0074         G4String HCname = name + "_HC";
0075         collectionName.insert(HCname);
0076         G4cout << collectionName.size() << "   PhotonSD name:  " << name << " collection Name: " << HCname << G4endl;
0077     }
0078 
0079     void Initialize(G4HCofThisEvent *hce) override
0080     {
0081         fPhotonHitsCollection = new PhotonHitsCollection(SensitiveDetectorName, collectionName[0]);
0082         if (fHCID < 0)
0083         {
0084             G4cout << "PhotonSD::Initialize:  " << SensitiveDetectorName << "   " << collectionName[0] << G4endl;
0085             fHCID = G4SDManager::GetSDMpointer()->GetCollectionID(collectionName[0]);
0086         }
0087         hce->AddHitsCollection(fHCID, fPhotonHitsCollection);
0088     }
0089 
0090     G4bool ProcessHits(G4Step *aStep, G4TouchableHistory *) override
0091     {
0092         G4Track *track = aStep->GetTrack();
0093 
0094         if (track->GetDefinition() != G4OpticalPhoton::OpticalPhotonDefinition())
0095             return false;
0096 
0097         PhotonHit *hit = new PhotonHit(
0098             track->GetTotalEnergy(), track->GetGlobalTime(), aStep->GetPostStepPoint()->GetPosition(),
0099             aStep->GetPostStepPoint()->GetMomentumDirection(), aStep->GetPostStepPoint()->GetPolarization());
0100 
0101         fPhotonHitsCollection->insert(hit);
0102         track->SetTrackStatus(fStopAndKill);
0103 
0104         return true;
0105     }
0106 
0107     void EndOfEvent(G4HCofThisEvent *) override
0108     {
0109         G4int NbHits = fPhotonHitsCollection->entries();
0110         G4cout << "PhotonSD::EndOfEvent Number of PhotonHits: " << NbHits << G4endl;
0111     }
0112 
0113   private:
0114     PhotonHitsCollection *fPhotonHitsCollection{nullptr};
0115     G4int fHCID;
0116 };
0117 
0118 struct DetectorConstruction : G4VUserDetectorConstruction
0119 {
0120     DetectorConstruction(std::filesystem::path gdml_file) : gdml_file_(gdml_file)
0121     {
0122     }
0123 
0124     G4VPhysicalVolume *Construct() override
0125     {
0126         parser_.Read(gdml_file_.string(), false);
0127         G4VPhysicalVolume *world = parser_.GetWorldVolume();
0128 
0129         G4CXOpticks::SetGeometry(world);
0130 
0131         return world;
0132     }
0133 
0134     void ConstructSDandField() override
0135     {
0136         G4cout << "ConstructSDandField is called." << G4endl;
0137         G4SDManager *SDman = G4SDManager::GetSDMpointer();
0138 
0139         const G4GDMLAuxMapType *auxmap = parser_.GetAuxMap();
0140         for (auto const &[logVol, listType] : *auxmap)
0141         {
0142             for (auto const &auxtype : listType)
0143             {
0144                 if (auxtype.type == "SensDet")
0145                 {
0146                     G4cout << "Attaching sensitive detector to logical volume: " << logVol->GetName() << G4endl;
0147                     G4String name = logVol->GetName() + "_PhotonDetector";
0148                     PhotonSD *aPhotonSD = new PhotonSD(name);
0149                     SDman->AddNewDetector(aPhotonSD);
0150                     logVol->SetSensitiveDetector(aPhotonSD);
0151                 }
0152             }
0153         }
0154     }
0155 
0156   private:
0157     std::filesystem::path gdml_file_;
0158     G4GDMLParser parser_;
0159 };
0160 
0161 struct PrimaryGenerator : G4VUserPrimaryGeneratorAction
0162 {
0163     gphox::Config cfg;
0164     SEvt *sev;
0165 
0166     PrimaryGenerator(const gphox::Config &cfg, SEvt *sev) : cfg(cfg), sev(sev)
0167     {
0168     }
0169 
0170     void GeneratePrimaries(G4Event *event) override
0171     {
0172         std::vector<sphoton> sphotons = generate_photons(cfg.torch);
0173 
0174         size_t num_floats = sphotons.size() * 4 * 4;
0175         float *data = reinterpret_cast<float *>(sphotons.data());
0176         NP *photons = NP::MakeFromValues<float>(data, num_floats);
0177 
0178         photons->reshape({static_cast<int64_t>(sphotons.size()), 4, 4});
0179 
0180         for (const sphoton &p : sphotons)
0181         {
0182             G4ThreeVector position_mm(p.pos.x, p.pos.y, p.pos.z);
0183             G4double time_ns = p.time;
0184             G4ThreeVector direction(p.mom.x, p.mom.y, p.mom.z);
0185             G4double wavelength_nm = p.wavelength;
0186             G4ThreeVector polarization(p.pol.x, p.pol.y, p.pol.z);
0187 
0188             G4PrimaryVertex *vertex = new G4PrimaryVertex(position_mm, time_ns);
0189             G4double kineticEnergy = h_Planck * c_light / (wavelength_nm * nm);
0190 
0191             G4PrimaryParticle *particle = new G4PrimaryParticle(G4OpticalPhoton::Definition());
0192             particle->SetKineticEnergy(kineticEnergy);
0193             particle->SetMomentumDirection(direction);
0194             particle->SetPolarization(polarization);
0195 
0196             vertex->SetPrimary(particle);
0197             event->AddPrimaryVertex(vertex);
0198         }
0199 
0200         sev->SetInputPhoton(photons);
0201     }
0202 };
0203 
0204 struct EventAction : G4UserEventAction
0205 {
0206     SEvt *sev;
0207     G4int fTotalG4Hits{0};
0208     unsigned int fTotalOpticksHits{0};
0209 
0210     EventAction(SEvt *sev) : sev(sev)
0211     {
0212     }
0213 
0214     void BeginOfEventAction(const G4Event *event) override
0215     {
0216         sev->beginOfEvent(event->GetEventID());
0217     }
0218 
0219     void EndOfEventAction(const G4Event *event) override
0220     {
0221         int eventID = event->GetEventID();
0222         sev->addEventConfigArray();
0223         sev->gather();
0224         sev->endOfEvent(eventID);
0225 
0226         // GPU-based simulation
0227         G4CXOpticks *gx = G4CXOpticks::Get();
0228 
0229         gx->simulate(eventID, false);
0230         cudaDeviceSynchronize();
0231 
0232         SEvt *sev_gpu = SEvt::Get_EGPU();
0233         unsigned int num_hits = sev_gpu->GetNumHit(0);
0234 
0235         std::cout << "Opticks: NumHits:  " << num_hits << std::endl;
0236         fTotalOpticksHits += num_hits;
0237 
0238         std::ofstream outFile("opticks_hits_output.txt");
0239         if (!outFile.is_open())
0240         {
0241             std::cerr << "Error opening output file!" << std::endl;
0242         }
0243         else
0244         {
0245             for (int idx = 0; idx < int(num_hits); idx++)
0246             {
0247                 sphoton hit;
0248                 sev_gpu->getHit(hit, idx);
0249                 G4ThreeVector position = G4ThreeVector(hit.pos.x, hit.pos.y, hit.pos.z);
0250                 G4ThreeVector direction = G4ThreeVector(hit.mom.x, hit.mom.y, hit.mom.z);
0251                 G4ThreeVector polarization = G4ThreeVector(hit.pol.x, hit.pol.y, hit.pol.z);
0252                 outFile << hit.time << " " << hit.wavelength << "  " << "(" << position.x() << ", " << position.y()
0253                         << ", " << position.z() << ")  " << "(" << direction.x() << ", " << direction.y() << ", "
0254                         << direction.z() << ")  " << "(" << polarization.x() << ", " << polarization.y() << ", "
0255                         << polarization.z() << ")" << std::endl;
0256             }
0257             outFile.close();
0258         }
0259 
0260         gx->reset(eventID);
0261 
0262         // Accumulate and write G4 hits
0263         G4HCofThisEvent *hce = event->GetHCofThisEvent();
0264         if (hce)
0265         {
0266             std::ofstream g4OutFile("g4_hits_output.txt");
0267             if (!g4OutFile.is_open())
0268             {
0269                 std::cerr << "Error opening G4 output file!" << std::endl;
0270             }
0271 
0272             for (G4int i = 0; i < hce->GetNumberOfCollections(); i++)
0273             {
0274                 PhotonHitsCollection *hc = dynamic_cast<PhotonHitsCollection *>(hce->GetHC(i));
0275                 if (hc)
0276                 {
0277                     fTotalG4Hits += hc->entries();
0278                     if (g4OutFile.is_open())
0279                     {
0280                         for (size_t j = 0; j < hc->entries(); j++)
0281                         {
0282                             const sphoton &p = (*hc)[j]->photon;
0283                             G4ThreeVector position(p.pos.x, p.pos.y, p.pos.z);
0284                             G4ThreeVector direction(p.mom.x, p.mom.y, p.mom.z);
0285                             G4ThreeVector polarization(p.pol.x, p.pol.y, p.pol.z);
0286                             g4OutFile << p.time << " " << p.wavelength << "  " << "(" << position.x() << ", "
0287                                       << position.y() << ", " << position.z() << ")  " << "(" << direction.x() << ", "
0288                                       << direction.y() << ", " << direction.z() << ")  " << "(" << polarization.x()
0289                                       << ", " << polarization.y() << ", " << polarization.z() << ")" << std::endl;
0290                         }
0291                     }
0292                 }
0293             }
0294 
0295             if (g4OutFile.is_open())
0296             {
0297                 g4OutFile.close();
0298             }
0299         }
0300     }
0301 
0302     G4int GetTotalG4Hits() const
0303     {
0304         return fTotalG4Hits;
0305     }
0306 
0307     unsigned int GetTotalOpticksHits() const
0308     {
0309         return fTotalOpticksHits;
0310     }
0311 };
0312 
0313 void get_label(spho &ulabel, const G4Track *track)
0314 {
0315     spho *label = STrackInfo::GetRef(track);
0316     assert(label && label->isDefined() && "all photons are expected to be labelled");
0317 
0318     std::array<int, spho::N> a_label;
0319     label->serialize(a_label);
0320 
0321     ulabel.load(a_label);
0322 }
0323 
0324 struct SteppingAction : G4UserSteppingAction
0325 {
0326     SEvt *sev;
0327 
0328     SteppingAction(SEvt *sev) : sev(sev)
0329     {
0330     }
0331 
0332     void UserSteppingAction(const G4Step *step)
0333     {
0334         if (step->GetTrack()->GetDefinition() != G4OpticalPhoton::OpticalPhotonDefinition())
0335             return;
0336 
0337         const G4VProcess *process = step->GetPreStepPoint()->GetProcessDefinedStep();
0338 
0339         if (process == nullptr)
0340             return;
0341 
0342         const G4Track *track = step->GetTrack();
0343         G4VPhysicalVolume *pv = track->GetVolume();
0344         const G4VTouchable *touch = track->GetTouchable();
0345 
0346         spho ulabel = {};
0347         get_label(ulabel, track);
0348 
0349         const G4StepPoint *pre = step->GetPreStepPoint();
0350         const G4StepPoint *post = step->GetPostStepPoint();
0351 
0352         sev->checkPhotonLineage(ulabel);
0353 
0354         sphoton &current_photon = sev->current_ctx.p;
0355 
0356         if (current_photon.flagmask_count() == 1)
0357         {
0358             U4StepPoint::Update(current_photon, pre);
0359             sev->pointPhoton(ulabel);
0360         }
0361 
0362         bool tir;
0363         unsigned flag = U4StepPoint::Flag<G4OpBoundaryProcess>(post, true, tir);
0364         bool is_detect_flag = OpticksPhoton::IsSurfaceDetectFlag(flag);
0365 
0366         int touch_depth = touch->GetHistoryDepth();
0367         if (touch_depth > 1)
0368         {
0369             current_photon.hitcount_iindex =
0370                 is_detect_flag ? U4Touchable::ImmediateReplicaNumber(touch) : U4Touchable::AncestorReplicaNumber(touch);
0371         }
0372         else
0373         {
0374             current_photon.hitcount_iindex = 0;
0375         }
0376 
0377         U4StepPoint::Update(current_photon, post);
0378 
0379         current_photon.set_flag(flag);
0380 
0381         sev->pointPhoton(ulabel);
0382     }
0383 };
0384 
0385 struct TrackingAction : G4UserTrackingAction
0386 {
0387     const G4Track *transient_fSuspend_track = nullptr;
0388     SEvt *sev;
0389 
0390     TrackingAction(SEvt *sev) : sev(sev)
0391     {
0392     }
0393 
0394     void PreUserTrackingAction_Optical_FabricateLabel(const G4Track *track)
0395     {
0396         U4Track::SetFabricatedLabel(track);
0397         spho *label = STrackInfo::GetRef(track);
0398         assert(label);
0399     }
0400 
0401     void PreUserTrackingAction(const G4Track *track) override
0402     {
0403         spho *label = STrackInfo::GetRef(track);
0404 
0405         if (label == nullptr)
0406         {
0407             PreUserTrackingAction_Optical_FabricateLabel(track);
0408             label = STrackInfo::GetRef(track);
0409         }
0410 
0411         assert(label && label->isDefined());
0412 
0413         std::array<int, spho::N> a_label;
0414         label->serialize(a_label);
0415 
0416         spho ulabel = {};
0417         ulabel.load(a_label);
0418 
0419         U4Random::SetSequenceIndex(ulabel.id);
0420 
0421         bool resume_fSuspend = track == transient_fSuspend_track;
0422 
0423         if (ulabel.gen() == 0)
0424         {
0425             if (resume_fSuspend == false)
0426                 sev->beginPhoton(ulabel);
0427             else
0428                 sev->resumePhoton(ulabel);
0429         }
0430         else if (ulabel.gen() > 0)
0431         {
0432             if (resume_fSuspend == false)
0433                 sev->rjoinPhoton(ulabel);
0434             else
0435                 sev->rjoin_resumePhoton(ulabel);
0436         }
0437     }
0438 
0439     void PostUserTrackingAction(const G4Track *track) override
0440     {
0441         G4TrackStatus tstat = track->GetTrackStatus();
0442 
0443         bool is_fStopAndKill = tstat == fStopAndKill;
0444         bool is_fSuspend = tstat == fSuspend;
0445         bool is_fStopAndKill_or_fSuspend = is_fStopAndKill || is_fSuspend;
0446 
0447         assert(is_fStopAndKill_or_fSuspend);
0448 
0449         spho ulabel = {};
0450         get_label(ulabel, track);
0451 
0452         if (is_fStopAndKill)
0453         {
0454             U4Random::SetSequenceIndex(-1);
0455             sev->finalPhoton(ulabel);
0456             transient_fSuspend_track = nullptr;
0457         }
0458         else if (is_fSuspend)
0459         {
0460             transient_fSuspend_track = track;
0461         }
0462     }
0463 };
0464 
0465 struct RunAction : G4UserRunAction
0466 {
0467     EventAction *fEventAction;
0468 
0469     RunAction(EventAction *eventAction) : fEventAction(eventAction)
0470     {
0471     }
0472 
0473     void BeginOfRunAction(const G4Run *) override
0474     {
0475     }
0476 
0477     void EndOfRunAction(const G4Run *) override
0478     {
0479         std::cout << "Opticks: NumHits:  " << fEventAction->GetTotalOpticksHits() << std::endl;
0480         std::cout << "Geant4: NumHits:  " << fEventAction->GetTotalG4Hits() << std::endl;
0481     }
0482 };
0483 
0484 struct G4App
0485 {
0486     G4App(const gphox::Config &cfg, std::filesystem::path gdml_file)
0487         : sev(SEvt::CreateOrReuse_ECPU()), det_cons_(new DetectorConstruction(gdml_file)),
0488           prim_gen_(new PrimaryGenerator(cfg, sev)), event_act_(new EventAction(sev)),
0489           run_act_(new RunAction(event_act_)), stepping_(new SteppingAction(sev)), tracking_(new TrackingAction(sev))
0490     {
0491     }
0492 
0493     SEvt *sev;
0494 
0495     G4VUserDetectorConstruction *det_cons_;
0496     G4VUserPrimaryGeneratorAction *prim_gen_;
0497     EventAction *event_act_;
0498     RunAction *run_act_;
0499     SteppingAction *stepping_;
0500     TrackingAction *tracking_;
0501 };