File indexing completed on 2025-11-04 09:22:00
0001 
0002 
0003 
0004 
0005 
0006 
0007 
0008 
0009 #pragma once
0010 
0011 #include "Acts/Definitions/Algebra.hpp"
0012 #include "Acts/Surfaces/PerigeeSurface.hpp"
0013 #include "Acts/Vertexing/LinearizedTrack.hpp"
0014 #include "Acts/Vertexing/TrackAtVertex.hpp"
0015 #include "Acts/Vertexing/Vertex.hpp"
0016 
0017 namespace Acts::KalmanVertexUpdater::detail {
0018 
0019 
0020 
0021 
0022 template <unsigned int nDimVertex>
0023 struct Cache {
0024   using VertexVector = ActsVector<nDimVertex>;
0025   using VertexMatrix = ActsSquareMatrix<nDimVertex>;
0026   
0027   VertexVector newVertexPos = VertexVector::Zero();
0028   
0029   VertexMatrix newVertexCov = VertexMatrix::Zero();
0030   
0031   VertexMatrix newVertexWeight = VertexMatrix::Zero();
0032   
0033   VertexMatrix oldVertexWeight = VertexMatrix::Zero();
0034   
0035   SquareMatrix3 wMat = SquareMatrix3::Zero();
0036 };
0037 
0038 
0039 
0040 
0041 
0042 
0043 
0044 
0045 
0046 
0047 
0048 
0049 
0050 
0051 
0052 template <unsigned int nDimVertex>
0053 void calculateUpdate(const Vertex& vtx, const Acts::LinearizedTrack& linTrack,
0054                      const double trackWeight, const int sign,
0055                      Cache<nDimVertex>& cache) {
0056   constexpr unsigned int nBoundParams = nDimVertex + 2;
0057   using ParameterVector = ActsVector<nBoundParams>;
0058   using ParameterMatrix = ActsSquareMatrix<nBoundParams>;
0059   
0060   
0061   
0062   const ActsMatrix<nBoundParams, nDimVertex> posJac =
0063       linTrack.positionJacobian.block<nBoundParams, nDimVertex>(0, 0);
0064   
0065   const ActsMatrix<nBoundParams, 3> momJac =
0066       linTrack.momentumJacobian.block<nBoundParams, 3>(0, 0);
0067   
0068   const ParameterVector trkParams =
0069       linTrack.parametersAtPCA.head<nBoundParams>();
0070   
0071   const ParameterVector constTerm = linTrack.constantTerm.head<nBoundParams>();
0072   
0073   
0074   
0075   
0076   
0077   
0078   
0079   
0080   
0081   const ParameterMatrix trkParamWeight =
0082       linTrack.covarianceAtPCA.block<nBoundParams, nBoundParams>(0, 0)
0083           .inverse();
0084 
0085   
0086   const ActsVector<nDimVertex> oldVtxPos =
0087       vtx.fullPosition().template head<nDimVertex>();
0088   
0089   cache.oldVertexWeight =
0090       (vtx.fullCovariance().template block<nDimVertex, nDimVertex>(0, 0))
0091           .inverse();
0092 
0093   
0094   cache.wMat = (momJac.transpose() * (trkParamWeight * momJac)).inverse();
0095 
0096   
0097   ParameterMatrix gBMat = trkParamWeight - trkParamWeight * momJac *
0098                                                cache.wMat * momJac.transpose() *
0099                                                trkParamWeight;
0100 
0101   
0102   cache.newVertexWeight = cache.oldVertexWeight + sign * trackWeight *
0103                                                       posJac.transpose() *
0104                                                       gBMat * posJac;
0105   
0106   cache.newVertexCov = cache.newVertexWeight.inverse();
0107 
0108   
0109   cache.newVertexPos =
0110       cache.newVertexCov * (cache.oldVertexWeight * oldVtxPos +
0111                             sign * trackWeight * posJac.transpose() * gBMat *
0112                                 (trkParams - constTerm));
0113 }
0114 
0115 template <unsigned int nDimVertex>
0116 double vertexPositionChi2Update(const Vector4& oldVtxPos,
0117                                 const Cache<nDimVertex>& cache) {
0118   ActsVector<nDimVertex> posDiff =
0119       cache.newVertexPos - oldVtxPos.template head<nDimVertex>();
0120 
0121   
0122   return posDiff.transpose() * (cache.oldVertexWeight * posDiff);
0123 }
0124 
0125 template <unsigned int nDimVertex>
0126 double trackParametersChi2(const LinearizedTrack& linTrack,
0127                            const Cache<nDimVertex>& cache) {
0128   constexpr unsigned int nBoundParams = nDimVertex + 2;
0129   using ParameterVector = ActsVector<nBoundParams>;
0130   using ParameterMatrix = ActsSquareMatrix<nBoundParams>;
0131   
0132   const ActsMatrix<nBoundParams, nDimVertex> posJac =
0133       linTrack.positionJacobian.block<nBoundParams, nDimVertex>(0, 0);
0134   
0135   const ActsMatrix<nBoundParams, 3> momJac =
0136       linTrack.momentumJacobian.block<nBoundParams, 3>(0, 0);
0137   
0138   const ParameterVector trkParams =
0139       linTrack.parametersAtPCA.head<nBoundParams>();
0140   
0141   const ParameterVector constTerm = linTrack.constantTerm.head<nBoundParams>();
0142   
0143   
0144   
0145   const ParameterMatrix trkParamWeight =
0146       linTrack.covarianceAtPCA.block<nBoundParams, nBoundParams>(0, 0)
0147           .inverse();
0148 
0149   
0150   const ParameterVector posJacVtxPos = posJac * cache.newVertexPos;
0151 
0152   
0153   Vector3 newTrkMom = cache.wMat * momJac.transpose() * trkParamWeight *
0154                       (trkParams - constTerm - posJacVtxPos);
0155 
0156   
0157   ParameterVector linearizedTrackParameters =
0158       constTerm + posJacVtxPos + momJac * newTrkMom;
0159 
0160   
0161   ParameterVector paramDiff = trkParams - linearizedTrackParameters;
0162 
0163   
0164   return paramDiff.transpose() * (trkParamWeight * paramDiff);
0165 }
0166 
0167 
0168 
0169 
0170 
0171 
0172 
0173 
0174 
0175 
0176 
0177 template <unsigned int nDimVertex>
0178 Acts::BoundMatrix calculateTrackCovariance(
0179     const SquareMatrix3& wMat, const ActsMatrix<nDimVertex, 3>& crossCovVP,
0180     const ActsSquareMatrix<nDimVertex>& vtxCov,
0181     const BoundVector& newTrkParams) {
0182   
0183   ActsSquareMatrix<3> momCov =
0184       wMat + crossCovVP.transpose() * vtxCov.inverse() * crossCovVP;
0185 
0186   
0187   
0188   
0189   
0190   constexpr unsigned int nFreeParams = nDimVertex + 3;
0191   ActsSquareMatrix<nFreeParams> freeTrkCov(
0192       ActsSquareMatrix<nFreeParams>::Zero());
0193 
0194   freeTrkCov.template block<3, 3>(0, 0) = vtxCov.template block<3, 3>(0, 0);
0195   freeTrkCov.template block<3, 3>(0, 3) = crossCovVP.template block<3, 3>(0, 0);
0196   freeTrkCov.template block<3, 3>(3, 0) =
0197       crossCovVP.template block<3, 3>(0, 0).transpose();
0198   freeTrkCov.template block<3, 3>(3, 3) = momCov;
0199 
0200   
0201   if constexpr (nFreeParams == 7) {
0202     freeTrkCov.template block<3, 1>(0, 6) = vtxCov.template block<3, 1>(0, 3);
0203     freeTrkCov.template block<3, 1>(3, 6) =
0204         crossCovVP.template block<1, 3>(3, 0).transpose();
0205     freeTrkCov.template block<1, 3>(6, 0) = vtxCov.template block<1, 3>(3, 0);
0206     freeTrkCov.template block<1, 3>(6, 3) =
0207         crossCovVP.template block<1, 3>(3, 0);
0208     freeTrkCov(6, 6) = vtxCov(3, 3);
0209   }
0210 
0211   
0212   constexpr unsigned int nBoundParams = nDimVertex + 2;
0213   ActsMatrix<nBoundParams, nFreeParams> freeToBoundJac(
0214       ActsMatrix<nBoundParams, nFreeParams>::Zero());
0215 
0216   
0217   
0218   freeToBoundJac(0, 0) = -std::sin(newTrkParams[2]);
0219   freeToBoundJac(0, 1) = std::cos(newTrkParams[2]);
0220 
0221   double tanTheta = std::tan(newTrkParams[3]);
0222 
0223   
0224   freeToBoundJac(1, 0) = -freeToBoundJac(0, 1) / tanTheta;
0225   freeToBoundJac(1, 1) = freeToBoundJac(0, 0) / tanTheta;
0226 
0227   
0228   constexpr unsigned int nDimIdentity = nFreeParams - 2;
0229   freeToBoundJac.template block<nDimIdentity, nDimIdentity>(1, 2) =
0230       ActsMatrix<nDimIdentity, nDimIdentity>::Identity();
0231 
0232   
0233   BoundMatrix boundTrackCov(BoundMatrix::Identity());
0234   boundTrackCov.block<nBoundParams, nBoundParams>(0, 0) =
0235       (freeToBoundJac * (freeTrkCov * freeToBoundJac.transpose()));
0236 
0237   return boundTrackCov;
0238 }
0239 
0240 template <unsigned int nDimVertex>
0241 void updateVertexWithTrackImpl(Vertex& vtx, TrackAtVertex& trk, int sign) {
0242   if constexpr (nDimVertex != 3 && nDimVertex != 4) {
0243     throw std::invalid_argument(
0244         "The vertex dimension must either be 3 (when fitting the spatial "
0245         "coordinates) or 4 (when fitting the spatial coordinates + time).");
0246   }
0247 
0248   double trackWeight = trk.trackWeight;
0249 
0250   
0251   Cache<nDimVertex> cache;
0252 
0253   
0254   calculateUpdate(vtx, trk.linearizedState, trackWeight, sign, cache);
0255 
0256   
0257   auto [chi2, ndf] = vtx.fitQuality();
0258 
0259   
0260   double trkChi2 = trackParametersChi2(trk.linearizedState, cache);
0261 
0262   
0263   double vtxPosChi2Update = vertexPositionChi2Update(vtx.fullPosition(), cache);
0264 
0265   
0266   chi2 += sign * (vtxPosChi2Update + trackWeight * trkChi2);
0267 
0268   
0269   ndf += sign * trackWeight * 2.;
0270 
0271   
0272   if constexpr (nDimVertex == 3) {
0273     vtx.fullPosition().head<3>() = cache.newVertexPos.template head<3>();
0274     vtx.fullCovariance().template topLeftCorner<3, 3>() =
0275         cache.newVertexCov.template topLeftCorner<3, 3>();
0276   } else if constexpr (nDimVertex == 4) {
0277     vtx.fullPosition() = cache.newVertexPos;
0278     vtx.fullCovariance() = cache.newVertexCov;
0279   }
0280   vtx.setFitQuality(chi2, ndf);
0281 
0282   if (sign == 1) {
0283     
0284     trk.chi2Track = trkChi2;
0285     trk.ndf = 2 * trackWeight;
0286   }
0287   
0288   else if (sign == -1) {
0289     trk.trackWeight = 0.;
0290   } else {
0291     throw std::invalid_argument(
0292         "Sign for adding/removing track must be +1 (add) or -1 (remove).");
0293   }
0294 }
0295 
0296 template <unsigned int nDimVertex>
0297 void updateTrackWithVertexImpl(TrackAtVertex& track, const Vertex& vtx) {
0298   if constexpr (nDimVertex != 3 && nDimVertex != 4) {
0299     throw std::invalid_argument(
0300         "The vertex dimension must either be 3 (when fitting the spatial "
0301         "coordinates) or 4 (when fitting the spatial coordinates + time).");
0302   }
0303 
0304   using VertexVector = ActsVector<nDimVertex>;
0305   using VertexMatrix = ActsSquareMatrix<nDimVertex>;
0306   constexpr unsigned int nBoundParams = nDimVertex + 2;
0307   using ParameterVector = ActsVector<nBoundParams>;
0308   using ParameterMatrix = ActsSquareMatrix<nBoundParams>;
0309   
0310   if (!track.isLinearized) {
0311     throw std::invalid_argument("TrackAtVertex object must be linearized.");
0312   }
0313 
0314   
0315   
0316   const VertexVector vtxPos = vtx.fullPosition().template head<nDimVertex>();
0317   
0318   const VertexMatrix vtxCov =
0319       vtx.fullCovariance().template block<nDimVertex, nDimVertex>(0, 0);
0320 
0321   
0322   const LinearizedTrack& linTrack = track.linearizedState;
0323   
0324   
0325   
0326   const ActsMatrix<nBoundParams, nDimVertex> posJac =
0327       linTrack.positionJacobian.block<nBoundParams, nDimVertex>(0, 0);
0328   
0329   const ActsMatrix<nBoundParams, 3> momJac =
0330       linTrack.momentumJacobian.block<nBoundParams, 3>(0, 0);
0331   
0332   const ParameterVector trkParams =
0333       linTrack.parametersAtPCA.head<nBoundParams>();
0334   
0335   const ParameterVector constTerm = linTrack.constantTerm.head<nBoundParams>();
0336   
0337   
0338   
0339   const ParameterMatrix trkParamWeight =
0340       linTrack.covarianceAtPCA.block<nBoundParams, nBoundParams>(0, 0)
0341           .inverse();
0342 
0343   
0344   Cache<nDimVertex> cache;
0345 
0346   
0347   
0348   calculateUpdate(vtx, linTrack, track.trackWeight, -1, cache);
0349 
0350   
0351   Vector3 newTrkMomentum = cache.wMat * momJac.transpose() * trkParamWeight *
0352                            (trkParams - constTerm - posJac * vtxPos);
0353 
0354   
0355   
0356   BoundVector newTrkParams(BoundVector::Zero());
0357 
0358   
0359   const auto correctedPhiTheta =
0360       Acts::detail::normalizePhiTheta(newTrkMomentum(0), newTrkMomentum(1));
0361   newTrkParams(BoundIndices::eBoundPhi) = correctedPhiTheta.first;     
0362   newTrkParams(BoundIndices::eBoundTheta) = correctedPhiTheta.second;  
0363   newTrkParams(BoundIndices::eBoundQOverP) = newTrkMomentum(2);        
0364 
0365   
0366   const ActsMatrix<nDimVertex, 3> crossCovVP =
0367       -vtxCov * posJac.transpose() * trkParamWeight * momJac * cache.wMat;
0368 
0369   
0370   VertexVector posDiff =
0371       vtxPos - cache.newVertexPos.template head<nDimVertex>();
0372 
0373   
0374   ParameterVector paramDiff =
0375       trkParams - (constTerm + posJac * vtxPos + momJac * newTrkMomentum);
0376 
0377   
0378   double chi2 =
0379       posDiff.dot(
0380           cache.newVertexWeight.template block<nDimVertex, nDimVertex>(0, 0) *
0381           posDiff) +
0382       paramDiff.dot(trkParamWeight * paramDiff);
0383 
0384   Acts::BoundMatrix newTrackCov = calculateTrackCovariance<nDimVertex>(
0385       cache.wMat, crossCovVP, vtxCov, newTrkParams);
0386 
0387   
0388   std::shared_ptr<PerigeeSurface> perigeeSurface =
0389       Surface::makeShared<PerigeeSurface>(vtxPos.template head<3>());
0390 
0391   BoundTrackParameters refittedPerigee =
0392       BoundTrackParameters(perigeeSurface, newTrkParams, std::move(newTrackCov),
0393                            track.fittedParams.particleHypothesis());
0394 
0395   
0396   track.fittedParams = refittedPerigee;
0397   track.chi2Track = chi2;
0398   track.ndf = 2 * track.trackWeight;
0399 
0400   return;
0401 }
0402 
0403 }