File indexing completed on 2025-12-15 09:42:16
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "Acts/Seeding/detail/FastStrawLineFitter.hpp"
0012
0013 #include "Acts/Definitions/Units.hpp"
0014 #include "Acts/Utilities/Enumerate.hpp"
0015
0016 #include <format>
0017 namespace Acts::Experimental::detail {
0018
0019 template <CompositeSpacePointContainer StrawCont_t>
0020 std::optional<FastStrawLineFitter::FitResult> FastStrawLineFitter::fit(
0021 const StrawCont_t& measurements, const std::vector<int>& signs) const {
0022 if (measurements.size() != signs.size()) {
0023 ACTS_WARNING(
0024 __func__ << "() - " << __LINE__
0025 << ": Not all measurements are associated with a drift sign");
0026 return std::nullopt;
0027 }
0028
0029 auto result = fit(fillAuxiliaries(measurements, signs));
0030 if (!result) {
0031 return std::nullopt;
0032 }
0033
0034 calcPostFitChi2(measurements, *result);
0035 return result;
0036 }
0037
0038 template <CompositeSpacePointContainer StrawCont_t>
0039 void FastStrawLineFitter::calcPostFitChi2(const StrawCont_t& measurements,
0040 FitResult& result) const {
0041 const TrigonomHelper angles{result.theta};
0042 result.chi2 = 0.;
0043 for (const auto& strawMeas : measurements) {
0044 result.chi2 += chi2Term(angles, result.y0, *strawMeas);
0045 }
0046 ACTS_DEBUG(__func__ << "() - " << __LINE__ << ": Overall chi2: "
0047 << result.chi2 << ", nDoF: " << result.nDoF
0048 << ", redChi2: " << (result.chi2 / result.nDoF));
0049 }
0050
0051 template <CompositeSpacePoint Point_t>
0052 double FastStrawLineFitter::chi2Term(const TrigonomHelper& angle,
0053 const double y0, const Point_t& strawMeas,
0054 std::optional<double> r) const {
0055 if (!strawMeas.isStraw()) {
0056 return 0.;
0057 }
0058 const double cov = strawMeas.covariance()[s_covIdx];
0059 if (cov < std::numeric_limits<double>::epsilon()) {
0060 return 0.;
0061 }
0062 const Vector& pos = strawMeas.localPosition();
0063 const double y = pos.dot(strawMeas.toNextSensor());
0064 const double z = pos.dot(strawMeas.planeNormal());
0065 const double dist = Acts::abs((y - y0) * angle.cosTheta - z * angle.sinTheta);
0066 ACTS_VERBOSE(__func__ << "() - " << __LINE__ << ": Distance straw (" << y
0067 << ", " << z
0068 << "), r: " << r.value_or(strawMeas.driftRadius())
0069 << " - track: " << dist);
0070 return Acts::pow(dist - r.value_or(strawMeas.driftRadius()), 2) / cov;
0071 }
0072
0073 template <CompositeSpacePointContainer StripCont_t>
0074 std::optional<FastStrawLineFitter::FitResult> FastStrawLineFitter::fit(
0075 const StripCont_t& measurements, const ResidualIdx projection) const {
0076 if (projection == ResidualIdx::time) {
0077 ACTS_WARNING(__func__ << "() - " << __LINE__
0078 << ": Only spatial projections, "
0079 << "i.e. nonBending / bending are sensible");
0080 return std::nullopt;
0081 }
0082
0083 FitAuxiliaries auxVars{};
0084 Vector centerOfGravity{Vector::Zero()};
0085
0086 auto select = [&projection](const auto& strip) -> bool {
0087
0088
0089 if (strip->isStraw()) {
0090 return strip->measuresLoc0() && projection == ResidualIdx::nonBending;
0091 }
0092
0093 return (strip->measuresLoc0() && projection == ResidualIdx::nonBending) ||
0094 (strip->measuresLoc1() && projection == ResidualIdx::bending);
0095 };
0096
0097 auxVars.invCovs.resize(measurements.size());
0098 for (const auto& [sIdx, strip] : enumerate(measurements)) {
0099 if (!select(strip)) {
0100 ACTS_VERBOSE(__func__ << "() - " << __LINE__
0101 << ": Skip strip measurement @"
0102 << toString(strip->localPosition()));
0103 continue;
0104 }
0105 const auto& invCov =
0106 (auxVars.invCovs[sIdx] =
0107 1. / strip->covariance()[toUnderlying(projection)]);
0108 auxVars.covNorm += invCov;
0109 centerOfGravity += invCov * strip->localPosition();
0110 ++auxVars.nDoF;
0111 }
0112
0113 if (auxVars.nDoF < 3) {
0114 return std::nullopt;
0115 }
0116
0117
0118 auxVars.nDoF -= 2u;
0119 auxVars.covNorm = 1. / auxVars.covNorm;
0120 centerOfGravity *= auxVars.covNorm;
0121
0122 bool centerSet{false};
0123 for (const auto& [sIdx, strip] : enumerate(measurements)) {
0124 if (!select(strip)) {
0125 continue;
0126 }
0127 const Vector pos = strip->localPosition() - centerOfGravity;
0128 const Vector& measDir{
0129 (projection == ResidualIdx::nonBending && strip->measuresLoc1()) ||
0130 strip->isStraw()
0131 ? strip->sensorDirection()
0132 : strip->toNextSensor()};
0133
0134 if (!centerSet) {
0135 auxVars.centerY = centerOfGravity.dot(measDir);
0136 auxVars.centerZ = centerOfGravity.dot(strip->planeNormal());
0137 centerSet = true;
0138 }
0139
0140 const double y = pos.dot(measDir);
0141 const double z = pos.dot(strip->planeNormal());
0142
0143 const auto& invCov = auxVars.invCovs[sIdx];
0144 auxVars.T_zzyy += invCov * (Acts::square(z) - Acts::square(y));
0145 auxVars.T_yz += invCov * z * y;
0146 }
0147 return fit(auxVars);
0148 }
0149
0150 template <CompositeSpacePointContainer StrawCont_t,
0151 CompositeSpacePointFastCalibrator<
0152 Acts::RemovePointer_t<typename StrawCont_t::value_type>>
0153 Calibrator_t>
0154 void FastStrawLineFitter::calcPostFitChi2(const Acts::CalibrationContext& ctx,
0155 const StrawCont_t& measurements,
0156 const Calibrator_t& calibrator,
0157 FitResultT0& result) const {
0158 const TrigonomHelper angles{result.theta};
0159 result.chi2 = 0.;
0160 for (const auto& strawMeas : measurements) {
0161 result.chi2 += chi2Term(angles, result.y0, *strawMeas,
0162 calibrator.driftRadius(ctx, *strawMeas, result.t0));
0163 }
0164 ACTS_DEBUG(__func__ << "() - " << __LINE__ << ": Overall chi2: "
0165 << result.chi2 << ", nDoF: " << result.nDoF
0166 << ", redChi2: " << (result.chi2 / result.nDoF));
0167 }
0168
0169 template <CompositeSpacePointContainer StrawCont_t>
0170 FastStrawLineFitter::FitAuxiliaries FastStrawLineFitter::fillAuxiliaries(
0171 const StrawCont_t& measurements, const std::vector<int>& signs) const {
0172 FitAuxiliaries auxVars{};
0173 auxVars.invCovs.resize(signs.size(), -1.);
0174 Vector centerOfGravity{Vector::Zero()};
0175
0176
0177 for (const auto& [sIdx, strawMeas] : enumerate(measurements)) {
0178 if (!strawMeas->isStraw()) {
0179 ACTS_DEBUG(__func__ << "() - " << __LINE__
0180 << ": The measurement is not a straw");
0181 continue;
0182 }
0183 const double cov = strawMeas->covariance()[s_covIdx];
0184 if (cov < std::numeric_limits<double>::epsilon()) {
0185 ACTS_WARNING(__func__ << "() - " << __LINE__ << ": The covariance ("
0186 << cov << ") of the measurement is invalid.");
0187 continue;
0188 }
0189 auto& invCov = (auxVars.invCovs[sIdx] = 1. / cov);
0190 auxVars.covNorm += invCov;
0191 centerOfGravity += invCov * strawMeas->localPosition();
0192 ++auxVars.nDoF;
0193 }
0194 if (auxVars.nDoF < 3) {
0195 std::stringstream sstr{};
0196 for (const auto& [sIdx, strawMeas] : enumerate(measurements)) {
0197 sstr << " --- " << (sIdx + 1) << ") "
0198 << toString(strawMeas->localPosition())
0199 << ", r: " << strawMeas->driftRadius()
0200 << ", weight: " << auxVars.invCovs[sIdx] << std::endl;
0201 }
0202 ACTS_WARNING(__func__ << "() - " << __LINE__
0203 << ": At least 3 measurements are required to "
0204 "perform the straw line fit\n"
0205 << sstr.str());
0206 auxVars.nDoF = 0u;
0207 return auxVars;
0208 }
0209
0210
0211 auxVars.nDoF -= 2u;
0212 auxVars.covNorm = 1. / auxVars.covNorm;
0213 centerOfGravity *= auxVars.covNorm;
0214
0215
0216 bool centerSet{false};
0217 for (const auto& [sIdx, strawMeas] : enumerate(measurements)) {
0218 const auto& invCov = auxVars.invCovs[sIdx];
0219
0220 if (invCov < 0.) {
0221 continue;
0222 }
0223 if (!centerSet) {
0224 auxVars.centerY = centerOfGravity.dot(strawMeas->toNextSensor());
0225 auxVars.centerZ = centerOfGravity.dot(strawMeas->planeNormal());
0226 centerSet = true;
0227 }
0228 const Vector pos = strawMeas->localPosition() - centerOfGravity;
0229 const double y = pos.dot(strawMeas->toNextSensor());
0230 const double z = pos.dot(strawMeas->planeNormal());
0231 const double r = strawMeas->driftRadius();
0232
0233 auxVars.T_zzyy += invCov * (Acts::square(z) - Acts::square(y));
0234 auxVars.T_yz += invCov * z * y;
0235 const double sInvCov = -invCov * signs[sIdx];
0236 auxVars.T_rz += sInvCov * z * r;
0237 auxVars.T_ry += sInvCov * y * r;
0238 auxVars.fitY0 += sInvCov * r;
0239 }
0240 auxVars.fitY0 *= auxVars.covNorm;
0241
0242 return auxVars;
0243 }
0244
0245 template <CompositeSpacePointContainer StrawCont_t,
0246 CompositeSpacePointFastCalibrator<
0247 Acts::RemovePointer_t<typename StrawCont_t::value_type>>
0248 Calibrator_t>
0249 std::optional<FastStrawLineFitter::FitResultT0> FastStrawLineFitter::fit(
0250 const Acts::CalibrationContext& ctx, const Calibrator_t& calibrator,
0251 const StrawCont_t& measurements, const std::vector<int>& signs,
0252 std::optional<double> startT0) const {
0253 using namespace Acts::UnitLiterals;
0254 if (measurements.size() != signs.size()) {
0255 ACTS_WARNING(
0256 __func__ << "() - " << __LINE__
0257 << ": Not all measurements are associated with a drift sign");
0258 return std::nullopt;
0259 }
0260
0261 FitResultT0 result{};
0262 result.t0 = startT0.value_or(0.);
0263
0264 FitAuxiliariesWithT0 fitPars{
0265 fillAuxiliaries(ctx, calibrator, measurements, signs, result.t0)};
0266 result.theta = startTheta(fitPars);
0267 result.nDoF = fitPars.nDoF;
0268 ACTS_DEBUG(__func__ << "() - " << __LINE__
0269 << ": Initial fit parameters: " << result);
0270 UpdateStatus iterStatus{UpdateStatus::GoodStep};
0271
0272 while ((iterStatus = updateIteration(fitPars, result)) !=
0273 UpdateStatus::Exceeded) {
0274 if (iterStatus == UpdateStatus::Converged) {
0275 calcPostFitChi2(ctx, measurements, calibrator, result);
0276 return result;
0277 }
0278 fitPars = fillAuxiliaries(ctx, calibrator, measurements, signs, result.t0);
0279 }
0280 if (logger().doPrint(Logging::VERBOSE)) {
0281 ACTS_VERBOSE("Fit failed, printing all measurements:");
0282 for (const auto& meas : measurements) {
0283 ACTS_VERBOSE(
0284 "Pos: " << Acts::toString(meas->localPosition()) << ", t,t0: "
0285 << meas->time() / 1._ns << ", " << result.t0 / 1._ns
0286 << ", truthR, RecoR: " << meas->driftRadius() << ", "
0287 << calibrator.driftRadius(ctx, *meas, result.t0) << ", v: "
0288 << calibrator.driftVelocity(ctx, *meas, result.t0) * 1._ns
0289 << ", a: "
0290 << calibrator.driftAcceleration(ctx, *meas, result.t0) *
0291 1._ns * 1._ns);
0292 }
0293 ACTS_VERBOSE("Result: " << result);
0294 }
0295 return std::nullopt;
0296 }
0297
0298 template <CompositeSpacePointContainer StrawCont_t,
0299 CompositeSpacePointFastCalibrator<
0300 Acts::RemovePointer_t<typename StrawCont_t::value_type>>
0301 Calibrator_t>
0302 FastStrawLineFitter::FitAuxiliariesWithT0 FastStrawLineFitter::fillAuxiliaries(
0303 const CalibrationContext& ctx, const Calibrator_t& calibrator,
0304 const StrawCont_t& measurements, const std::vector<int>& signs,
0305 const double t0) const {
0306 using namespace Acts::UnitLiterals;
0307 FitAuxiliariesWithT0 auxVars{fillAuxiliaries(measurements, signs)};
0308 if (auxVars.nDoF <= 1) {
0309 auxVars.nDoF = 0;
0310 return auxVars;
0311 }
0312
0313 --auxVars.nDoF;
0314
0315 auxVars.T_rz = 0.;
0316 auxVars.T_ry = 0.;
0317 auxVars.fitY0 = 0.;
0318 for (const auto& [spIdx, strawMeas] : enumerate(measurements)) {
0319 const double& invCov = auxVars.invCovs[spIdx];
0320
0321 if (invCov < 0.) {
0322 continue;
0323 }
0324 const double sInvCov = -invCov * signs[spIdx];
0325 const double r = calibrator.driftRadius(ctx, *strawMeas, t0);
0326 const double v = calibrator.driftVelocity(ctx, *strawMeas, t0);
0327 const double a = calibrator.driftAcceleration(ctx, *strawMeas, t0);
0328 const double y = strawMeas->localPosition().dot(strawMeas->toNextSensor()) -
0329 auxVars.centerY;
0330 const double z = strawMeas->localPosition().dot(strawMeas->planeNormal()) -
0331 auxVars.centerZ;
0332
0333 ACTS_VERBOSE(__func__ << "() - " << __LINE__ << ": # " << (spIdx + 1)
0334 << ") t,t0: " << strawMeas->time() / 1._ns << ", "
0335 << t0 / 1._ns << " r: " << r << ", v: " << v * 1._ns
0336 << ", a: " << a * 1._ns * 1._ns);
0337 auxVars.fitY0 += sInvCov * r;
0338 auxVars.R_v += sInvCov * v;
0339 auxVars.R_a += sInvCov * a;
0340
0341 auxVars.T_rz += sInvCov * z * r;
0342 auxVars.T_ry += sInvCov * y * r;
0343
0344 auxVars.T_vy += sInvCov * v * y;
0345 auxVars.T_vz += sInvCov * v * z;
0346
0347 auxVars.R_vr += invCov * r * v;
0348 auxVars.R_vv += invCov * v * v;
0349
0350 auxVars.T_ay += sInvCov * a * y;
0351 auxVars.T_az += sInvCov * a * z;
0352
0353 auxVars.R_ar += invCov * a * r;
0354 }
0355 auxVars.fitY0 *= auxVars.covNorm;
0356 ACTS_DEBUG(__func__ << "() - " << __LINE__ << " Fit constants calculated \n"
0357 << auxVars);
0358 return auxVars;
0359 }
0360
0361 }