Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-15 09:42:16

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 #include "Acts/Seeding/detail/FastStrawLineFitter.hpp"
0012 
0013 #include "Acts/Definitions/Units.hpp"
0014 #include "Acts/Utilities/Enumerate.hpp"
0015 
0016 #include <format>
0017 namespace Acts::Experimental::detail {
0018 
0019 template <CompositeSpacePointContainer StrawCont_t>
0020 std::optional<FastStrawLineFitter::FitResult> FastStrawLineFitter::fit(
0021     const StrawCont_t& measurements, const std::vector<int>& signs) const {
0022   if (measurements.size() != signs.size()) {
0023     ACTS_WARNING(
0024         __func__ << "() - " << __LINE__
0025                  << ": Not all measurements are associated with a drift sign");
0026     return std::nullopt;
0027   }
0028 
0029   auto result = fit(fillAuxiliaries(measurements, signs));
0030   if (!result) {
0031     return std::nullopt;
0032   }
0033   // Calculate the chi2
0034   calcPostFitChi2(measurements, *result);
0035   return result;
0036 }
0037 
0038 template <CompositeSpacePointContainer StrawCont_t>
0039 void FastStrawLineFitter::calcPostFitChi2(const StrawCont_t& measurements,
0040                                           FitResult& result) const {
0041   const TrigonomHelper angles{result.theta};
0042   result.chi2 = 0.;
0043   for (const auto& strawMeas : measurements) {
0044     result.chi2 += chi2Term(angles, result.y0, *strawMeas);
0045   }
0046   ACTS_DEBUG(__func__ << "() - " << __LINE__ << ": Overall chi2: "
0047                       << result.chi2 << ", nDoF: " << result.nDoF
0048                       << ", redChi2: " << (result.chi2 / result.nDoF));
0049 }
0050 
0051 template <CompositeSpacePoint Point_t>
0052 double FastStrawLineFitter::chi2Term(const TrigonomHelper& angle,
0053                                      const double y0, const Point_t& strawMeas,
0054                                      std::optional<double> r) const {
0055   if (!strawMeas.isStraw()) {
0056     return 0.;
0057   }
0058   const double cov = strawMeas.covariance()[s_covIdx];
0059   if (cov < std::numeric_limits<double>::epsilon()) {
0060     return 0.;
0061   }
0062   const Vector& pos = strawMeas.localPosition();
0063   const double y = pos.dot(strawMeas.toNextSensor());
0064   const double z = pos.dot(strawMeas.planeNormal());
0065   const double dist = Acts::abs((y - y0) * angle.cosTheta - z * angle.sinTheta);
0066   ACTS_VERBOSE(__func__ << "() - " << __LINE__ << ": Distance straw (" << y
0067                         << ", " << z
0068                         << "), r: " << r.value_or(strawMeas.driftRadius())
0069                         << " - track: " << dist);
0070   return Acts::pow(dist - r.value_or(strawMeas.driftRadius()), 2) / cov;
0071 }
0072 
0073 template <CompositeSpacePointContainer StripCont_t>
0074 std::optional<FastStrawLineFitter::FitResult> FastStrawLineFitter::fit(
0075     const StripCont_t& measurements, const ResidualIdx projection) const {
0076   if (projection == ResidualIdx::time) {
0077     ACTS_WARNING(__func__ << "() - " << __LINE__
0078                           << ": Only spatial projections, "
0079                           << "i.e. nonBending / bending are sensible");
0080     return std::nullopt;
0081   }
0082 
0083   FitAuxiliaries auxVars{};
0084   Vector centerOfGravity{Vector::Zero()};
0085   // @brief Selector function to check that the current strip provides a constraint in the projetor direction
0086   auto select = [&projection](const auto& strip) -> bool {
0087     // Skip straw measurements that are not twins &
0088     // don't measure non-bending coordinate
0089     if (strip->isStraw()) {
0090       return strip->measuresLoc0() && projection == ResidualIdx::nonBending;
0091     }
0092     // Check that the strip is actually measuring the projection
0093     return (strip->measuresLoc0() && projection == ResidualIdx::nonBending) ||
0094            (strip->measuresLoc1() && projection == ResidualIdx::bending);
0095   };
0096 
0097   auxVars.invCovs.resize(measurements.size());
0098   for (const auto& [sIdx, strip] : enumerate(measurements)) {
0099     if (!select(strip)) {
0100       ACTS_VERBOSE(__func__ << "() - " << __LINE__
0101                             << ": Skip strip measurement @"
0102                             << toString(strip->localPosition()));
0103       continue;
0104     }
0105     const auto& invCov =
0106         (auxVars.invCovs[sIdx] =
0107              1. / strip->covariance()[toUnderlying(projection)]);
0108     auxVars.covNorm += invCov;
0109     centerOfGravity += invCov * strip->localPosition();
0110     ++auxVars.nDoF;
0111   }
0112   // To little information provided
0113   if (auxVars.nDoF < 3) {
0114     return std::nullopt;
0115   }
0116   // Reduce the number of degrees of freedom by 2 to account
0117   // for the two free parameters to fit
0118   auxVars.nDoF -= 2u;
0119   auxVars.covNorm = 1. / auxVars.covNorm;
0120   centerOfGravity *= auxVars.covNorm;
0121 
0122   bool centerSet{false};
0123   for (const auto& [sIdx, strip] : enumerate(measurements)) {
0124     if (!select(strip)) {
0125       continue;
0126     }
0127     const Vector pos = strip->localPosition() - centerOfGravity;
0128     const Vector& measDir{
0129         (projection == ResidualIdx::nonBending && strip->measuresLoc1()) ||
0130                 strip->isStraw()
0131             ? strip->sensorDirection()
0132             : strip->toNextSensor()};
0133 
0134     if (!centerSet) {
0135       auxVars.centerY = centerOfGravity.dot(measDir);
0136       auxVars.centerZ = centerOfGravity.dot(strip->planeNormal());
0137       centerSet = true;
0138     }
0139 
0140     const double y = pos.dot(measDir);
0141     const double z = pos.dot(strip->planeNormal());
0142 
0143     const auto& invCov = auxVars.invCovs[sIdx];
0144     auxVars.T_zzyy += invCov * (Acts::square(z) - Acts::square(y));
0145     auxVars.T_yz += invCov * z * y;
0146   }
0147   return fit(auxVars);
0148 }
0149 
0150 template <CompositeSpacePointContainer StrawCont_t,
0151           CompositeSpacePointFastCalibrator<
0152               Acts::RemovePointer_t<typename StrawCont_t::value_type>>
0153               Calibrator_t>
0154 void FastStrawLineFitter::calcPostFitChi2(const Acts::CalibrationContext& ctx,
0155                                           const StrawCont_t& measurements,
0156                                           const Calibrator_t& calibrator,
0157                                           FitResultT0& result) const {
0158   const TrigonomHelper angles{result.theta};
0159   result.chi2 = 0.;
0160   for (const auto& strawMeas : measurements) {
0161     result.chi2 += chi2Term(angles, result.y0, *strawMeas,
0162                             calibrator.driftRadius(ctx, *strawMeas, result.t0));
0163   }
0164   ACTS_DEBUG(__func__ << "() - " << __LINE__ << ": Overall chi2: "
0165                       << result.chi2 << ", nDoF: " << result.nDoF
0166                       << ", redChi2: " << (result.chi2 / result.nDoF));
0167 }
0168 
0169 template <CompositeSpacePointContainer StrawCont_t>
0170 FastStrawLineFitter::FitAuxiliaries FastStrawLineFitter::fillAuxiliaries(
0171     const StrawCont_t& measurements, const std::vector<int>& signs) const {
0172   FitAuxiliaries auxVars{};
0173   auxVars.invCovs.resize(signs.size(), -1.);
0174   Vector centerOfGravity{Vector::Zero()};
0175 
0176   // Calculate first the center of gravity
0177   for (const auto& [sIdx, strawMeas] : enumerate(measurements)) {
0178     if (!strawMeas->isStraw()) {
0179       ACTS_DEBUG(__func__ << "() - " << __LINE__
0180                           << ": The measurement is not a straw");
0181       continue;
0182     }
0183     const double cov = strawMeas->covariance()[s_covIdx];
0184     if (cov < std::numeric_limits<double>::epsilon()) {
0185       ACTS_WARNING(__func__ << "() - " << __LINE__ << ": The covariance ("
0186                             << cov << ") of the measurement is invalid.");
0187       continue;
0188     }
0189     auto& invCov = (auxVars.invCovs[sIdx] = 1. / cov);
0190     auxVars.covNorm += invCov;
0191     centerOfGravity += invCov * strawMeas->localPosition();
0192     ++auxVars.nDoF;
0193   }
0194   if (auxVars.nDoF < 3) {
0195     std::stringstream sstr{};
0196     for (const auto& [sIdx, strawMeas] : enumerate(measurements)) {
0197       sstr << " --- " << (sIdx + 1) << ") "
0198            << toString(strawMeas->localPosition())
0199            << ", r: " << strawMeas->driftRadius()
0200            << ", weight: " << auxVars.invCovs[sIdx] << std::endl;
0201     }
0202     ACTS_WARNING(__func__ << "() - " << __LINE__
0203                           << ": At least 3 measurements are required to "
0204                              "perform the straw line fit\n"
0205                           << sstr.str());
0206     auxVars.nDoF = 0u;
0207     return auxVars;
0208   }
0209   // Reduce the number of degrees of freedom by 2 to account
0210   // for the two free parameters to fit
0211   auxVars.nDoF -= 2u;
0212   auxVars.covNorm = 1. / auxVars.covNorm;
0213   centerOfGravity *= auxVars.covNorm;
0214 
0215   // Now calculate the fit constants
0216   bool centerSet{false};
0217   for (const auto& [sIdx, strawMeas] : enumerate(measurements)) {
0218     const auto& invCov = auxVars.invCovs[sIdx];
0219     // Invalid measurements were marked
0220     if (invCov < 0.) {
0221       continue;
0222     }
0223     if (!centerSet) {
0224       auxVars.centerY = centerOfGravity.dot(strawMeas->toNextSensor());
0225       auxVars.centerZ = centerOfGravity.dot(strawMeas->planeNormal());
0226       centerSet = true;
0227     }
0228     const Vector pos = strawMeas->localPosition() - centerOfGravity;
0229     const double y = pos.dot(strawMeas->toNextSensor());
0230     const double z = pos.dot(strawMeas->planeNormal());
0231     const double r = strawMeas->driftRadius();
0232 
0233     auxVars.T_zzyy += invCov * (Acts::square(z) - Acts::square(y));
0234     auxVars.T_yz += invCov * z * y;
0235     const double sInvCov = -invCov * signs[sIdx];
0236     auxVars.T_rz += sInvCov * z * r;
0237     auxVars.T_ry += sInvCov * y * r;
0238     auxVars.fitY0 += sInvCov * r;
0239   }
0240   auxVars.fitY0 *= auxVars.covNorm;
0241 
0242   return auxVars;
0243 }
0244 
0245 template <CompositeSpacePointContainer StrawCont_t,
0246           CompositeSpacePointFastCalibrator<
0247               Acts::RemovePointer_t<typename StrawCont_t::value_type>>
0248               Calibrator_t>
0249 std::optional<FastStrawLineFitter::FitResultT0> FastStrawLineFitter::fit(
0250     const Acts::CalibrationContext& ctx, const Calibrator_t& calibrator,
0251     const StrawCont_t& measurements, const std::vector<int>& signs,
0252     std::optional<double> startT0) const {
0253   using namespace Acts::UnitLiterals;
0254   if (measurements.size() != signs.size()) {
0255     ACTS_WARNING(
0256         __func__ << "() - " << __LINE__
0257                  << ": Not all measurements are associated with a drift sign");
0258     return std::nullopt;
0259   }
0260 
0261   FitResultT0 result{};
0262   result.t0 = startT0.value_or(0.);
0263 
0264   FitAuxiliariesWithT0 fitPars{
0265       fillAuxiliaries(ctx, calibrator, measurements, signs, result.t0)};
0266   result.theta = startTheta(fitPars);
0267   result.nDoF = fitPars.nDoF;
0268   ACTS_DEBUG(__func__ << "() - " << __LINE__
0269                       << ": Initial fit parameters: " << result);
0270   UpdateStatus iterStatus{UpdateStatus::GoodStep};
0271 
0272   while ((iterStatus = updateIteration(fitPars, result)) !=
0273          UpdateStatus::Exceeded) {
0274     if (iterStatus == UpdateStatus::Converged) {
0275       calcPostFitChi2(ctx, measurements, calibrator, result);
0276       return result;
0277     }
0278     fitPars = fillAuxiliaries(ctx, calibrator, measurements, signs, result.t0);
0279   }
0280   if (logger().doPrint(Logging::VERBOSE)) {
0281     ACTS_VERBOSE("Fit failed, printing all measurements:");
0282     for (const auto& meas : measurements) {
0283       ACTS_VERBOSE(
0284           "Pos: " << Acts::toString(meas->localPosition()) << ", t,t0: "
0285                   << meas->time() / 1._ns << ", " << result.t0 / 1._ns
0286                   << ", truthR, RecoR: " << meas->driftRadius() << ", "
0287                   << calibrator.driftRadius(ctx, *meas, result.t0) << ", v: "
0288                   << calibrator.driftVelocity(ctx, *meas, result.t0) * 1._ns
0289                   << ", a: "
0290                   << calibrator.driftAcceleration(ctx, *meas, result.t0) *
0291                          1._ns * 1._ns);
0292     }
0293     ACTS_VERBOSE("Result: " << result);
0294   }
0295   return std::nullopt;
0296 }
0297 
0298 template <CompositeSpacePointContainer StrawCont_t,
0299           CompositeSpacePointFastCalibrator<
0300               Acts::RemovePointer_t<typename StrawCont_t::value_type>>
0301               Calibrator_t>
0302 FastStrawLineFitter::FitAuxiliariesWithT0 FastStrawLineFitter::fillAuxiliaries(
0303     const CalibrationContext& ctx, const Calibrator_t& calibrator,
0304     const StrawCont_t& measurements, const std::vector<int>& signs,
0305     const double t0) const {
0306   using namespace Acts::UnitLiterals;
0307   FitAuxiliariesWithT0 auxVars{fillAuxiliaries(measurements, signs)};
0308   if (auxVars.nDoF <= 1) {
0309     auxVars.nDoF = 0;
0310     return auxVars;
0311   }
0312   // Account for the time offset as extra degree of freedom
0313   --auxVars.nDoF;
0314   // Fill the new extra variables
0315   auxVars.T_rz = 0.;
0316   auxVars.T_ry = 0.;
0317   auxVars.fitY0 = 0.;
0318   for (const auto& [spIdx, strawMeas] : enumerate(measurements)) {
0319     const double& invCov = auxVars.invCovs[spIdx];
0320     // Invalid (non)-straw measurements
0321     if (invCov < 0.) {
0322       continue;
0323     }
0324     const double sInvCov = -invCov * signs[spIdx];
0325     const double r = calibrator.driftRadius(ctx, *strawMeas, t0);
0326     const double v = calibrator.driftVelocity(ctx, *strawMeas, t0);
0327     const double a = calibrator.driftAcceleration(ctx, *strawMeas, t0);
0328     const double y = strawMeas->localPosition().dot(strawMeas->toNextSensor()) -
0329                      auxVars.centerY;
0330     const double z = strawMeas->localPosition().dot(strawMeas->planeNormal()) -
0331                      auxVars.centerZ;
0332 
0333     ACTS_VERBOSE(__func__ << "() - " << __LINE__ << ": # " << (spIdx + 1)
0334                           << ") t,t0: " << strawMeas->time() / 1._ns << ", "
0335                           << t0 / 1._ns << " r: " << r << ", v: " << v * 1._ns
0336                           << ", a: " << a * 1._ns * 1._ns);
0337     auxVars.fitY0 += sInvCov * r;
0338     auxVars.R_v += sInvCov * v;
0339     auxVars.R_a += sInvCov * a;
0340 
0341     auxVars.T_rz += sInvCov * z * r;
0342     auxVars.T_ry += sInvCov * y * r;
0343 
0344     auxVars.T_vy += sInvCov * v * y;
0345     auxVars.T_vz += sInvCov * v * z;
0346 
0347     auxVars.R_vr += invCov * r * v;
0348     auxVars.R_vv += invCov * v * v;
0349 
0350     auxVars.T_ay += sInvCov * a * y;
0351     auxVars.T_az += sInvCov * a * z;
0352 
0353     auxVars.R_ar += invCov * a * r;
0354   }
0355   auxVars.fitY0 *= auxVars.covNorm;
0356   ACTS_DEBUG(__func__ << "() - " << __LINE__ << " Fit constants calculated \n"
0357                       << auxVars);
0358   return auxVars;
0359 }
0360 
0361 }  // namespace Acts::Experimental::detail