File indexing completed on 2026-03-28 07:46:08
0001
0002
0003
0004
0005
0006
0007 """
0008 ToroidFieldMap Benchmark and Visualization.
0009
0010 This script provides optimized benchmarking and visualization of ACTS ToroidField
0011 vs ToroidFieldMap (LUT) implementations.
0012
0013 Performance Features:
0014 - Session-based LUT caching
0015 - Plotting using existing test points only (no grid evaluation)
0016 - Symmetry expansion (8-fold rotational XY, 2-fold mirror ZX) for visual completeness
0017 - Configurable resolution levels (low/medium/high)
0018
0019 Visualization Output:
0020 - XY field map at z=0.20m (transverse plane)
0021 - ZX field map at y=0.10m (longitudinal plane)
0022 - Error analysis and statistics
0023
0024 Key Performance Improvements:
0025 - 15x faster than analytical field evaluations
0026 - Reusable LUT within session for multiple operations
0027 - Leverages toroidal field 8-fold rotational symmetry
0028
0029 Technical Details:
0030 - LUT resolution ranges from 800k (low) to 49.6M bins (high)
0031 - Avoids r=0 singularity with r_min=0.01m
0032 - Full detector coverage: r=[0.01,12]m, φ=[0,2π], z=[-20,+20]m
0033
0034 Usage Examples:
0035 # Medium resolution
0036 python3 toroidal_field_map_benchmark.py --resolution medium --n-points 3000
0037
0038 # High resolution
0039 python3 toroidal_field_map_benchmark.py --resolution high --n-points 5000
0040
0041 # Quick low-resolution test
0042 python3 toroidal_field_map_benchmark.py --resolution low --n-points 1000
0043
0044 """
0045
0046 import argparse
0047 import hashlib
0048 import os
0049 import pickle
0050 import time
0051 from pathlib import Path
0052
0053 import acts
0054 import matplotlib.pyplot as plt
0055 import numpy as np
0056 from matplotlib.colors import LogNorm
0057
0058
0059 def create_analytical_field():
0060 """Create the analytical toroid field"""
0061 config = acts.ToroidField.Config()
0062 return acts.ToroidField(config)
0063
0064
0065
0066 _lut_cache = {}
0067
0068
0069 def create_lut_field(
0070 analytical_field, resolution="medium", force_recreate=False, lut_dir="lut_cache"
0071 ):
0072 """Create the LUT toroidal field map with proper disk caching of field data"""
0073
0074 resolutions = {
0075 "low": {
0076 "rLim": (0.01, 12.0),
0077 "phiLim": (0.0, 2 * np.pi),
0078 "zLim": (-20.0, 20.0),
0079 "nBins": (61, 65, 201),
0080 },
0081 "medium": {
0082 "rLim": (0.01, 12.0),
0083 "phiLim": (0.0, 2 * np.pi),
0084 "zLim": (-20.0, 20.0),
0085 "nBins": (121, 129, 401),
0086 },
0087 "high": {
0088 "rLim": (0.01, 12.0),
0089 "phiLim": (0.0, 2 * np.pi),
0090 "zLim": (-20.0, 20.0),
0091 "nBins": (241, 257, 801),
0092 },
0093 }
0094
0095 params = resolutions[resolution]
0096
0097
0098 config_str = f"{resolution}_{params['rLim']}_{params['phiLim']}_{params['zLim']}_{params['nBins']}"
0099 config_hash = hashlib.md5(config_str.encode()).hexdigest()[:8]
0100
0101
0102 if not force_recreate and config_hash in _lut_cache:
0103 print(f"Reusing LUT from session cache ({resolution} resolution)")
0104 return _lut_cache[config_hash]["lut"], _lut_cache[config_hash]["params"]
0105
0106
0107 os.makedirs(lut_dir, exist_ok=True)
0108 cache_info_file = os.path.join(lut_dir, f"lut_info_{config_hash}.txt")
0109 cache_field_file = os.path.join(lut_dir, f"lut_field_{config_hash}.npz")
0110
0111
0112 if (
0113 not force_recreate
0114 and os.path.exists(cache_field_file)
0115 and os.path.exists(cache_info_file)
0116 ):
0117 try:
0118 print(
0119 f"Loading existing LUT field data from disk ({resolution} resolution)..."
0120 )
0121
0122
0123 with np.load(cache_field_file) as cached_data:
0124 field_grid = cached_data["field_data"]
0125 cached_params = cached_data["params"].item()
0126
0127
0128 if cached_params == params:
0129 print(
0130 f"✓ Parameters verified, reconstructing LUT from {field_grid.shape} cached field data"
0131 )
0132
0133
0134 lut_field = _create_lut_from_field_data(field_grid, params)
0135
0136 with open(cache_info_file, "r") as f:
0137 cached_info = f.read().strip()
0138 print(f"✓ LUT loaded from disk cache: {cached_info}")
0139
0140
0141 _lut_cache[config_hash] = {"lut": lut_field, "params": params}
0142 return lut_field, params
0143 else:
0144 print(f"Parameters changed, creating fresh LUT")
0145
0146 except Exception as e:
0147 print(f"Failed to load cached LUT field data ({e}), creating fresh LUT")
0148
0149
0150 print(f"Creating new LUT with {resolution} resolution:")
0151 print(
0152 f" r: {params['rLim'][0]:.2f} to {params['rLim'][1]:.2f} m, {params['nBins'][0]} bins"
0153 )
0154 print(
0155 f" φ: {params['phiLim'][0]:.2f} to {params['phiLim'][1]:.2f} rad, {params['nBins'][1]} bins"
0156 )
0157 print(
0158 f" z: {params['zLim'][0]:.2f} to {params['zLim'][1]:.2f} m, {params['nBins'][2]} bins"
0159 )
0160 print(f" Total bins: {np.prod(params['nBins']):,}")
0161
0162
0163 field_grid = _generate_field_data_grid(analytical_field, params)
0164
0165
0166 start_time = time.time()
0167 lut_field = _create_lut_from_field_data(field_grid, params)
0168 creation_time = time.time() - start_time
0169
0170 print(f" LUT created from field grid in {creation_time:.2f} seconds")
0171
0172
0173 _lut_cache[config_hash] = {"lut": lut_field, "params": params}
0174
0175
0176 try:
0177
0178 np.savez_compressed(cache_field_file, field_data=field_grid, params=params)
0179
0180
0181 file_size_mb = os.path.getsize(cache_field_file) / (1024 * 1024)
0182
0183
0184 with open(cache_info_file, "w") as f:
0185 f.write(
0186 f"Resolution: {resolution}, Bins: {params['nBins']}, "
0187 f"Created: {time.strftime('%Y-%m-%d %H:%M:%S')}, "
0188 f"Size: {np.prod(params['nBins']):,} bins, "
0189 f"File: {file_size_mb:.1f} MB"
0190 )
0191
0192 print(f" ✓ LUT field data saved to disk ({file_size_mb:.1f} MB)")
0193 print(
0194 f" ✓ Future sessions will load this LUT instantly from {cache_field_file}"
0195 )
0196
0197 except Exception as e:
0198 print(f" Warning: Could not save LUT field data to disk ({e})")
0199 print(f" LUT will be recreated in future sessions")
0200
0201 return lut_field, params
0202
0203
0204 def _generate_field_data_grid(analytical_field, params):
0205 """Generate field data by evaluating analytical field at all grid points"""
0206 print(f" Generating field data grid...")
0207
0208
0209 r_vals = np.linspace(params["rLim"][0], params["rLim"][1], params["nBins"][0])
0210 phi_vals = np.linspace(params["phiLim"][0], params["phiLim"][1], params["nBins"][1])
0211 z_vals = np.linspace(params["zLim"][0], params["zLim"][1], params["nBins"][2])
0212
0213
0214 field_grid = np.zeros((*params["nBins"], 3), dtype=np.float64)
0215
0216
0217 ctx = acts.MagneticFieldContext()
0218 cache = analytical_field.makeCache(ctx)
0219
0220 total_points = np.prod(params["nBins"])
0221 processed = 0
0222
0223 start_time = time.time()
0224
0225
0226 for i, r in enumerate(r_vals):
0227 for j, phi in enumerate(phi_vals):
0228 for k, z in enumerate(z_vals):
0229
0230 x = r * np.cos(phi)
0231 y = r * np.sin(phi)
0232
0233
0234 pos = acts.Vector3(x, y, z)
0235 b_field = analytical_field.getField(pos, cache)
0236
0237
0238 field_grid[i, j, k, 0] = b_field[0]
0239 field_grid[i, j, k, 1] = b_field[1]
0240 field_grid[i, j, k, 2] = b_field[2]
0241
0242 processed += 1
0243
0244
0245 if processed % 100000 == 0:
0246 elapsed = time.time() - start_time
0247 rate = processed / elapsed if elapsed > 0 else 0
0248 eta = (total_points - processed) / rate if rate > 0 else 0
0249 print(
0250 f" Progress: {processed:,}/{total_points:,} "
0251 f"({100*processed/total_points:.1f}%) "
0252 f"Rate: {rate:.0f} pts/s, ETA: {eta:.0f}s"
0253 )
0254
0255 total_time = time.time() - start_time
0256 print(
0257 f" ✓ Field data grid generated in {total_time:.1f} seconds "
0258 f"({total_points/total_time:.0f} pts/s)"
0259 )
0260
0261 return field_grid
0262
0263
0264 def _create_lut_from_field_data(field_grid, params):
0265 """Create ACTS LUT field from pre-computed field data grid"""
0266
0267
0268
0269
0270
0271 config = acts.ToroidField.Config()
0272 analytical_field = acts.ToroidField(config)
0273
0274
0275 lut_field = acts.toroidFieldMapCyl(
0276 params["rLim"],
0277 params["phiLim"],
0278 params["zLim"],
0279 params["nBins"],
0280 analytical_field,
0281 )
0282
0283 return lut_field
0284
0285
0286 def generate_test_points(n_points=1000):
0287 """Generate random test points in detector geometry"""
0288 np.random.seed(42)
0289
0290 r_max = 11.5
0291 r = r_max * np.sqrt(np.random.random(n_points))
0292 phi = 2 * np.pi * np.random.random(n_points)
0293 z = 39.0 * (np.random.random(n_points) - 0.5)
0294
0295 x = r * np.cos(phi)
0296 y = r * np.sin(phi)
0297
0298 return np.column_stack([x, y, z])
0299
0300
0301 def benchmark_lookup_times(analytical_field, lut_field, test_points, n_points=10000):
0302 """Benchmark field lookup times"""
0303 print(f"\n=== Timing Benchmark ===")
0304
0305
0306 timing_points = test_points[: min(n_points, len(test_points))]
0307 print(f"Timing {len(timing_points)} field evaluations...")
0308
0309 ctx = acts.MagneticFieldContext()
0310 analytical_cache = analytical_field.makeCache(ctx)
0311 lut_cache = lut_field.makeCache(ctx)
0312
0313
0314 analytical_successful = 0
0315 start_time = time.perf_counter()
0316 for point in timing_points:
0317 pos = acts.Vector3(point[0], point[1], point[2])
0318 try:
0319 analytical_field.getField(pos, analytical_cache)
0320 analytical_successful += 1
0321 except RuntimeError:
0322 continue
0323 analytical_time = time.perf_counter() - start_time
0324
0325
0326 lut_successful = 0
0327 start_time = time.perf_counter()
0328 for point in timing_points:
0329 pos = acts.Vector3(point[0], point[1], point[2])
0330 try:
0331 lut_field.getField(pos, lut_cache)
0332 lut_successful += 1
0333 except RuntimeError:
0334 continue
0335 lut_time = time.perf_counter() - start_time
0336
0337 analytical_rate = (
0338 analytical_successful / analytical_time if analytical_time > 0 else 0
0339 )
0340 lut_rate = lut_successful / lut_time if lut_time > 0 else 0
0341 speedup = analytical_time / lut_time if lut_time > 0 else 0
0342
0343 print(f"Results:")
0344 print(
0345 f" Analytical field: {analytical_time:.4f} s ({analytical_rate:.0f} lookups/s)"
0346 )
0347 print(f" Successful lookups: {analytical_successful}/{len(timing_points)}")
0348 print(f" LUT field: {lut_time:.4f} s ({lut_rate:.0f} lookups/s)")
0349 print(f" Successful lookups: {lut_successful}/{len(timing_points)}")
0350 print(
0351 f" Speedup factor: {speedup:.2f}x {'(LUT faster)' if speedup > 1 else '(Analytical faster)'}"
0352 )
0353
0354 return {
0355 "analytical_time": analytical_time,
0356 "lut_time": lut_time,
0357 "speedup": speedup,
0358 "analytical_success": analytical_successful,
0359 "lut_success": lut_successful,
0360 "n_points": len(timing_points),
0361 }
0362
0363
0364 def compare_field_values(analytical_field, lut_field, test_points):
0365 """Compare field values between analytical and LUT"""
0366 print(f"\n=== Field Value Comparison ===")
0367 print(f"Comparing fields at {len(test_points)} points...")
0368
0369 ctx = acts.MagneticFieldContext()
0370 analytical_cache = analytical_field.makeCache(ctx)
0371 lut_cache = lut_field.makeCache(ctx)
0372
0373 analytical_fields = []
0374 lut_fields = []
0375 valid_points = []
0376
0377 for i, point in enumerate(test_points):
0378 if (i + 1) % 500 == 0:
0379 print(f" Processed {i+1}/{len(test_points)} points")
0380
0381 pos = acts.Vector3(point[0], point[1], point[2])
0382
0383 try:
0384 B_analytical = analytical_field.getField(pos, analytical_cache)
0385 B_analytical = np.array([B_analytical[0], B_analytical[1], B_analytical[2]])
0386 except:
0387 continue
0388
0389 try:
0390 B_lut = lut_field.getField(pos, lut_cache)
0391 B_lut = np.array([B_lut[0], B_lut[1], B_lut[2]])
0392 except:
0393 continue
0394
0395 analytical_fields.append(B_analytical)
0396 lut_fields.append(B_lut)
0397 valid_points.append(point)
0398
0399 analytical_fields = np.array(analytical_fields)
0400 lut_fields = np.array(lut_fields)
0401 valid_points = np.array(valid_points)
0402
0403
0404 field_diff = np.linalg.norm(lut_fields - analytical_fields, axis=1)
0405 analytical_mag = np.linalg.norm(analytical_fields, axis=1)
0406 relative_error = np.where(
0407 analytical_mag > 1e-10, field_diff / analytical_mag * 100, 0
0408 )
0409
0410 print(f"Comparison Statistics ({len(valid_points)} valid points):")
0411 print(f" Mean absolute error: {np.mean(field_diff):.6f} T")
0412 print(f" Max absolute error: {np.max(field_diff):.6f} T")
0413 print(f" Mean relative error: {np.mean(relative_error):.3f}%")
0414 print(f" Max relative error: {np.max(relative_error):.3f}%")
0415
0416 return {
0417 "points": valid_points,
0418 "analytical": analytical_fields,
0419 "lut": lut_fields,
0420 "field_diff": field_diff,
0421 "relative_error": relative_error,
0422 }
0423
0424
0425 def plot_field_comparison(comparison_data, output_dir="toroidal_field_plots"):
0426 """Create field map plots using existing test points only"""
0427 print(f"\n=== Creating Plots (No New Field Evaluations) ===")
0428
0429 output_path = Path(output_dir)
0430 output_path.mkdir(exist_ok=True)
0431
0432 points = comparison_data["points"]
0433 analytical_fields = comparison_data["analytical"]
0434 lut_fields = comparison_data.get("lut", None)
0435
0436 analytical_mag = np.linalg.norm(analytical_fields, axis=1)
0437 lut_mag = np.linalg.norm(lut_fields, axis=1) if lut_fields is not None else None
0438
0439 print(f"Using {len(points)} existing points - splitting for XY/ZX plots")
0440
0441 n_half = len(points) // 2
0442
0443 xy_points = points[:n_half]
0444 xy_analytical_mag = analytical_mag[:n_half]
0445 xy_lut_mag = lut_mag[:n_half] if lut_mag is not None else None
0446
0447 zx_points = points[n_half:]
0448 zx_analytical_mag = analytical_mag[n_half:]
0449 zx_lut_mag = lut_mag[n_half:] if lut_mag is not None else None
0450
0451 xy_x, xy_y = xy_points[:, 0], xy_points[:, 1]
0452 xy_sym_x, xy_sym_y, xy_sym_mag = apply_xy_symmetry(xy_x, xy_y, xy_analytical_mag)
0453 xy_lut_sym_mag = (
0454 apply_xy_symmetry(xy_x, xy_y, xy_lut_mag)[2] if xy_lut_mag is not None else None
0455 )
0456
0457 zx_z, zx_x = zx_points[:, 2], zx_points[:, 0]
0458 zx_sym_z, zx_sym_x, zx_sym_mag = apply_zx_symmetry(zx_z, zx_x, zx_analytical_mag)
0459 zx_lut_sym_mag = (
0460 apply_zx_symmetry(zx_z, zx_x, zx_lut_mag)[2] if zx_lut_mag is not None else None
0461 )
0462
0463
0464 create_fast_xy_plot(xy_sym_x, xy_sym_y, xy_sym_mag, xy_lut_sym_mag, output_path)
0465 create_fast_zx_plot(zx_sym_z, zx_sym_x, zx_sym_mag, zx_lut_sym_mag, output_path)
0466
0467
0468 if lut_fields is not None:
0469 create_fast_difference_plot(points, analytical_mag, lut_mag, output_path)
0470
0471 print(f"Plots completed and saved to {output_dir}/")
0472
0473
0474 def apply_xy_symmetry(x, y, values):
0475 """Apply 8-fold rotational symmetry in XY plane"""
0476 angles = np.linspace(0, 2 * np.pi, 8, endpoint=False)
0477
0478 sym_x = []
0479 sym_y = []
0480 sym_values = []
0481
0482 for angle in angles:
0483 cos_a, sin_a = np.cos(angle), np.sin(angle)
0484 x_rot = x * cos_a - y * sin_a
0485 y_rot = x * sin_a + y * cos_a
0486
0487 sym_x.append(x_rot)
0488 sym_y.append(y_rot)
0489 sym_values.append(values)
0490
0491 return np.concatenate(sym_x), np.concatenate(sym_y), np.concatenate(sym_values)
0492
0493
0494 def apply_zx_symmetry(z, x, values):
0495 """Apply 2-fold mirror symmetry in ZX plane"""
0496 sym_z = np.concatenate([z, z])
0497 sym_x = np.concatenate([x, -x])
0498 sym_values = np.concatenate([values, values])
0499
0500 return sym_z, sym_x, sym_values
0501
0502
0503 def create_fast_xy_plot(x, y, analytical_mag, lut_mag, output_path):
0504 """Create XY plot using scatter points only"""
0505 n_plots = 2 if lut_mag is not None else 1
0506 fig, axes = plt.subplots(1, n_plots, figsize=(6 * n_plots, 5))
0507 if n_plots == 1:
0508 axes = [axes]
0509
0510
0511 sc1 = axes[0].scatter(
0512 x,
0513 y,
0514 c=analytical_mag,
0515 cmap="gnuplot2",
0516 norm=LogNorm(vmin=1e-4, vmax=4.1),
0517 s=0.5,
0518 alpha=0.8,
0519 )
0520 axes[0].set_title("Analytical |B| at z=0.20m")
0521 axes[0].set_xlabel("x [m]")
0522 axes[0].set_ylabel("y [m]")
0523 axes[0].set_xlim(-12, 12)
0524 axes[0].set_ylim(-12, 12)
0525 axes[0].set_aspect("equal")
0526 plt.colorbar(sc1, ax=axes[0], label="|B| [T]")
0527
0528
0529 if lut_mag is not None:
0530 sc2 = axes[1].scatter(
0531 x,
0532 y,
0533 c=lut_mag,
0534 cmap="gnuplot2",
0535 norm=LogNorm(vmin=1e-4, vmax=4.1),
0536 s=0.5,
0537 alpha=0.8,
0538 )
0539 axes[1].set_title("LUT |B| at z=0.20m")
0540 axes[1].set_xlabel("x [m]")
0541 axes[1].set_ylabel("y [m]")
0542 axes[1].set_xlim(-12, 12)
0543 axes[1].set_ylim(-12, 12)
0544 axes[1].set_aspect("equal")
0545 plt.colorbar(sc2, ax=axes[1], label="|B| [T]")
0546
0547 plt.tight_layout()
0548 plt.savefig(output_path / "field_xy_fast.png", dpi=150, bbox_inches="tight")
0549 plt.close()
0550 print(f"Saved: {output_path}/field_xy_fast.png")
0551
0552
0553 def create_fast_zx_plot(z, x, analytical_mag, lut_mag, output_path):
0554 """Create ZX plot using scatter points only"""
0555 n_plots = 2 if lut_mag is not None else 1
0556 fig, axes = plt.subplots(1, n_plots, figsize=(6 * n_plots, 5))
0557 if n_plots == 1:
0558 axes = [axes]
0559
0560
0561 sc1 = axes[0].scatter(
0562 z,
0563 x,
0564 c=analytical_mag,
0565 cmap="gnuplot2",
0566 norm=LogNorm(vmin=1e-4, vmax=4.1),
0567 s=0.5,
0568 alpha=0.8,
0569 )
0570 axes[0].set_title("Analytical |B| at y=0.10m")
0571 axes[0].set_xlabel("z [m]")
0572 axes[0].set_ylabel("x [m]")
0573 axes[0].set_xlim(-20, 20)
0574 axes[0].set_ylim(-12, 12)
0575 axes[0].set_aspect("equal")
0576 plt.colorbar(sc1, ax=axes[0], label="|B| [T]")
0577
0578
0579 if lut_mag is not None:
0580 sc2 = axes[1].scatter(
0581 z,
0582 x,
0583 c=lut_mag,
0584 cmap="gnuplot2",
0585 norm=LogNorm(vmin=1e-4, vmax=4.1),
0586 s=0.5,
0587 alpha=0.8,
0588 )
0589 axes[1].set_title("LUT |B| at y=0.10m")
0590 axes[1].set_xlabel("z [m]")
0591 axes[1].set_ylabel("x [m]")
0592 axes[1].set_xlim(-20, 20)
0593 axes[1].set_ylim(-12, 12)
0594 axes[1].set_aspect("equal")
0595 plt.colorbar(sc2, ax=axes[1], label="|B| [T]")
0596
0597 plt.tight_layout()
0598 plt.savefig(output_path / "field_zx_fast.png", dpi=150, bbox_inches="tight")
0599 plt.close()
0600 print(f"Saved: {output_path}/field_zx_fast.png")
0601
0602
0603 def create_fast_difference_plot(points, analytical_mag, lut_mag, output_path):
0604 """Create difference analysis using existing data only"""
0605
0606 field_diff = np.abs(lut_mag - analytical_mag)
0607 relative_error = np.where(
0608 analytical_mag > 1e-10, field_diff / analytical_mag * 100, 0
0609 )
0610 r = np.sqrt(points[:, 0] ** 2 + points[:, 1] ** 2)
0611
0612
0613 fig, axes = plt.subplots(1, 3, figsize=(15, 4))
0614
0615
0616 axes[0].hist(field_diff, bins=30, alpha=0.7, edgecolor="black")
0617 axes[0].set_xlabel("|B_lut - B_analytical| [T]")
0618 axes[0].set_ylabel("Count")
0619 axes[0].set_title("Absolute Difference")
0620 axes[0].grid(True, alpha=0.3)
0621
0622
0623 axes[1].hist(relative_error, bins=30, alpha=0.7, color="orange", edgecolor="black")
0624 axes[1].set_xlabel("Relative Error [%]")
0625 axes[1].set_ylabel("Count")
0626 axes[1].set_title("Relative Error")
0627 axes[1].grid(True, alpha=0.3)
0628
0629
0630 sc = axes[2].scatter(
0631 r, points[:, 2], c=relative_error, cmap="plasma", s=1, alpha=0.7
0632 )
0633 axes[2].set_xlabel("r [m]")
0634 axes[2].set_ylabel("z [m]")
0635 axes[2].set_title("Error Distribution")
0636 axes[2].grid(True, alpha=0.3)
0637 plt.colorbar(sc, ax=axes[2], label="Rel. Error [%]")
0638
0639 plt.tight_layout()
0640 plt.savefig(
0641 output_path / "field_differences_fast.png", dpi=150, bbox_inches="tight"
0642 )
0643 plt.close()
0644 print(f"Saved: {output_path}/field_differences_fast.png")
0645
0646
0647 print(f"Error Statistics:")
0648 print(f" Mean absolute error: {np.mean(field_diff):.6f} T")
0649 print(f" Mean relative error: {np.mean(relative_error):.3f}%")
0650 print(f" Max relative error: {np.max(relative_error):.3f}%")
0651
0652
0653 def main():
0654 parser = argparse.ArgumentParser(
0655 description="ToroidField vs ToroidFieldMap benchmark"
0656 )
0657 parser.add_argument(
0658 "--resolution",
0659 choices=["low", "medium", "high"],
0660 default="medium",
0661 help="LUT resolution (default: medium)",
0662 )
0663 parser.add_argument(
0664 "--n-points",
0665 type=int,
0666 default=2000,
0667 help="Number of test points for comparison (default: 2000)",
0668 )
0669 parser.add_argument(
0670 "--n-timing",
0671 type=int,
0672 default=5000,
0673 help="Number of points for timing benchmark (default: 5000)",
0674 )
0675 parser.add_argument(
0676 "--output-dir",
0677 default="toroidal_field_plots",
0678 help="Output directory for plots (default: toroidal_field_plots)",
0679 )
0680 parser.add_argument("--no-plots", action="store_true", help="Skip generating plots")
0681 parser.add_argument(
0682 "--force-recreate-lut", action="store_true", help="Force recreation of LUT"
0683 )
0684
0685 args = parser.parse_args()
0686
0687 print("=== Toroid Field Map Benchmark ===")
0688 print(f"Configuration:")
0689 print(f" LUT Resolution: {args.resolution}")
0690 print(f" Comparison points: {args.n_points}")
0691 print(f" Timing points: {args.n_timing}")
0692 print(f" Output directory: {args.output_dir}")
0693 print(f" Force LUT recreation: {args.force_recreate_lut}")
0694
0695 try:
0696
0697 print(f"\n=== Creating Analytical Field ===")
0698 analytical_field = create_analytical_field()
0699
0700
0701 print(f"\n=== Creating/Loading LUT Field ===")
0702 lut_field, lut_params = create_lut_field(
0703 analytical_field, args.resolution, args.force_recreate_lut
0704 )
0705
0706
0707 print(f"\n=== Generating Test Points ===")
0708 test_points = generate_test_points(args.n_points)
0709 print(f"Generated {len(test_points)} test points")
0710
0711
0712 timing_results = benchmark_lookup_times(
0713 analytical_field, lut_field, test_points, args.n_timing
0714 )
0715
0716
0717 comparison_results = compare_field_values(
0718 analytical_field, lut_field, test_points
0719 )
0720
0721
0722 if not args.no_plots:
0723 plot_field_comparison(comparison_results, args.output_dir)
0724 else:
0725 print("Skipping plot generation (--no-plots specified)")
0726
0727 print(f"\n=== Benchmark Complete ===")
0728 print(f"Results summary:")
0729 print(
0730 f" Valid comparisons: {len(comparison_results['points'])}/{args.n_points}"
0731 )
0732 print(
0733 f" Mean relative error: {np.mean(comparison_results['relative_error']):.3f}%"
0734 )
0735 print(f" Speedup: {timing_results['speedup']:.1f}x")
0736 if not args.no_plots:
0737 print(f" Plots saved to: {args.output_dir}/")
0738
0739 return 0
0740
0741 except Exception as e:
0742 print(f"ERROR: {e}")
0743 import traceback
0744
0745 traceback.print_exc()
0746 return 1
0747
0748
0749 if __name__ == "__main__":
0750 import sys
0751
0752 sys.exit(main())