File indexing completed on 2025-02-22 10:31:21
0001
0002
0003
0004
0005
0006
0007
0008 #pragma once
0009
0010 #include <cmath>
0011
0012 #include "corecel/Macros.hh"
0013 #include "corecel/Types.hh"
0014 #include "corecel/math/Algorithms.hh"
0015 #include "corecel/math/NumericLimits.hh"
0016 #include "corecel/math/SoftEqual.hh"
0017
0018 #include "FieldDriverOptions.hh"
0019 #include "Types.hh"
0020
0021 #include "detail/FieldUtils.hh"
0022
0023 namespace celeritas
0024 {
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058 template<class StepperT>
0059 class FieldDriver
0060 {
0061 public:
0062
0063 inline CELER_FUNCTION
0064 FieldDriver(FieldDriverOptions const& options, StepperT&& perform_step);
0065
0066
0067 inline CELER_FUNCTION DriverResult advance(real_type step,
0068 OdeState const& state);
0069
0070
0071
0072 inline CELER_FUNCTION DriverResult accurate_advance(
0073 real_type step, OdeState const& state, real_type hinitial) const;
0074
0075
0076
0077 CELER_FUNCTION short int max_substeps() const
0078 {
0079 return options_.max_substeps;
0080 }
0081
0082 CELER_FUNCTION real_type minimum_step() const
0083 {
0084 return options_.minimum_step;
0085 }
0086
0087
0088 CELER_FUNCTION real_type delta_intersection() const
0089 {
0090 return options_.delta_intersection;
0091 }
0092
0093 private:
0094
0095
0096
0097 FieldDriverOptions const& options_;
0098
0099
0100 StepperT apply_step_;
0101
0102
0103 real_type max_chord_{numeric_limits<real_type>::infinity()};
0104
0105
0106
0107
0108 struct ChordSearch
0109 {
0110 DriverResult end;
0111 real_type err_sq;
0112 };
0113
0114 struct Integration
0115 {
0116 DriverResult end;
0117 real_type proposed_step;
0118 };
0119
0120
0121
0122
0123 inline CELER_FUNCTION ChordSearch
0124 find_next_chord(real_type step, OdeState const& state) const;
0125
0126
0127 inline CELER_FUNCTION Integration
0128 integrate_step(real_type step, OdeState const& state) const;
0129
0130
0131 inline CELER_FUNCTION Integration one_good_step(real_type step,
0132 OdeState const& state) const;
0133
0134
0135 inline CELER_FUNCTION real_type new_step_scale(real_type error_sq) const;
0136
0137
0138
0139 static CELER_CONSTEXPR_FUNCTION real_type half() { return 0.5; }
0140 };
0141
0142
0143
0144
0145 template<class StepperT>
0146 CELER_FUNCTION
0147 FieldDriver(FieldDriverOptions const&, StepperT&&) -> FieldDriver<StepperT>;
0148
0149
0150
0151
0152
0153
0154
0155 template<class StepperT>
0156 CELER_FUNCTION
0157 FieldDriver<StepperT>::FieldDriver(FieldDriverOptions const& options,
0158 StepperT&& stepper)
0159 : options_(options), apply_step_(::celeritas::forward<StepperT>(stepper))
0160 {
0161 CELER_EXPECT(options_);
0162 }
0163
0164
0165
0166
0167
0168
0169
0170
0171
0172
0173
0174
0175
0176
0177
0178
0179
0180 template<class StepperT>
0181 CELER_FUNCTION DriverResult
0182 FieldDriver<StepperT>::advance(real_type step, OdeState const& state)
0183 {
0184 if (step <= options_.minimum_step)
0185 {
0186
0187 DriverResult result;
0188 result.state = apply_step_(step, state).end_state;
0189 result.step = step;
0190 return result;
0191 }
0192
0193
0194
0195 ChordSearch output
0196 = this->find_next_chord(celeritas::min(step, max_chord_), state);
0197 CELER_ASSERT(output.end.step <= step);
0198 if (output.end.step < step)
0199 {
0200
0201
0202 max_chord_ = output.end.step * (1 / options_.min_chord_shrink);
0203 }
0204
0205 if (output.err_sq > 1)
0206 {
0207
0208
0209 real_type next_step = step * this->new_step_scale(output.err_sq);
0210 output.end = this->accurate_advance(output.end.step, state, next_step);
0211 }
0212
0213 CELER_ENSURE(output.end.step > 0 && output.end.step <= step);
0214 return output.end;
0215 }
0216
0217
0218
0219
0220
0221 template<class StepperT>
0222 CELER_FUNCTION auto FieldDriver<StepperT>::find_next_chord(
0223 real_type step, OdeState const& state) const -> ChordSearch
0224 {
0225
0226 ChordSearch output;
0227
0228 bool succeeded = false;
0229 auto remaining_steps = options_.max_nsteps;
0230 FieldStepperResult result;
0231
0232 do
0233 {
0234
0235 result = apply_step_(step, state);
0236
0237
0238
0239 real_type dchord = detail::distance_chord(
0240 state.pos, result.mid_state.pos, result.end_state.pos);
0241
0242 if (dchord > options_.delta_chord + options_.dchord_tol)
0243 {
0244
0245 real_type scale_step = max(std::sqrt(options_.delta_chord / dchord),
0246 options_.min_chord_shrink);
0247 step *= scale_step;
0248 }
0249 else
0250 {
0251 succeeded = true;
0252 }
0253 } while (!succeeded && --remaining_steps > 0);
0254
0255
0256 output.end.step = step;
0257 output.end.state = result.end_state;
0258 output.err_sq = detail::rel_err_sq(result.err_state, step, state.mom)
0259 / ipow<2>(options_.epsilon_rel_max);
0260
0261 return output;
0262 }
0263
0264
0265
0266
0267
0268
0269
0270
0271
0272 template<class StepperT>
0273 CELER_FUNCTION DriverResult FieldDriver<StepperT>::accurate_advance(
0274 real_type step, OdeState const& state, real_type hinitial) const
0275 {
0276 CELER_ASSERT(step > 0);
0277
0278
0279 real_type end_curve_length = step;
0280
0281
0282
0283
0284
0285 real_type h
0286 = ((hinitial > options_.initial_step_tol * step) && (hinitial < step))
0287 ? hinitial
0288 : step;
0289 real_type h_threshold = options_.epsilon_step * step;
0290
0291
0292 Integration output;
0293 output.end.state = state;
0294
0295
0296 bool succeeded = false;
0297 real_type curve_length = 0;
0298 auto remaining_steps = options_.max_nsteps;
0299
0300 do
0301 {
0302 CELER_ASSERT(h > 0);
0303 output = this->integrate_step(h, output.end.state);
0304
0305 curve_length += output.end.step;
0306
0307 if (h < h_threshold || curve_length >= end_curve_length)
0308 {
0309 succeeded = true;
0310 }
0311 else
0312 {
0313 h = celeritas::min(
0314 celeritas::max(output.proposed_step, options_.minimum_step),
0315 end_curve_length - curve_length);
0316 }
0317 } while (!succeeded && --remaining_steps > 0);
0318
0319
0320
0321 CELER_ENSURE(curve_length > 0
0322 && (curve_length <= step || soft_equal(curve_length, step)));
0323 output.end.step = min(curve_length, step);
0324 return output.end;
0325 }
0326
0327
0328
0329
0330
0331
0332
0333 template<class StepperT>
0334 CELER_FUNCTION auto
0335 FieldDriver<StepperT>::integrate_step(real_type step,
0336 OdeState const& state) const -> Integration
0337 {
0338 CELER_EXPECT(step > 0);
0339
0340
0341 Integration output;
0342
0343 if (step > options_.minimum_step)
0344 {
0345 output = this->one_good_step(step, state);
0346 }
0347 else
0348 {
0349
0350 FieldStepperResult result = apply_step_(step, state);
0351
0352
0353 output.end.state = result.end_state;
0354 output.end.step = step;
0355
0356
0357 real_type err_sq = detail::rel_err_sq(result.err_state, step, state.mom)
0358 / ipow<2>(options_.epsilon_rel_max);
0359 output.proposed_step = step * this->new_step_scale(err_sq);
0360 }
0361
0362 return output;
0363 }
0364
0365
0366
0367
0368
0369
0370 template<class StepperT>
0371 CELER_FUNCTION auto
0372 FieldDriver<StepperT>::one_good_step(real_type step,
0373 OdeState const& state) const -> Integration
0374 {
0375
0376 Integration output;
0377
0378
0379 bool succeeded = false;
0380 size_type remaining_steps = options_.max_nsteps;
0381 real_type err_sq;
0382 FieldStepperResult result;
0383
0384 do
0385 {
0386 result = apply_step_(step, state);
0387
0388 err_sq = detail::rel_err_sq(result.err_state, step, state.mom)
0389 / ipow<2>(options_.epsilon_rel_max);
0390
0391 if (err_sq > 1)
0392 {
0393
0394 step *= max(this->new_step_scale(err_sq),
0395 options_.max_stepping_decrease);
0396 }
0397 else
0398 {
0399
0400 succeeded = true;
0401 }
0402 } while (!succeeded && --remaining_steps > 0);
0403
0404
0405 output.end.state = result.end_state;
0406 output.end.step = step;
0407 output.proposed_step
0408 = step
0409 * min(this->new_step_scale(err_sq), options_.max_stepping_increase);
0410
0411 return output;
0412 }
0413
0414
0415
0416
0417
0418 template<class StepperT>
0419 CELER_FUNCTION real_type
0420 FieldDriver<StepperT>::new_step_scale(real_type err_sq) const
0421 {
0422 CELER_ASSERT(err_sq >= 0);
0423 return options_.safety
0424 * fastpow(err_sq,
0425 half() * (err_sq > 1 ? options_.pshrink : options_.pgrow));
0426 }
0427
0428
0429 }