Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:12:11

0001 #!/usr/bin/env python3
0002 from pathlib import Path
0003 from typing import Optional, Dict, List
0004 import re
0005 import enum
0006 import sys
0007 
0008 import uproot
0009 import typer
0010 import hist
0011 import pydantic
0012 import yaml
0013 import matplotlib.pyplot
0014 import awkward
0015 
0016 
0017 class Model(pydantic.BaseModel):
0018     class Config:
0019         extra = "forbid"
0020 
0021 
0022 class HistConfig(Model):
0023     nbins: int = 100
0024     min: Optional[float] = None
0025     max: Optional[float] = None
0026     label: Optional[str] = None
0027 
0028 
0029 class Extra(HistConfig):
0030     expression: str
0031     name: str
0032 
0033 
0034 class Config(Model):
0035     histograms: Dict[str, HistConfig] = pydantic.Field(default_factory=dict)
0036     extra_histograms: List[Extra] = pydantic.Field(default_factory=list)
0037     exclude: List[str] = pydantic.Field(default_factory=list)
0038 
0039 
0040 class Mode(str, enum.Enum):
0041     recreate = "recreate"
0042     update = "update"
0043 
0044 
0045 def main(
0046     infile: Path = typer.Argument(
0047         ..., exists=True, dir_okay=False, help="The input ROOT file"
0048     ),
0049     treename: str = typer.Argument(..., help="The tree to look up branches from"),
0050     outpath: Path = typer.Argument(
0051         "outfile", dir_okay=False, help="The output ROOT file"
0052     ),
0053     config_file: Optional[Path] = typer.Option(
0054         None,
0055         "--config",
0056         "-c",
0057         exists=True,
0058         dir_okay=False,
0059         help="A config file following the input spec. By default, all branches will be plotted.",
0060     ),
0061     mode: Mode = typer.Option(Mode.recreate, help="Mode to open ROOT file in"),
0062     plots: Optional[Path] = typer.Option(
0063         None,
0064         "--plots",
0065         "-p",
0066         file_okay=False,
0067         help="If set, output plots individually to this directory",
0068     ),
0069     plot_format: str = typer.Option(
0070         "pdf", "--plot-format", "-f", help="Format to write plots in if --plots is set"
0071     ),
0072     silent: bool = typer.Option(
0073         False, "--silent", "-s", help="Do not print any output"
0074     ),
0075     dump_yml: bool = typer.Option(False, help="Print axis ranges as yml"),
0076 ):
0077     """
0078     Script to plot all branches in a TTree from a ROOT file, with optional configurable binning and ranges.
0079     Also allows setting extra expressions to be plotted as well.
0080     """
0081 
0082     rf = uproot.open(infile)
0083     tree = rf[treename]
0084 
0085     outfile = getattr(uproot, mode.value)(outpath)
0086 
0087     if config_file is None:
0088         config = Config()
0089     else:
0090         with config_file.open() as fh:
0091             config = Config.model_validate(yaml.safe_load(fh))
0092 
0093     histograms = {}
0094 
0095     if not silent:
0096         print(config.extra_histograms, file=sys.stderr)
0097 
0098     for df in tree.iterate(library="ak", how=dict):
0099         for col in df.keys():
0100             if any([re.match(ex, col) for ex in config.exclude]):
0101                 continue
0102             h = histograms.get(col)
0103             values = awkward.flatten(df[col], axis=None)
0104 
0105             if len(values) == 0:
0106                 print(f"WARNING: Branch '{col}' is empty. Skipped.")
0107                 continue
0108 
0109             if h is None:
0110                 # try to find config
0111                 found = None
0112                 for ex, data in config.histograms.items():
0113                     if re.match(ex, col):
0114                         found = data.model_copy()
0115                         print(
0116                             "Found HistConfig",
0117                             ex,
0118                             "for",
0119                             col,
0120                             ":",
0121                             found,
0122                             file=sys.stderr,
0123                         )
0124 
0125                 if found is None:
0126                     found = HistConfig()
0127 
0128                 if found.min is None:
0129                     found.min = awkward.min(values)
0130 
0131                 if found.max is None:
0132                     found.max = awkward.max(values)
0133 
0134                 if found.min == found.max:
0135                     found.min -= 1
0136                     found.max += 1
0137 
0138                 h = hist.Hist(
0139                     hist.axis.Regular(
0140                         found.nbins, found.min, found.max, name=found.label or col
0141                     )
0142                 )
0143 
0144                 histograms[col] = h
0145             h.fill(values)
0146 
0147             for extra in config.extra_histograms:
0148                 h = histograms.get(extra.name)
0149                 #  calc = pandas.eval(extra.expression, target=df)
0150                 calc = eval(extra.expression)
0151                 values = awkward.flatten(calc, axis=None)
0152                 if h is None:
0153                     if extra.min is None:
0154                         extra.min = awkward.min(values)
0155                     if extra.max is None:
0156                         extra.max = awkward.max(values)
0157 
0158                 if extra.min == extra.max:
0159                     extra.min -= 1
0160                     extra.max += 1
0161 
0162                 h = hist.Hist(
0163                     hist.axis.Regular(
0164                         extra.nbins,
0165                         extra.min,
0166                         extra.max,
0167                         name=extra.label or extra.name,
0168                     )
0169                 )
0170 
0171                 histograms[extra.name] = h
0172                 h.fill(values)
0173 
0174     if plots is not None:
0175         plots.mkdir(parents=True, exist_ok=True)
0176 
0177     for k, h in histograms.items():
0178         if not silent:
0179             if dump_yml:
0180                 ax = h.axes[0]
0181                 s = """
0182 {k}:
0183   nbins: {b}
0184   min: {min}
0185   max: {max}
0186                 """.format(
0187                     k=k, b=len(ax.edges) - 1, min=ax.edges[0], max=ax.edges[-1]
0188                 )
0189                 print(s)
0190             else:
0191                 print(k, h.axes[0])
0192         outfile[k] = h
0193 
0194         if plots is not None:
0195             fig, ax = matplotlib.pyplot.subplots()
0196 
0197             h.plot(ax=ax, flow=None)
0198 
0199             fig.tight_layout()
0200             fig.savefig(str(plots / f"{k}.{plot_format}"))
0201             matplotlib.pyplot.close()
0202 
0203 
0204 if __name__ == "__main__":
0205     typer.run(main)