File indexing completed on 2025-01-18 09:12:11
0001
0002 from pathlib import Path
0003 from typing import Optional, Dict, List
0004 import re
0005 import enum
0006 import sys
0007
0008 import uproot
0009 import typer
0010 import hist
0011 import pydantic
0012 import yaml
0013 import matplotlib.pyplot
0014 import awkward
0015
0016
0017 class Model(pydantic.BaseModel):
0018 class Config:
0019 extra = "forbid"
0020
0021
0022 class HistConfig(Model):
0023 nbins: int = 100
0024 min: Optional[float] = None
0025 max: Optional[float] = None
0026 label: Optional[str] = None
0027
0028
0029 class Extra(HistConfig):
0030 expression: str
0031 name: str
0032
0033
0034 class Config(Model):
0035 histograms: Dict[str, HistConfig] = pydantic.Field(default_factory=dict)
0036 extra_histograms: List[Extra] = pydantic.Field(default_factory=list)
0037 exclude: List[str] = pydantic.Field(default_factory=list)
0038
0039
0040 class Mode(str, enum.Enum):
0041 recreate = "recreate"
0042 update = "update"
0043
0044
0045 def main(
0046 infile: Path = typer.Argument(
0047 ..., exists=True, dir_okay=False, help="The input ROOT file"
0048 ),
0049 treename: str = typer.Argument(..., help="The tree to look up branches from"),
0050 outpath: Path = typer.Argument(
0051 "outfile", dir_okay=False, help="The output ROOT file"
0052 ),
0053 config_file: Optional[Path] = typer.Option(
0054 None,
0055 "--config",
0056 "-c",
0057 exists=True,
0058 dir_okay=False,
0059 help="A config file following the input spec. By default, all branches will be plotted.",
0060 ),
0061 mode: Mode = typer.Option(Mode.recreate, help="Mode to open ROOT file in"),
0062 plots: Optional[Path] = typer.Option(
0063 None,
0064 "--plots",
0065 "-p",
0066 file_okay=False,
0067 help="If set, output plots individually to this directory",
0068 ),
0069 plot_format: str = typer.Option(
0070 "pdf", "--plot-format", "-f", help="Format to write plots in if --plots is set"
0071 ),
0072 silent: bool = typer.Option(
0073 False, "--silent", "-s", help="Do not print any output"
0074 ),
0075 dump_yml: bool = typer.Option(False, help="Print axis ranges as yml"),
0076 ):
0077 """
0078 Script to plot all branches in a TTree from a ROOT file, with optional configurable binning and ranges.
0079 Also allows setting extra expressions to be plotted as well.
0080 """
0081
0082 rf = uproot.open(infile)
0083 tree = rf[treename]
0084
0085 outfile = getattr(uproot, mode.value)(outpath)
0086
0087 if config_file is None:
0088 config = Config()
0089 else:
0090 with config_file.open() as fh:
0091 config = Config.model_validate(yaml.safe_load(fh))
0092
0093 histograms = {}
0094
0095 if not silent:
0096 print(config.extra_histograms, file=sys.stderr)
0097
0098 for df in tree.iterate(library="ak", how=dict):
0099 for col in df.keys():
0100 if any([re.match(ex, col) for ex in config.exclude]):
0101 continue
0102 h = histograms.get(col)
0103 values = awkward.flatten(df[col], axis=None)
0104
0105 if len(values) == 0:
0106 print(f"WARNING: Branch '{col}' is empty. Skipped.")
0107 continue
0108
0109 if h is None:
0110
0111 found = None
0112 for ex, data in config.histograms.items():
0113 if re.match(ex, col):
0114 found = data.model_copy()
0115 print(
0116 "Found HistConfig",
0117 ex,
0118 "for",
0119 col,
0120 ":",
0121 found,
0122 file=sys.stderr,
0123 )
0124
0125 if found is None:
0126 found = HistConfig()
0127
0128 if found.min is None:
0129 found.min = awkward.min(values)
0130
0131 if found.max is None:
0132 found.max = awkward.max(values)
0133
0134 if found.min == found.max:
0135 found.min -= 1
0136 found.max += 1
0137
0138 h = hist.Hist(
0139 hist.axis.Regular(
0140 found.nbins, found.min, found.max, name=found.label or col
0141 )
0142 )
0143
0144 histograms[col] = h
0145 h.fill(values)
0146
0147 for extra in config.extra_histograms:
0148 h = histograms.get(extra.name)
0149
0150 calc = eval(extra.expression)
0151 values = awkward.flatten(calc, axis=None)
0152 if h is None:
0153 if extra.min is None:
0154 extra.min = awkward.min(values)
0155 if extra.max is None:
0156 extra.max = awkward.max(values)
0157
0158 if extra.min == extra.max:
0159 extra.min -= 1
0160 extra.max += 1
0161
0162 h = hist.Hist(
0163 hist.axis.Regular(
0164 extra.nbins,
0165 extra.min,
0166 extra.max,
0167 name=extra.label or extra.name,
0168 )
0169 )
0170
0171 histograms[extra.name] = h
0172 h.fill(values)
0173
0174 if plots is not None:
0175 plots.mkdir(parents=True, exist_ok=True)
0176
0177 for k, h in histograms.items():
0178 if not silent:
0179 if dump_yml:
0180 ax = h.axes[0]
0181 s = """
0182 {k}:
0183 nbins: {b}
0184 min: {min}
0185 max: {max}
0186 """.format(
0187 k=k, b=len(ax.edges) - 1, min=ax.edges[0], max=ax.edges[-1]
0188 )
0189 print(s)
0190 else:
0191 print(k, h.axes[0])
0192 outfile[k] = h
0193
0194 if plots is not None:
0195 fig, ax = matplotlib.pyplot.subplots()
0196
0197 h.plot(ax=ax, flow=None)
0198
0199 fig.tight_layout()
0200 fig.savefig(str(plots / f"{k}.{plot_format}"))
0201 matplotlib.pyplot.close()
0202
0203
0204 if __name__ == "__main__":
0205 typer.run(main)