File indexing completed on 2025-12-16 09:22:37
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/Seeding/detail/FastStrawLineFitter.hpp"
0012
0013 #include "Acts/Definitions/Units.hpp"
0014 #include "Acts/Utilities/Enumerate.hpp"
0015
0016 #include <format>
0017 namespace Acts::Experimental::detail {
0018
0019 template <CompositeSpacePointContainer StrawCont_t>
0020 std::optional<FastStrawLineFitter::FitResult> FastStrawLineFitter::fit(
0021 const StrawCont_t& measurements, const std::vector<int>& signs) const {
0022 if (measurements.size() != signs.size()) {
0023 ACTS_WARNING(
0024 __func__ << "() - " << __LINE__
0025 << ": Not all measurements are associated with a drift sign");
0026 return std::nullopt;
0027 }
0028
0029 auto result = fit(fillAuxiliaries(measurements, signs));
0030 if (!result) {
0031 return std::nullopt;
0032 }
0033
0034 calcPostFitChi2(measurements, *result);
0035 return result;
0036 }
0037
0038 template <CompositeSpacePointContainer StrawCont_t>
0039 void FastStrawLineFitter::calcPostFitChi2(const StrawCont_t& measurements,
0040 FitResult& result) const {
0041 const TrigonomHelper angles{result.theta};
0042 result.chi2 = 0.;
0043 for (const auto& strawMeas : measurements) {
0044 result.chi2 += chi2Term(angles, result.y0, *strawMeas);
0045 }
0046 ACTS_DEBUG(__func__ << "() - " << __LINE__ << ": Overall chi2: "
0047 << result.chi2 << ", nDoF: " << result.nDoF
0048 << ", redChi2: " << (result.chi2 / result.nDoF));
0049 }
0050
0051 template <CompositeSpacePoint Point_t>
0052 double FastStrawLineFitter::chi2Term(const TrigonomHelper& angle,
0053 const double y0, const Point_t& strawMeas,
0054 std::optional<double> r) const {
0055 if (!strawMeas.isStraw()) {
0056 return 0.;
0057 }
0058 const double cov = strawMeas.covariance()[s_covIdx];
0059 if (cov < std::numeric_limits<double>::epsilon()) {
0060 return 0.;
0061 }
0062 const Vector& pos = strawMeas.localPosition();
0063 const double y = pos.dot(strawMeas.toNextSensor());
0064 const double z = pos.dot(strawMeas.planeNormal());
0065 const double dist = Acts::abs((y - y0) * angle.cosTheta - z * angle.sinTheta);
0066 ACTS_VERBOSE(__func__ << "() - " << __LINE__ << ": Distance straw (" << y
0067 << ", " << z
0068 << "), r: " << r.value_or(strawMeas.driftRadius())
0069 << " - track: " << dist);
0070 return Acts::pow(dist - r.value_or(strawMeas.driftRadius()), 2) / cov;
0071 }
0072
0073 template <CompositeSpacePointContainer StripCont_t>
0074 std::optional<FastStrawLineFitter::FitResult> FastStrawLineFitter::fit(
0075 const StripCont_t& measurements, const ResidualIdx projection) const {
0076 if (projection == ResidualIdx::time) {
0077 ACTS_WARNING(__func__ << "() - " << __LINE__
0078 << ": Only spatial projections, "
0079 << "i.e. nonBending / bending are sensible");
0080 return std::nullopt;
0081 }
0082
0083 FitAuxiliaries auxVars{};
0084 Vector centerOfGravity{Vector::Zero()};
0085
0086 auto select = [&projection](const auto& strip) -> bool {
0087
0088
0089 if (strip->isStraw()) {
0090 return strip->measuresLoc0() && projection == ResidualIdx::nonBending;
0091 }
0092
0093 return (strip->measuresLoc0() && projection == ResidualIdx::nonBending) ||
0094 (strip->measuresLoc1() && projection == ResidualIdx::bending);
0095 };
0096
0097 auxVars.invCovs.resize(measurements.size());
0098 for (const auto& [sIdx, strip] : enumerate(measurements)) {
0099 if (!select(strip)) {
0100 ACTS_VERBOSE(__func__ << "() - " << __LINE__
0101 << ": Skip strip measurement " << toString(*strip));
0102 continue;
0103 }
0104 const auto& invCov =
0105 (auxVars.invCovs[sIdx] =
0106 1. / strip->covariance()[toUnderlying(projection)]);
0107 auxVars.covNorm += invCov;
0108 centerOfGravity += invCov * strip->localPosition();
0109 ++auxVars.nDoF;
0110 }
0111
0112 if (auxVars.nDoF < 3) {
0113 return std::nullopt;
0114 }
0115
0116
0117 auxVars.nDoF -= 2u;
0118 auxVars.covNorm = 1. / auxVars.covNorm;
0119 centerOfGravity *= auxVars.covNorm;
0120
0121 bool centerSet{false};
0122 for (const auto& [sIdx, strip] : enumerate(measurements)) {
0123 if (!select(strip)) {
0124 continue;
0125 }
0126 const Vector pos = strip->localPosition() - centerOfGravity;
0127 const Vector& measDir{
0128 (projection == ResidualIdx::nonBending && strip->measuresLoc1()) ||
0129 strip->isStraw()
0130 ? strip->sensorDirection()
0131 : strip->toNextSensor()};
0132
0133 if (!centerSet) {
0134 auxVars.centerY = centerOfGravity.dot(measDir);
0135 auxVars.centerZ = centerOfGravity.dot(strip->planeNormal());
0136 centerSet = true;
0137 }
0138
0139 const double y = pos.dot(measDir);
0140 const double z = pos.dot(strip->planeNormal());
0141
0142 const auto& invCov = auxVars.invCovs[sIdx];
0143 auxVars.T_zzyy += invCov * (Acts::square(z) - Acts::square(y));
0144 auxVars.T_yz += invCov * z * y;
0145 }
0146 return fit(auxVars);
0147 }
0148
0149 template <CompositeSpacePointContainer StrawCont_t,
0150 CompositeSpacePointFastCalibrator<
0151 Acts::RemovePointer_t<typename StrawCont_t::value_type>>
0152 Calibrator_t>
0153 void FastStrawLineFitter::calcPostFitChi2(const Acts::CalibrationContext& ctx,
0154 const StrawCont_t& measurements,
0155 const Calibrator_t& calibrator,
0156 FitResultT0& result) const {
0157 const TrigonomHelper angles{result.theta};
0158 result.chi2 = 0.;
0159 for (const auto& strawMeas : measurements) {
0160 result.chi2 += chi2Term(angles, result.y0, *strawMeas,
0161 calibrator.driftRadius(ctx, *strawMeas, result.t0));
0162 }
0163 ACTS_DEBUG(__func__ << "() - " << __LINE__ << ": Overall chi2: "
0164 << result.chi2 << ", nDoF: " << result.nDoF
0165 << ", redChi2: " << (result.chi2 / result.nDoF));
0166 }
0167
0168 template <CompositeSpacePointContainer StrawCont_t>
0169 FastStrawLineFitter::FitAuxiliaries FastStrawLineFitter::fillAuxiliaries(
0170 const StrawCont_t& measurements, const std::vector<int>& signs) const {
0171 FitAuxiliaries auxVars{};
0172 auxVars.invCovs.resize(signs.size(), -1.);
0173 Vector centerOfGravity{Vector::Zero()};
0174
0175
0176 for (const auto& [sIdx, strawMeas] : enumerate(measurements)) {
0177 if (!strawMeas->isStraw()) {
0178 ACTS_DEBUG(__func__ << "() - " << __LINE__ << ": The measurement "
0179 << toString(*strawMeas) << " is not a straw");
0180 continue;
0181 }
0182 const double cov = strawMeas->covariance()[s_covIdx];
0183 if (cov < std::numeric_limits<double>::epsilon()) {
0184 ACTS_WARNING(__func__ << "() - " << __LINE__ << ": The covariance ("
0185 << cov << ") of the measurement "
0186 << toString(*strawMeas) << " is invalid.");
0187 continue;
0188 }
0189 ACTS_VERBOSE(__func__ << "() - " << __LINE__ << ": Fill "
0190 << toString(*strawMeas) << ".");
0191
0192 auto& invCov = (auxVars.invCovs[sIdx] = 1. / cov);
0193 auxVars.covNorm += invCov;
0194 centerOfGravity += invCov * strawMeas->localPosition();
0195 ++auxVars.nDoF;
0196 }
0197 if (auxVars.nDoF < 3) {
0198 std::stringstream sstr{};
0199 for (const auto& [sIdx, strawMeas] : enumerate(measurements)) {
0200 sstr << " --- " << (sIdx + 1) << ") " << toString(*strawMeas)
0201 << ", weight: " << auxVars.invCovs[sIdx] << std::endl;
0202 }
0203 ACTS_WARNING(__func__ << "() - " << __LINE__
0204 << ": At least 3 measurements are required to "
0205 "perform the straw line fit\n"
0206 << sstr.str());
0207 auxVars.nDoF = 0u;
0208 return auxVars;
0209 }
0210
0211
0212 auxVars.nDoF -= 2u;
0213 auxVars.covNorm = 1. / auxVars.covNorm;
0214 centerOfGravity *= auxVars.covNorm;
0215
0216
0217 bool centerSet{false};
0218 for (const auto& [sIdx, strawMeas] : enumerate(measurements)) {
0219 const auto& invCov = auxVars.invCovs[sIdx];
0220
0221 if (invCov < 0.) {
0222 continue;
0223 }
0224 if (!centerSet) {
0225 auxVars.centerY = centerOfGravity.dot(strawMeas->toNextSensor());
0226 auxVars.centerZ = centerOfGravity.dot(strawMeas->planeNormal());
0227 centerSet = true;
0228 }
0229 const Vector pos = strawMeas->localPosition() - centerOfGravity;
0230 const double y = pos.dot(strawMeas->toNextSensor());
0231 const double z = pos.dot(strawMeas->planeNormal());
0232 const double r = strawMeas->driftRadius();
0233
0234 auxVars.T_zzyy += invCov * (Acts::square(z) - Acts::square(y));
0235 auxVars.T_yz += invCov * z * y;
0236 const double sInvCov = -invCov * signs[sIdx];
0237 auxVars.T_rz += sInvCov * z * r;
0238 auxVars.T_ry += sInvCov * y * r;
0239 auxVars.fitY0 += sInvCov * r;
0240 }
0241 auxVars.fitY0 *= auxVars.covNorm;
0242
0243 return auxVars;
0244 }
0245
0246 template <CompositeSpacePointContainer StrawCont_t,
0247 CompositeSpacePointFastCalibrator<
0248 Acts::RemovePointer_t<typename StrawCont_t::value_type>>
0249 Calibrator_t>
0250 std::optional<FastStrawLineFitter::FitResultT0> FastStrawLineFitter::fit(
0251 const Acts::CalibrationContext& ctx, const Calibrator_t& calibrator,
0252 const StrawCont_t& measurements, const std::vector<int>& signs,
0253 std::optional<double> startT0) const {
0254 using namespace Acts::UnitLiterals;
0255 if (measurements.size() != signs.size()) {
0256 ACTS_WARNING(
0257 __func__ << "() - " << __LINE__
0258 << ": Not all measurements are associated with a drift sign");
0259 return std::nullopt;
0260 }
0261
0262 FitResultT0 result{};
0263 result.t0 = startT0.value_or(0.);
0264
0265 FitAuxiliariesWithT0 fitPars{
0266 fillAuxiliaries(ctx, calibrator, measurements, signs, result.t0)};
0267 result.theta = startTheta(fitPars);
0268 result.nDoF = fitPars.nDoF;
0269 ACTS_DEBUG(__func__ << "() - " << __LINE__
0270 << ": Initial fit parameters: " << result);
0271 UpdateStatus iterStatus{UpdateStatus::GoodStep};
0272
0273 while ((iterStatus = updateIteration(fitPars, result)) !=
0274 UpdateStatus::Exceeded) {
0275 if (iterStatus == UpdateStatus::Converged) {
0276 calcPostFitChi2(ctx, measurements, calibrator, result);
0277 return result;
0278 }
0279 fitPars = fillAuxiliaries(ctx, calibrator, measurements, signs, result.t0);
0280 }
0281 if (logger().doPrint(Logging::VERBOSE)) {
0282 ACTS_VERBOSE("Fit failed, printing all measurements:");
0283 for (const auto& meas : measurements) {
0284 ACTS_VERBOSE(toString(*meas)
0285 << ", t0: " << result.t0 / 1._ns
0286 << ", truthR, RecoR: " << meas->driftRadius() << ", "
0287 << calibrator.driftRadius(ctx, *meas, result.t0)
0288 << ", velocity: "
0289 << calibrator.driftVelocity(ctx, *meas, result.t0) * 1._ns
0290 << ", acceleration: "
0291 << calibrator.driftAcceleration(ctx, *meas, result.t0) *
0292 Acts::square(1._ns));
0293 }
0294 ACTS_VERBOSE("Result: " << result);
0295 }
0296 return std::nullopt;
0297 }
0298
0299 template <CompositeSpacePointContainer StrawCont_t,
0300 CompositeSpacePointFastCalibrator<
0301 Acts::RemovePointer_t<typename StrawCont_t::value_type>>
0302 Calibrator_t>
0303 FastStrawLineFitter::FitAuxiliariesWithT0 FastStrawLineFitter::fillAuxiliaries(
0304 const CalibrationContext& ctx, const Calibrator_t& calibrator,
0305 const StrawCont_t& measurements, const std::vector<int>& signs,
0306 const double t0) const {
0307 using namespace Acts::UnitLiterals;
0308 FitAuxiliariesWithT0 auxVars{fillAuxiliaries(measurements, signs)};
0309 if (auxVars.nDoF <= 1) {
0310 auxVars.nDoF = 0;
0311 return auxVars;
0312 }
0313
0314 --auxVars.nDoF;
0315
0316 auxVars.T_rz = 0.;
0317 auxVars.T_ry = 0.;
0318 auxVars.fitY0 = 0.;
0319 for (const auto& [spIdx, strawMeas] : enumerate(measurements)) {
0320 const double& invCov = auxVars.invCovs[spIdx];
0321
0322 if (invCov < 0.) {
0323 continue;
0324 }
0325 const double sInvCov = -invCov * signs[spIdx];
0326 const double r = calibrator.driftRadius(ctx, *strawMeas, t0);
0327 const double v = calibrator.driftVelocity(ctx, *strawMeas, t0);
0328 const double a = calibrator.driftAcceleration(ctx, *strawMeas, t0);
0329 const double y = strawMeas->localPosition().dot(strawMeas->toNextSensor()) -
0330 auxVars.centerY;
0331 const double z = strawMeas->localPosition().dot(strawMeas->planeNormal()) -
0332 auxVars.centerZ;
0333
0334 ACTS_VERBOSE(__func__ << "() - " << __LINE__ << ": # " << (spIdx + 1)
0335 << ") " << toString(*strawMeas) << ", t0: "
0336 << t0 / 1._ns << " r: " << r << ", v: " << v * 1._ns
0337 << ", a: " << a * Acts::square(1._ns));
0338 auxVars.fitY0 += sInvCov * r;
0339 auxVars.R_v += sInvCov * v;
0340 auxVars.R_a += sInvCov * a;
0341
0342 auxVars.T_rz += sInvCov * z * r;
0343 auxVars.T_ry += sInvCov * y * r;
0344
0345 auxVars.T_vy += sInvCov * v * y;
0346 auxVars.T_vz += sInvCov * v * z;
0347
0348 auxVars.R_vr += invCov * r * v;
0349 auxVars.R_vv += invCov * v * v;
0350
0351 auxVars.T_ay += sInvCov * a * y;
0352 auxVars.T_az += sInvCov * a * z;
0353
0354 auxVars.R_ar += invCov * a * r;
0355 }
0356 auxVars.fitY0 *= auxVars.covNorm;
0357 ACTS_DEBUG(__func__ << "() - " << __LINE__ << " Fit constants calculated \n"
0358 << auxVars);
0359 return auxVars;
0360 }
0361
0362 }