Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:22:37

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 #include "Acts/Seeding/detail/FastStrawLineFitter.hpp"
0012 
0013 #include "Acts/Definitions/Units.hpp"
0014 #include "Acts/Utilities/Enumerate.hpp"
0015 
0016 #include <format>
0017 namespace Acts::Experimental::detail {
0018 
0019 template <CompositeSpacePointContainer StrawCont_t>
0020 std::optional<FastStrawLineFitter::FitResult> FastStrawLineFitter::fit(
0021     const StrawCont_t& measurements, const std::vector<int>& signs) const {
0022   if (measurements.size() != signs.size()) {
0023     ACTS_WARNING(
0024         __func__ << "() - " << __LINE__
0025                  << ": Not all measurements are associated with a drift sign");
0026     return std::nullopt;
0027   }
0028 
0029   auto result = fit(fillAuxiliaries(measurements, signs));
0030   if (!result) {
0031     return std::nullopt;
0032   }
0033   // Calculate the chi2
0034   calcPostFitChi2(measurements, *result);
0035   return result;
0036 }
0037 
0038 template <CompositeSpacePointContainer StrawCont_t>
0039 void FastStrawLineFitter::calcPostFitChi2(const StrawCont_t& measurements,
0040                                           FitResult& result) const {
0041   const TrigonomHelper angles{result.theta};
0042   result.chi2 = 0.;
0043   for (const auto& strawMeas : measurements) {
0044     result.chi2 += chi2Term(angles, result.y0, *strawMeas);
0045   }
0046   ACTS_DEBUG(__func__ << "() - " << __LINE__ << ": Overall chi2: "
0047                       << result.chi2 << ", nDoF: " << result.nDoF
0048                       << ", redChi2: " << (result.chi2 / result.nDoF));
0049 }
0050 
0051 template <CompositeSpacePoint Point_t>
0052 double FastStrawLineFitter::chi2Term(const TrigonomHelper& angle,
0053                                      const double y0, const Point_t& strawMeas,
0054                                      std::optional<double> r) const {
0055   if (!strawMeas.isStraw()) {
0056     return 0.;
0057   }
0058   const double cov = strawMeas.covariance()[s_covIdx];
0059   if (cov < std::numeric_limits<double>::epsilon()) {
0060     return 0.;
0061   }
0062   const Vector& pos = strawMeas.localPosition();
0063   const double y = pos.dot(strawMeas.toNextSensor());
0064   const double z = pos.dot(strawMeas.planeNormal());
0065   const double dist = Acts::abs((y - y0) * angle.cosTheta - z * angle.sinTheta);
0066   ACTS_VERBOSE(__func__ << "() - " << __LINE__ << ": Distance straw (" << y
0067                         << ", " << z
0068                         << "), r: " << r.value_or(strawMeas.driftRadius())
0069                         << " - track: " << dist);
0070   return Acts::pow(dist - r.value_or(strawMeas.driftRadius()), 2) / cov;
0071 }
0072 
0073 template <CompositeSpacePointContainer StripCont_t>
0074 std::optional<FastStrawLineFitter::FitResult> FastStrawLineFitter::fit(
0075     const StripCont_t& measurements, const ResidualIdx projection) const {
0076   if (projection == ResidualIdx::time) {
0077     ACTS_WARNING(__func__ << "() - " << __LINE__
0078                           << ": Only spatial projections, "
0079                           << "i.e. nonBending / bending are sensible");
0080     return std::nullopt;
0081   }
0082 
0083   FitAuxiliaries auxVars{};
0084   Vector centerOfGravity{Vector::Zero()};
0085   // @brief Selector function to check that the current strip provides a constraint in the projetor direction
0086   auto select = [&projection](const auto& strip) -> bool {
0087     // Skip straw measurements that are not twins &
0088     // don't measure non-bending coordinate
0089     if (strip->isStraw()) {
0090       return strip->measuresLoc0() && projection == ResidualIdx::nonBending;
0091     }
0092     // Check that the strip is actually measuring the projection
0093     return (strip->measuresLoc0() && projection == ResidualIdx::nonBending) ||
0094            (strip->measuresLoc1() && projection == ResidualIdx::bending);
0095   };
0096 
0097   auxVars.invCovs.resize(measurements.size());
0098   for (const auto& [sIdx, strip] : enumerate(measurements)) {
0099     if (!select(strip)) {
0100       ACTS_VERBOSE(__func__ << "() - " << __LINE__
0101                             << ": Skip strip measurement " << toString(*strip));
0102       continue;
0103     }
0104     const auto& invCov =
0105         (auxVars.invCovs[sIdx] =
0106              1. / strip->covariance()[toUnderlying(projection)]);
0107     auxVars.covNorm += invCov;
0108     centerOfGravity += invCov * strip->localPosition();
0109     ++auxVars.nDoF;
0110   }
0111   // To little information provided
0112   if (auxVars.nDoF < 3) {
0113     return std::nullopt;
0114   }
0115   // Reduce the number of degrees of freedom by 2 to account
0116   // for the two free parameters to fit
0117   auxVars.nDoF -= 2u;
0118   auxVars.covNorm = 1. / auxVars.covNorm;
0119   centerOfGravity *= auxVars.covNorm;
0120 
0121   bool centerSet{false};
0122   for (const auto& [sIdx, strip] : enumerate(measurements)) {
0123     if (!select(strip)) {
0124       continue;
0125     }
0126     const Vector pos = strip->localPosition() - centerOfGravity;
0127     const Vector& measDir{
0128         (projection == ResidualIdx::nonBending && strip->measuresLoc1()) ||
0129                 strip->isStraw()
0130             ? strip->sensorDirection()
0131             : strip->toNextSensor()};
0132 
0133     if (!centerSet) {
0134       auxVars.centerY = centerOfGravity.dot(measDir);
0135       auxVars.centerZ = centerOfGravity.dot(strip->planeNormal());
0136       centerSet = true;
0137     }
0138 
0139     const double y = pos.dot(measDir);
0140     const double z = pos.dot(strip->planeNormal());
0141 
0142     const auto& invCov = auxVars.invCovs[sIdx];
0143     auxVars.T_zzyy += invCov * (Acts::square(z) - Acts::square(y));
0144     auxVars.T_yz += invCov * z * y;
0145   }
0146   return fit(auxVars);
0147 }
0148 
0149 template <CompositeSpacePointContainer StrawCont_t,
0150           CompositeSpacePointFastCalibrator<
0151               Acts::RemovePointer_t<typename StrawCont_t::value_type>>
0152               Calibrator_t>
0153 void FastStrawLineFitter::calcPostFitChi2(const Acts::CalibrationContext& ctx,
0154                                           const StrawCont_t& measurements,
0155                                           const Calibrator_t& calibrator,
0156                                           FitResultT0& result) const {
0157   const TrigonomHelper angles{result.theta};
0158   result.chi2 = 0.;
0159   for (const auto& strawMeas : measurements) {
0160     result.chi2 += chi2Term(angles, result.y0, *strawMeas,
0161                             calibrator.driftRadius(ctx, *strawMeas, result.t0));
0162   }
0163   ACTS_DEBUG(__func__ << "() - " << __LINE__ << ": Overall chi2: "
0164                       << result.chi2 << ", nDoF: " << result.nDoF
0165                       << ", redChi2: " << (result.chi2 / result.nDoF));
0166 }
0167 
0168 template <CompositeSpacePointContainer StrawCont_t>
0169 FastStrawLineFitter::FitAuxiliaries FastStrawLineFitter::fillAuxiliaries(
0170     const StrawCont_t& measurements, const std::vector<int>& signs) const {
0171   FitAuxiliaries auxVars{};
0172   auxVars.invCovs.resize(signs.size(), -1.);
0173   Vector centerOfGravity{Vector::Zero()};
0174 
0175   // Calculate first the center of gravity
0176   for (const auto& [sIdx, strawMeas] : enumerate(measurements)) {
0177     if (!strawMeas->isStraw()) {
0178       ACTS_DEBUG(__func__ << "() - " << __LINE__ << ": The measurement "
0179                           << toString(*strawMeas) << " is not a straw");
0180       continue;
0181     }
0182     const double cov = strawMeas->covariance()[s_covIdx];
0183     if (cov < std::numeric_limits<double>::epsilon()) {
0184       ACTS_WARNING(__func__ << "() - " << __LINE__ << ": The covariance ("
0185                             << cov << ") of the measurement "
0186                             << toString(*strawMeas) << " is invalid.");
0187       continue;
0188     }
0189     ACTS_VERBOSE(__func__ << "() - " << __LINE__ << ": Fill "
0190                           << toString(*strawMeas) << ".");
0191 
0192     auto& invCov = (auxVars.invCovs[sIdx] = 1. / cov);
0193     auxVars.covNorm += invCov;
0194     centerOfGravity += invCov * strawMeas->localPosition();
0195     ++auxVars.nDoF;
0196   }
0197   if (auxVars.nDoF < 3) {
0198     std::stringstream sstr{};
0199     for (const auto& [sIdx, strawMeas] : enumerate(measurements)) {
0200       sstr << " --- " << (sIdx + 1) << ") " << toString(*strawMeas)
0201            << ", weight: " << auxVars.invCovs[sIdx] << std::endl;
0202     }
0203     ACTS_WARNING(__func__ << "() - " << __LINE__
0204                           << ": At least 3 measurements are required to "
0205                              "perform the straw line fit\n"
0206                           << sstr.str());
0207     auxVars.nDoF = 0u;
0208     return auxVars;
0209   }
0210   // Reduce the number of degrees of freedom by 2 to account
0211   // for the two free parameters to fit
0212   auxVars.nDoF -= 2u;
0213   auxVars.covNorm = 1. / auxVars.covNorm;
0214   centerOfGravity *= auxVars.covNorm;
0215 
0216   // Now calculate the fit constants
0217   bool centerSet{false};
0218   for (const auto& [sIdx, strawMeas] : enumerate(measurements)) {
0219     const auto& invCov = auxVars.invCovs[sIdx];
0220     // Invalid measurements were marked
0221     if (invCov < 0.) {
0222       continue;
0223     }
0224     if (!centerSet) {
0225       auxVars.centerY = centerOfGravity.dot(strawMeas->toNextSensor());
0226       auxVars.centerZ = centerOfGravity.dot(strawMeas->planeNormal());
0227       centerSet = true;
0228     }
0229     const Vector pos = strawMeas->localPosition() - centerOfGravity;
0230     const double y = pos.dot(strawMeas->toNextSensor());
0231     const double z = pos.dot(strawMeas->planeNormal());
0232     const double r = strawMeas->driftRadius();
0233 
0234     auxVars.T_zzyy += invCov * (Acts::square(z) - Acts::square(y));
0235     auxVars.T_yz += invCov * z * y;
0236     const double sInvCov = -invCov * signs[sIdx];
0237     auxVars.T_rz += sInvCov * z * r;
0238     auxVars.T_ry += sInvCov * y * r;
0239     auxVars.fitY0 += sInvCov * r;
0240   }
0241   auxVars.fitY0 *= auxVars.covNorm;
0242 
0243   return auxVars;
0244 }
0245 
0246 template <CompositeSpacePointContainer StrawCont_t,
0247           CompositeSpacePointFastCalibrator<
0248               Acts::RemovePointer_t<typename StrawCont_t::value_type>>
0249               Calibrator_t>
0250 std::optional<FastStrawLineFitter::FitResultT0> FastStrawLineFitter::fit(
0251     const Acts::CalibrationContext& ctx, const Calibrator_t& calibrator,
0252     const StrawCont_t& measurements, const std::vector<int>& signs,
0253     std::optional<double> startT0) const {
0254   using namespace Acts::UnitLiterals;
0255   if (measurements.size() != signs.size()) {
0256     ACTS_WARNING(
0257         __func__ << "() - " << __LINE__
0258                  << ": Not all measurements are associated with a drift sign");
0259     return std::nullopt;
0260   }
0261 
0262   FitResultT0 result{};
0263   result.t0 = startT0.value_or(0.);
0264 
0265   FitAuxiliariesWithT0 fitPars{
0266       fillAuxiliaries(ctx, calibrator, measurements, signs, result.t0)};
0267   result.theta = startTheta(fitPars);
0268   result.nDoF = fitPars.nDoF;
0269   ACTS_DEBUG(__func__ << "() - " << __LINE__
0270                       << ": Initial fit parameters: " << result);
0271   UpdateStatus iterStatus{UpdateStatus::GoodStep};
0272 
0273   while ((iterStatus = updateIteration(fitPars, result)) !=
0274          UpdateStatus::Exceeded) {
0275     if (iterStatus == UpdateStatus::Converged) {
0276       calcPostFitChi2(ctx, measurements, calibrator, result);
0277       return result;
0278     }
0279     fitPars = fillAuxiliaries(ctx, calibrator, measurements, signs, result.t0);
0280   }
0281   if (logger().doPrint(Logging::VERBOSE)) {
0282     ACTS_VERBOSE("Fit failed, printing all measurements:");
0283     for (const auto& meas : measurements) {
0284       ACTS_VERBOSE(toString(*meas)
0285                    << ", t0: " << result.t0 / 1._ns
0286                    << ", truthR, RecoR: " << meas->driftRadius() << ", "
0287                    << calibrator.driftRadius(ctx, *meas, result.t0)
0288                    << ", velocity: "
0289                    << calibrator.driftVelocity(ctx, *meas, result.t0) * 1._ns
0290                    << ", acceleration: "
0291                    << calibrator.driftAcceleration(ctx, *meas, result.t0) *
0292                           Acts::square(1._ns));
0293     }
0294     ACTS_VERBOSE("Result: " << result);
0295   }
0296   return std::nullopt;
0297 }
0298 
0299 template <CompositeSpacePointContainer StrawCont_t,
0300           CompositeSpacePointFastCalibrator<
0301               Acts::RemovePointer_t<typename StrawCont_t::value_type>>
0302               Calibrator_t>
0303 FastStrawLineFitter::FitAuxiliariesWithT0 FastStrawLineFitter::fillAuxiliaries(
0304     const CalibrationContext& ctx, const Calibrator_t& calibrator,
0305     const StrawCont_t& measurements, const std::vector<int>& signs,
0306     const double t0) const {
0307   using namespace Acts::UnitLiterals;
0308   FitAuxiliariesWithT0 auxVars{fillAuxiliaries(measurements, signs)};
0309   if (auxVars.nDoF <= 1) {
0310     auxVars.nDoF = 0;
0311     return auxVars;
0312   }
0313   // Account for the time offset as extra degree of freedom
0314   --auxVars.nDoF;
0315   // Fill the new extra variables
0316   auxVars.T_rz = 0.;
0317   auxVars.T_ry = 0.;
0318   auxVars.fitY0 = 0.;
0319   for (const auto& [spIdx, strawMeas] : enumerate(measurements)) {
0320     const double& invCov = auxVars.invCovs[spIdx];
0321     // Invalid (non)-straw measurements
0322     if (invCov < 0.) {
0323       continue;
0324     }
0325     const double sInvCov = -invCov * signs[spIdx];
0326     const double r = calibrator.driftRadius(ctx, *strawMeas, t0);
0327     const double v = calibrator.driftVelocity(ctx, *strawMeas, t0);
0328     const double a = calibrator.driftAcceleration(ctx, *strawMeas, t0);
0329     const double y = strawMeas->localPosition().dot(strawMeas->toNextSensor()) -
0330                      auxVars.centerY;
0331     const double z = strawMeas->localPosition().dot(strawMeas->planeNormal()) -
0332                      auxVars.centerZ;
0333 
0334     ACTS_VERBOSE(__func__ << "() - " << __LINE__ << ": # " << (spIdx + 1)
0335                           << ") " << toString(*strawMeas) << ", t0: "
0336                           << t0 / 1._ns << " r: " << r << ", v: " << v * 1._ns
0337                           << ", a: " << a * Acts::square(1._ns));
0338     auxVars.fitY0 += sInvCov * r;
0339     auxVars.R_v += sInvCov * v;
0340     auxVars.R_a += sInvCov * a;
0341 
0342     auxVars.T_rz += sInvCov * z * r;
0343     auxVars.T_ry += sInvCov * y * r;
0344 
0345     auxVars.T_vy += sInvCov * v * y;
0346     auxVars.T_vz += sInvCov * v * z;
0347 
0348     auxVars.R_vr += invCov * r * v;
0349     auxVars.R_vv += invCov * v * v;
0350 
0351     auxVars.T_ay += sInvCov * a * y;
0352     auxVars.T_az += sInvCov * a * z;
0353 
0354     auxVars.R_ar += invCov * a * r;
0355   }
0356   auxVars.fitY0 *= auxVars.covNorm;
0357   ACTS_DEBUG(__func__ << "() - " << __LINE__ << " Fit constants calculated \n"
0358                       << auxVars);
0359   return auxVars;
0360 }
0361 
0362 }  // namespace Acts::Experimental::detail