Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:15:51

0001 # SPDX-License-Identifier: LGPL-3.0-or-later
0002 # Copyright (C) 2023 Chao Peng
0003 '''
0004     A script to visualize the fibers of some grids from BEMC ScFi part
0005     use case:
0006     python scripts/subdetector_tests/draw_bemc_scfi_grids.py -c epic_craterlake.xml
0007 '''
0008 import os
0009 import ROOT
0010 import dd4hep
0011 import DDRec
0012 import argparse
0013 import numpy as np
0014 from pydoc import locate
0015 from matplotlib import pyplot as plt
0016 from matplotlib.patches import Circle
0017 from matplotlib.collections import PatchCollection
0018 import matplotlib.ticker as ticker
0019 from collections import OrderedDict
0020 
0021 
0022 # helper function to do some type conversion for python wrapper of C++ function
0023 def dict_to_cpp_vec(my_dict, dtype='int'):
0024     # use ROOT to make sure the type is correct
0025     vol_ids = ROOT.std.vector('pair<string, {}>'.format(dtype))()
0026     dcast = locate(dtype)
0027     for field, fid in my_dict.items():
0028         vol_ids.push_back((field, dcast(fid)))
0029     return vol_ids
0030 
0031 
0032 # helper function to collect fibers under a grid
0033 def get_grid_fibers(det_elem, vol_man, id_conv, id_dict):
0034     # locate the nearest DetElement
0035     id_desc = vol_man.idSpec()
0036     try:
0037         # get fiber radius
0038         id_dict.update({'fiber': 1})
0039         fid = id_desc.encode(dict_to_cpp_vec(id_dict))
0040         # NOTE: for tube geometry, and it needs a cm to mm conversion
0041         fr = id_conv.cellDimensions(fid)[0]/2./10.
0042 
0043         # get the lowest level DetElement
0044         sdet = id_conv.findDetElement(id_conv.position(fid))
0045         gtrans = sdet.nominal().worldTransformation()
0046 
0047         # get grid node (it's not a DetElement)
0048         id_dict.update({'fiber': 0})
0049         gid = id_desc.encode(dict_to_cpp_vec(id_dict))
0050         gnode = id_conv.findContext(gid).volumePlacement()
0051         # print(id_desc.decoder().valueString(gid))
0052         grpos = id_conv.position(gid)
0053         grpos = np.array([grpos.X(), grpos.Y(), grpos.Z()])
0054     except Exception:
0055         return None, None
0056 
0057     # use TGeoNode to get center positions
0058     # here it can also use id_conv to do the same thing with cellIDs,
0059     # but it's much slower (adds an additional lookup process)
0060     fibers = []
0061     for i in np.arange(gnode.GetNdaughters()):
0062         fnode = gnode.GetDaughter(int(i))
0063         # NOTE, this is defined in geometry plugin, fiber_core is the only wanted sensitive detector
0064         if 'fiber' not in fnode.GetName():
0065             continue
0066         fpos = np.array([0., 0., 0.])
0067         gpos = np.array([0., 0., 0.])
0068         pos = np.array([0., 0., 0.])
0069         fnode.LocalToMaster(np.array([0., 0., 0.]), fpos)
0070         gnode.LocalToMaster(fpos, gpos)
0071         # the two method below are equivalent
0072         gtrans.LocalToMaster(gpos, pos)
0073         # detelem.nominal().localToWorld(gpos, pos)
0074         """ a test with converter
0075         if i < 50:
0076             id_dict.update({'fiber': int(len(fibers) + 1)})
0077             fid = id_desc.encode(dict_to_cpp_vec(id_dict))
0078             tpos = id_conv.position(fid)
0079             print(i,
0080                   fnode.GetName(),
0081                   np.asarray([tpos.X(), tpos.Y(), tpos.Z()]),
0082                   pos,
0083                   fpos,
0084                   gpos)
0085         """
0086         fibers.append(np.hstack([pos, fr]))
0087 
0088     return np.array(fibers), grpos
0089 
0090 
0091 if __name__ == '__main__':
0092     parser = argparse.ArgumentParser()
0093     parser.add_argument(
0094             'compact',
0095             help='Top-level xml file of the detector description.'
0096             )
0097     parser.add_argument(
0098             '--detector', default='EcalBarrelScFi',
0099             dest='detector',
0100             help='Detector name.'
0101             )
0102     parser.add_argument(
0103             '--readout', default='EcalBarrelScFiHits',
0104             help='Readout class for the detector.'
0105             )
0106     parser.add_argument(
0107             '--grid-path', default='module:1,layer:6,slice:1,grid:3',
0108             help='Path down to a grid volume to be centered at the plot, with the format \"field1:i1,field2:i2,...\"',
0109             )
0110     parser.add_argument(
0111             '--outdir',
0112             dest='outdir', default='.',
0113             help='Output directory.'
0114             )
0115     parser.add_argument(
0116             '--adj-nlayers',
0117             dest='nlayers', type=int, default=2,
0118             help='number of adjacent layers to draw (+-n).'
0119             )
0120     parser.add_argument(
0121             '--adj-ngrids',
0122             dest='ngrids', type=int, default=2,
0123             help='number of adjacent grids to draw (+-n).'
0124             )
0125     parser.add_argument(
0126             '--window-size',
0127             dest='wsize', type=float, default=4.,
0128             help='Plot window size (mm).'
0129             )
0130     parser.add_argument(
0131             '--no-marker', action='store_true',
0132             help='Switch on to not draw a marker for each grid\'s center',
0133             )
0134     parser.add_argument(
0135             '--no-fiber-edge', action='store_true',
0136             help='Switch on to not draw fiber edge, might be helpful for a crowded plot.',
0137             )
0138     args = parser.parse_args()
0139 
0140     # initialize dd4hep detector
0141     desc = dd4hep.Detector.getInstance()
0142     desc.fromXML(args.compact)
0143 
0144     # search the detector
0145     det = desc.world()
0146     try:
0147         det = det.child(args.detector)
0148     except Exception:
0149         print('Failed to find detector {} from \"{}\"'.format(args.detector, args.compact))
0150         print('Available detectors are listed below:')
0151         for n, d in desc.world().children():
0152             print(' --- detector: {}'.format(n))
0153         exit(-1)
0154 
0155     # build a volume manager so it triggers populating the volume IDs
0156     vman = dd4hep.VolumeManager(det, desc.readout(args.readout))
0157     converter = DDRec.CellIDPositionConverter(desc)
0158 
0159     fields = OrderedDict([['system', det.id()]] + [v.split(':') for v in args.grid_path.split(',')])
0160     layer = int(fields.get('layer'))
0161     grid = int(fields.get('grid'))
0162     # add adjacent layers and grids, always put the central one (0, 0) first
0163     id_dicts = []
0164     for dl in np.hstack([np.arange(0, args.nlayers + 1), np.arange(-1, -args.nlayers - 1, step=-1)]):
0165         for dg in np.hstack([np.arange(0, args.ngrids + 1), np.arange(-1, -args.ngrids - 1, step=-1)]):
0166             if layer + dl < 1 or grid + dg < 1:
0167                 continue
0168             new_dict = fields.copy()
0169             new_dict.update({'layer': layer + dl, 'grid': grid + dg})
0170             id_dicts.append(new_dict)
0171 
0172     # plot fibers in the grid
0173     fig, ax = plt.subplots(figsize=(12, 12), dpi=160)
0174     # default color cycle
0175     colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
0176     for i, ids in enumerate(id_dicts):
0177         # color index number
0178         ic = (ids.get('grid') + (ids.get('layer') % 2)*4 - 1) % len(colors)
0179         c = colors[ic]
0180         fibers, gr_pos = get_grid_fibers(det, vman, converter, ids)
0181         if fibers is None:
0182             print('ignored {} because the volume might not exist.'.format(ids))
0183             continue
0184 
0185         patches = []
0186         for fi in fibers:
0187             patches.append(Circle((fi[0], fi[1]), fi[3]))
0188         ec = 'k' if not args.no_fiber_edge else c
0189         p = PatchCollection(patches, alpha=0.6, facecolors=(c,), edgecolors=(ec,))
0190         if not args.no_marker:
0191             ax.plot(gr_pos[0], gr_pos[1], marker='P', mfc=c, mec='k', ms=9, label='grid {}'.format(ids['grid']))
0192         ax.add_collection(p)
0193         # center at the first entry
0194         if i == 0:
0195             ax.set_xlim(gr_pos[0] - args.wsize, gr_pos[0] + args.wsize)
0196             ax.set_ylim(gr_pos[1] - args.wsize, gr_pos[1] + args.wsize)
0197 
0198     # ax.legend(fontsize=22)
0199     ax.tick_params(labelsize=20, direction='in')
0200     ax.set_xlabel('X (mm)', fontsize=22)
0201     ax.set_ylabel('Y (mm)', fontsize=22)
0202     ax.set_title('Centered at {}'.format('/'.join(['{}{}'.format(k, v) for k, v in fields.items()])), fontsize=22)
0203     fig.savefig(os.path.join(args.outdir, 'grid_fibers.png'))