File indexing completed on 2025-08-28 08:13:23
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include <boost/test/unit_test.hpp>
0010
0011 #include "Acts/Definitions/Units.hpp"
0012 #include "Acts/Seeding/detail/FastStrawLineFitter.hpp"
0013 #include "Acts/Surfaces/detail/LineHelper.hpp"
0014 #include "Acts/Surfaces/detail/PlanarHelper.hpp"
0015 #include "Acts/Utilities/StringHelpers.hpp"
0016
0017 #include <format>
0018 #include <random>
0019
0020 #include "TFile.h"
0021 #include "TTree.h"
0022
0023 using namespace Acts;
0024 using namespace Acts::Experimental;
0025 using namespace Acts::Experimental::detail;
0026 using namespace Acts::UnitLiterals;
0027 using RandomEngine = std::mt19937;
0028
0029 constexpr std::size_t nTrials = 1;
0030
0031 namespace Acts::Test {
0032
0033 constexpr bool debugMode = true;
0034
0035 ACTS_LOCAL_LOGGER(getDefaultLogger("FastStrawLineFitTests",
0036 Logging::Level::INFO));
0037
0038 class StrawTestPoint;
0039 using TestStrawCont_t = std::vector<std::unique_ptr<StrawTestPoint>>;
0040 using Line_t = CompSpacePointAuxiliaries::Line_t;
0041 using ResidualIdx = FastStrawLineFitter::ResidualIdx;
0042
0043 template <typename T>
0044 std::ostream& operator<<(std::ostream& ostr, const std::vector<T>& v) {
0045 ostr << "[";
0046 for (std::size_t i = 0; i < v.size(); ++i) {
0047 ostr << v[i];
0048 if (i + 1 != v.size()) {
0049 ostr << ", ";
0050 }
0051 }
0052 ostr << "]";
0053 return ostr;
0054 }
0055
0056 class StrawTestPoint {
0057 public:
0058 StrawTestPoint(const Vector3& pos, const double driftR,
0059 const double driftRUncert)
0060 : m_pos{pos}, m_driftR{Acts::abs(driftR)} {
0061 m_cov[toUnderlying(ResidualIdx::bending)] = Acts::square(driftRUncert);
0062 }
0063
0064 const Vector3& localPosition() const { return m_pos; }
0065
0066 const Vector3& sensorDirection() const { return m_wireDir; }
0067
0068 const Vector3& toNextSensor() const { return m_toNext; }
0069
0070 const Vector3& planeNormal() const { return m_planeNorm; }
0071
0072 double driftRadius() const { return m_driftR; }
0073
0074 const std::array<double, 3>& covariance() const { return m_cov; }
0075
0076 double driftUncert() const {
0077 return std::sqrt(m_cov[toUnderlying(ResidualIdx::bending)]);
0078 }
0079
0080 double time() const { return m_drifT; }
0081
0082 bool isStraw() const { return true; }
0083
0084 bool hasTime() const { return false; }
0085
0086 bool measuresLoc0() const { return false; }
0087
0088 bool measuresLoc1() const { return false; }
0089 void setRadius(const double r, const double uncertR) {
0090 m_driftR = Acts::abs(r);
0091 m_cov[toUnderlying(ResidualIdx::bending)] = Acts::square(uncertR);
0092 }
0093 void setTimeRecord(const double t) { m_drifT = t; }
0094
0095 private:
0096 Vector3 m_pos{Vector3::Zero()};
0097 Vector3 m_wireDir{Vector3::UnitX()};
0098 Vector3 m_toNext{Vector3::UnitY()};
0099 Vector3 m_planeNorm{Vector3::UnitZ()};
0100 double m_driftR{0.};
0101 std::array<double, 3> m_cov{Acts::filledArray<double, 3>(0.)};
0102 double m_drifT{0.};
0103 };
0104 static_assert(CompositeSpacePoint<StrawTestPoint>);
0105
0106 class StrawTestCalibrator {
0107 public:
0108
0109
0110 static constexpr double CoeffRtoT = 1.;
0111 static constexpr double CoeffTtoR = 1. / CoeffRtoT;
0112
0113 static double calcDriftUncert(const double driftR) {
0114 return 0.1_mm + 0.15_mm * Acts::pow(1._mm + Acts::abs(driftR), -2);
0115 }
0116 static double driftTime(const double r) {
0117 return CoeffRtoT * Acts::square(r);
0118 }
0119 static double driftRadius(const double t) {
0120 return std::sqrt(Acts::abs(t) * CoeffTtoR);
0121 }
0122
0123 static double driftRadius(const Acts::CalibrationContext& ,
0124 const StrawTestPoint& straw, const double t0) {
0125 return driftRadius(straw.time() - t0);
0126 }
0127 static double driftVelocity(const Acts::CalibrationContext& ,
0128 const StrawTestPoint& straw, const double t0) {
0129 return CoeffTtoR / (2. * driftRadius(straw.time() - t0));
0130 }
0131 static double driftAcceleration(const Acts::CalibrationContext& ,
0132 const StrawTestPoint& straw,
0133 const double t0) {
0134 return -Acts::square(CoeffTtoR) /
0135 (4. * Acts::pow(driftRadius(straw.time() - t0), 3));
0136 }
0137 };
0138
0139
0140
0141
0142
0143 Line_t generateLine(RandomEngine& engine) {
0144 using ParIndex = Line_t::ParIndex;
0145 Line_t::ParamVector linePars{};
0146 linePars[toUnderlying(ParIndex::x0)] = 0.;
0147 linePars[toUnderlying(ParIndex::phi)] = 90._degree;
0148 linePars[toUnderlying(ParIndex::y0)] =
0149 std::uniform_real_distribution{-5000., 5000.}(engine);
0150 linePars[toUnderlying(ParIndex::theta)] =
0151 std::uniform_real_distribution{0.1_degree, 179.9_degree}(engine);
0152 Line_t line{};
0153 line.updateParameters(linePars);
0154 if (Acts::abs(linePars[toUnderlying(ParIndex::theta)] - 90._degree) <
0155 0.2_degree) {
0156 return generateLine(engine);
0157 }
0158 ACTS_DEBUG("Generated parameters theta: "
0159 << (linePars[toUnderlying(ParIndex::theta)] / 1._degree)
0160 << ", y0: " << linePars[toUnderlying(ParIndex::y0)] << " - "
0161 << toString(line.position()) << " + "
0162 << toString(line.direction()));
0163 return line;
0164 }
0165
0166
0167
0168
0169
0170
0171
0172
0173
0174
0175 TestStrawCont_t generateStrawCircles(const Line_t& trajLine,
0176 RandomEngine& engine, bool smearRadius) {
0177 const Vector3 posStaggering{0., std::cos(60._degree), std::sin(60._degree)};
0178 const Vector3 negStaggering{0., -std::cos(60._degree), std::sin(60._degree)};
0179
0180 constexpr std::size_t nLayersPerMl = 8;
0181
0182 constexpr std::size_t nTubeLayers = nLayersPerMl * 2;
0183
0184 constexpr double tubeRadius = 15._mm;
0185
0186 constexpr double tubeLayerDist = 1.2_m;
0187
0188 std::array<Vector3, nTubeLayers> tubePositions{
0189 filledArray<Vector3, nTubeLayers>(Vector3{0., tubeRadius, tubeRadius})};
0190
0191 for (std::size_t l = 1; l < nTubeLayers; ++l) {
0192 const Vector3& layStag{l % 2 == 1 ? posStaggering : negStaggering};
0193 tubePositions[l] = tubePositions[l - 1] + 2. * tubeRadius * layStag;
0194
0195 if (l == nLayersPerMl) {
0196 tubePositions[l] += tubeLayerDist * Vector3::UnitZ();
0197 }
0198 }
0199
0200 ACTS_DEBUG("##############################################");
0201
0202 for (std::size_t l = 0; l < nTubeLayers; ++l) {
0203 ACTS_DEBUG(" *** " << (l + 1) << " - " << toString(tubePositions[l]));
0204 }
0205 ACTS_DEBUG("##############################################");
0206
0207 TestStrawCont_t circles{};
0208
0209
0210 for (const auto& stag : tubePositions) {
0211 auto planeExtpLow = Acts::PlanarHelper::intersectPlane(
0212 trajLine.position(), trajLine.direction(), Vector3::UnitZ(),
0213 stag.z() - tubeRadius);
0214 auto planeExtpHigh = Acts::PlanarHelper::intersectPlane(
0215 trajLine.position(), trajLine.direction(), Vector3::UnitZ(),
0216 stag.z() + tubeRadius);
0217
0218 ACTS_DEBUG("Extrapolated to plane " << toString(planeExtpLow.position())
0219 << " "
0220 << toString(planeExtpHigh.position()));
0221
0222 const auto dToFirstLow = static_cast<int>(std::ceil(
0223 (planeExtpLow.position().y() - stag.y()) / (2. * tubeRadius)));
0224 const auto dToFirstHigh = static_cast<int>(std::ceil(
0225 (planeExtpHigh.position().y() - stag.y()) / (2. * tubeRadius)));
0226
0227 const int dT = dToFirstHigh > dToFirstLow ? 1 : -1;
0228
0229
0230
0231 for (int tN = dToFirstLow - dT; tN != dToFirstHigh + 2 * dT; tN += dT) {
0232 const Vector3 tube = stag + 2. * tN * tubeRadius * Vector3::UnitY();
0233 const double rad = Acts::detail::LineHelper::signedDistance(
0234 tube, Vector3::UnitX(), trajLine.position(), trajLine.direction());
0235 ACTS_DEBUG("Tube position: " << toString(tube) << ", radius: " << rad);
0236
0237 if (std::abs(rad) > tubeRadius) {
0238 continue;
0239 }
0240 std::normal_distribution<> dist{
0241 rad, StrawTestCalibrator::calcDriftUncert(rad)};
0242 const double smearedR = smearRadius ? std::abs(dist(engine)) : rad;
0243 if (smearedR > tubeRadius) {
0244 continue;
0245 }
0246 circles.emplace_back(std::make_unique<StrawTestPoint>(
0247 tube, smearedR, StrawTestCalibrator::calcDriftUncert(smearedR)));
0248 }
0249 }
0250 ACTS_DEBUG("Track hit in total " << circles.size() << " tubes ");
0251 return circles;
0252 }
0253
0254
0255
0256 double calcChi2(const TestStrawCont_t& measurements, const Line_t& track) {
0257 double chi2{0.};
0258 for (const auto& meas : measurements) {
0259 const double dist = Acts::detail::LineHelper::signedDistance(
0260 meas->localPosition(), meas->sensorDirection(), track.position(),
0261 track.direction());
0262 ACTS_DEBUG("Distance straw: " << toString(meas->localPosition())
0263 << ", r: " << meas->driftRadius()
0264 << " - to track: " << Acts::abs(dist));
0265
0266 chi2 += Acts::square((Acts::abs(dist) - meas->driftRadius()) /
0267 meas->driftUncert());
0268 }
0269 return chi2;
0270 }
0271
0272 BOOST_AUTO_TEST_SUITE(FastStrawLineFitTests)
0273
0274 BOOST_AUTO_TEST_CASE(SimpleLineFit) {
0275 RandomEngine engine{1419};
0276
0277 std::unique_ptr<TFile> outFile{};
0278 std::unique_ptr<TTree> outTree{};
0279 double trueY0{0.};
0280 double trueTheta{0.};
0281 double fitY0{0.};
0282 double fitTheta{0.};
0283 double fitdY0{0.};
0284 double fitdTheta{0.};
0285 double chi2{0.};
0286 std::size_t nDoF{0u};
0287 std::size_t nIter{0u};
0288 if (debugMode) {
0289 outFile.reset(TFile::Open("FastStrawLineFitTest.root", "RECREATE"));
0290 BOOST_CHECK_EQUAL(outFile->IsZombie(), false);
0291 outTree = std::make_unique<TTree>("FastFitTree", "FastFitTree");
0292 outTree->Branch("trueY0", &trueY0);
0293 outTree->Branch("trueTheta", &trueTheta);
0294 outTree->Branch("fitY0", &fitY0);
0295 outTree->Branch("fitTheta", &fitTheta);
0296 outTree->Branch("errY0", &fitdY0);
0297 outTree->Branch("errTheta", &fitdTheta);
0298 outTree->Branch("chi2", &chi2);
0299 outTree->Branch("nDoF", &nDoF);
0300 outTree->Branch("nIter", &nIter);
0301 }
0302
0303 FastStrawLineFitter::Config cfg{};
0304 FastStrawLineFitter fastFitter{cfg};
0305 for (std::size_t n = 0; n < nTrials; ++n) {
0306 auto track = generateLine(engine);
0307 auto strawPoints = generateStrawCircles(track, engine, true);
0308 if (strawPoints.size() < 3) {
0309 ACTS_WARNING(__func__ << "() - " << __LINE__ << ": -- event: " << n
0310 << ", track " << toString(track.position()) << " + "
0311 << toString(track.direction())
0312 << " did not lead to any valid measurement ");
0313 continue;
0314 }
0315 const std::vector<std::int32_t> trueDriftSigns =
0316 CompSpacePointAuxiliaries::strawSigns(track, strawPoints);
0317
0318 BOOST_CHECK_LE(calcChi2(generateStrawCircles(track, engine, false), track),
0319 1.e-12);
0320 ACTS_DEBUG("True drift signs: " << trueDriftSigns << ", chi2: " << chi2);
0321
0322 auto fitResult = fastFitter.fit(strawPoints, trueDriftSigns);
0323 if (!fitResult) {
0324 continue;
0325 }
0326 auto trackPars = track.parameters();
0327
0328 trueY0 = trackPars[toUnderlying(Line_t::ParIndex::y0)];
0329 trueTheta = trackPars[toUnderlying(Line_t::ParIndex::theta)];
0330
0331 trackPars[toUnderlying(Line_t::ParIndex::theta)] = (*fitResult).theta;
0332 trackPars[toUnderlying(Line_t::ParIndex::y0)] = (*fitResult).y0;
0333 trackPars[toUnderlying(Line_t::ParIndex::phi)] = 90._degree;
0334 track.updateParameters(trackPars);
0335 ACTS_DEBUG("Updated parameters: "
0336 << (trackPars[toUnderlying(Line_t::ParIndex::theta)] / 1._degree)
0337 << ", y0: " << trackPars[toUnderlying(Line_t::ParIndex::y0)]
0338 << " -- " << toString(track.position()) << " + "
0339 << toString(track.direction()));
0340
0341 const double testChi2 = calcChi2(strawPoints, track);
0342 ACTS_DEBUG("testChi2: " << testChi2 << ", fit:" << (*fitResult).chi2);
0343
0344 BOOST_CHECK_LE(Acts::abs(testChi2 - (*fitResult).chi2), 1.e-9);
0345 if (debugMode) {
0346 fitTheta = (*fitResult).theta;
0347 fitY0 = (*fitResult).y0;
0348 fitdTheta = (*fitResult).dTheta;
0349 fitdY0 = (*fitResult).dY0;
0350 nDoF = (*fitResult).nDoF;
0351 chi2 = (*fitResult).chi2;
0352 nIter = (*fitResult).nIter;
0353 outTree->Fill();
0354 }
0355 }
0356 if (debugMode) {
0357 outFile->WriteObject(outTree.get(), outTree->GetName());
0358 outTree.reset();
0359 }
0360 }
0361
0362 BOOST_AUTO_TEST_SUITE_END()
0363 }