Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-09 07:49:49

0001 #!/usr/bin/env python 
0002 
0003 import numpy as np
0004 import os, re, logging
0005 log = logging.getLogger(__name__)
0006 from collections import OrderedDict as odict 
0007 
0008 
0009 class stag_item(object):
0010     @classmethod
0011     def Placeholder(cls):
0012         return cls(-1,"placeholder","ERROR" )
0013 
0014     def __init__(self, code, name, note=""):
0015         self.code = code
0016         self.name = name
0017         self.note = note
0018 
0019     def __str__(self):
0020         return "%2d : %10s : %s " % (self.code, self.name, self.note)
0021     def __repr__(self):
0022         return "%2d : %10s" % (self.code, self.name)
0023 
0024 
0025 class stag(object):
0026     """
0027     # the below NSEQ, BITS, ... param need to correspond to stag.h static constexpr 
0028     """
0029     enum_ptn = re.compile("^\s*(\w+)\s*=\s*(.*?),*\s*?$")
0030     note_ptn = re.compile("^\s*static constexpr const char\* (\w+)_note = \"(.*)\" ;\s*$")
0031 
0032     PATH = "$OPTICKS_PREFIX/include/sysrap/stag.h" 
0033 
0034     NSEQ = 4   ## must match stag.h:NSEQ 
0035     BITS = 4   ## must match stag.h:BITS
0036     MASK = ( 0x1 << BITS ) - 1 
0037     SLOTMAX = 64//BITS
0038     SLOTS = SLOTMAX*NSEQ
0039 
0040 
0041     @classmethod
0042     def NumStarts(cls, tg):
0043         ns = np.zeros( (len(tg)), dtype=np.uint8 ) 
0044         for i in range(len(tg)):
0045             starts = np.where( tg[i] == tg[0,0] )[0] 
0046             ns[i] = len(starts)
0047         pass
0048         return ns 
0049  
0050 
0051     @classmethod
0052     def StepSplit(cls, tg, fl=None):
0053         """
0054         Hmm maybe StepFold is clearer name 
0055 
0056         :param tg: unpacked tag array of shape (n, SLOTS)
0057         :param fl: None or flat array of shape (n, SLOTS)
0058         :return tgs OR (tgs,fls): step split arrays of shape (n, max_starts, max_slots) 
0059 
0060         In [4]: at[0]
0061         Out[4]: array([ 1,  2,  9, 10,  1,  2,  9, 10,  1,  2, 11, 12,  0,  0,  0,  0], dtype=uint8)
0062 
0063         In [8]: ats[0]
0064         Out[8]: 
0065         array([[ 1,  2,  9, 10,  0,  0,  0,  0,  0,  0],
0066                [ 1,  2,  9, 10,  0,  0,  0,  0,  0,  0],
0067                [ 1,  2, 11, 12,  0,  0,  0,  0,  0,  0],
0068                [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0]], dtype=uint8)
0069 
0070         """
0071         if not fl is None:
0072             assert fl.shape == tg.shape 
0073         pass
0074 
0075         max_starts = 0   ## corresponds to the maximum number of steps of all photons 
0076         max_slots = 0 
0077         for i in range(len(tg)):
0078             starts = np.where( tg[i] == tg[0,0] )[0] # indices of where the very first tag appears in this photons tags 
0079             if len(starts) > max_starts: max_starts = len(starts)
0080             ends = np.where( tg[i] == 0 )[0] 
0081             end = ends[0] if len(ends) > 0 else len(tg[i])  
0082             mkr = np.concatenate( (starts, np.array((end,))) )
0083             mkr_slots = np.diff(mkr).max()  
0084             if mkr_slots > max_slots: max_slots = mkr_slots  
0085         pass
0086         print("max_starts:%s max_slots:%d" % (max_starts, max_slots))
0087 
0088         tgs = np.zeros((len(tg), max_starts, max_slots), dtype=np.uint8)
0089         fls = np.zeros((len(tg), max_starts, max_slots), dtype=np.float32) if not fl is None else None
0090 
0091         for i in range(len(tg)):
0092             starts = np.where( tg[i] == tg[0,0] )[0] 
0093             ends = np.where( tg[i] == 0 )[0] 
0094             end = ends[0] if len(ends) > 0 else len(tg[i])  
0095             ## above handles when the tags do not get to zero due to collection truncation
0096 
0097             for j in range(len(starts)):
0098                 st = starts[j]
0099                 en = starts[j+1] if j+1 < len(starts) else end
0100                 tgs[i, j,0:en-st] = tg[i,st:en] 
0101                 if not fls is None:
0102                     fls[i, j,0:en-st] = fl[i,st:en] 
0103                 pass
0104             pass
0105         pass
0106         return tgs if fls is None else tgs,fls         
0107 
0108 
0109 
0110     @classmethod
0111     def Unpack(cls, tag):
0112         """
0113         :param tag: (n, NSEQ) array of bitpacked tag enumerations
0114         :return tg: (n, SLOTS) array of unpacked tag enumerations
0115 
0116         Usage::
0117 
0118             # apply stag.Unpack to both as same stag.h bitpacking is used
0119             at = stag.Unpack(a.tag) if hasattr(a,"tag") else None
0120             bt = stag.Unpack(b.tag) if hasattr(b,"tag") else None
0121 
0122         """
0123         assert tag.shape == (len(tag), cls.NSEQ)
0124 
0125         st = np.zeros( (len(tag), cls.SLOTS), dtype=np.uint8 )   
0126         for i in range(cls.NSEQ):
0127             for j in range(cls.SLOTMAX):
0128                 st[:,i*cls.SLOTMAX+j] = (tag[:,i] >> (cls.BITS*j)) & cls.MASK
0129             pass
0130         pass
0131         return st 
0132 
0133     def __init__(self, path=PATH):
0134         path = os.path.expandvars(path)
0135         lines = open(path, "r").read().splitlines()
0136         self.path = path 
0137         self.lines = lines 
0138         self.items = []
0139         self.parse()
0140 
0141     def find_item(self, name):
0142         for item in self.items:
0143             if item.name == name: return item
0144         pass
0145         return None 
0146 
0147     def parse(self):
0148         self.code2item = odict()
0149         self.name2item = odict()
0150         for line in self.lines:
0151             enum_match = self.enum_ptn.match(line)
0152             note_match = self.note_ptn.match(line)
0153             if enum_match:
0154                 name, val = enum_match.groups() 
0155                 pfx = "stag_"
0156                 assert name.startswith(pfx) 
0157                 sname = name[len(pfx):]
0158                 code = int(val)
0159                 item = stag_item(code, sname, "") 
0160                 self.items.append(item)
0161                 self.code2item[code] = item
0162                 self.name2item[sname] = item
0163                 log.debug("%40s : name:%20s  sname:%10s val:%10s code:%d " % (line,name,sname,val, code) )
0164             elif note_match:
0165                 name, note = note_match.groups()
0166                 item = self.find_item(name)
0167                 assert not item is None
0168                 item.note = note 
0169                 log.debug(" note %10s : %s " % (name, note))
0170             pass
0171             pass
0172         pass
0173 
0174     def old_label(self, st):
0175         d = self.d
0176         label_ = lambda _:repr(d.get(_,stag_item.Placeholder()))
0177         ilabel_ = lambda _:"%2d : %s" % ( _, label_(st[_])) 
0178         return "\n".join(map(ilabel_, range(len(st))))
0179 
0180 
0181     def __call__(self, code):
0182         return self.code2item.get(code,stag_item.Placeholder())
0183  
0184     def label(self, st, fl=None):
0185         """
0186         :param st: array of stag enumeration codes
0187         :param fl: None or array of flat uniform rands of shape shape as st 
0188         """
0189         if not fl is None:
0190             assert st.shape == fl.shape
0191         pass
0192 
0193         lines = [] 
0194         num_zero = 0 
0195 
0196         for i in range(len(st)):
0197             code = st[i]
0198             flat = fl[i] if not fl is None else None
0199             item = self(code)
0200             it = repr(item)
0201             assert item.code == code
0202             if code == st[0] and i > 0:
0203                 lines.append("")   
0204             pass
0205             label = "%2d : %s " % (i, it) if fl is None else "%2d : %10.4f : %s" % (i, flat, it)
0206             lines.append(label)
0207             if code == 0:
0208                 num_zero += 1 
0209                 if num_zero == 2: break 
0210             pass 
0211         pass
0212         return "\n".join(lines)
0213 
0214     def __str__(self):
0215         return "\n".join(self.lines)
0216 
0217     def __repr__(self):
0218         return "\n".join(list(map(repr,self.items)))  
0219 
0220 
0221 
0222 
0223 def test_label():
0224    tag = stag()
0225    #print(tag) 
0226    print(repr(tag))
0227 
0228    st = np.array([[ 1,  2,  9, 10,  1,  2, 11, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
0229                   [ 1,  2,  9, 10,  1,  2, 11, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
0230                   [ 1,  2,  9, 10,  1,  2, 11, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]], dtype=np.uint8)
0231 
0232    print(tag.label(st[0,:10]))
0233 
0234 
0235 def test_StepSplit():
0236     from numpy import array, uint8
0237     at = array(
0238       [[ 1,  2,  9, 10,  1,  2,  9, 10,  1,  2, 11, 12,  0,  0,  0,  0],
0239        [ 1,  2,  9, 10,  1,  2,  9, 10,  1,  2, 11, 12,  0,  0,  0,  0],
0240        [ 1,  2,  9, 10,  1,  2,  9, 10,  1,  2, 11, 12,  0,  0,  0,  0]], dtype=uint8)
0241 
0242     x_ats = array(
0243       [[[ 1,  2,  9, 10,  0,  0,  0,  0,  0,  0],
0244         [ 1,  2,  9, 10,  0,  0,  0,  0,  0,  0],
0245         [ 1,  2, 11, 12,  0,  0,  0,  0,  0,  0],
0246         [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0]],
0247 
0248        [[ 1,  2,  9, 10,  0,  0,  0,  0,  0,  0],
0249         [ 1,  2,  9, 10,  0,  0,  0,  0,  0,  0],
0250         [ 1,  2, 11, 12,  0,  0,  0,  0,  0,  0],
0251         [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0]],
0252 
0253        [[ 1,  2,  9, 10,  0,  0,  0,  0,  0,  0],
0254         [ 1,  2,  9, 10,  0,  0,  0,  0,  0,  0],
0255         [ 1,  2, 11, 12,  0,  0,  0,  0,  0,  0],
0256         [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0]]], dtype=uint8)
0257 
0258     ats = stag.StepSplit(at)
0259     assert np.all( ats == x_ats )
0260 
0261 
0262 
0263 
0264 
0265 
0266 
0267 if __name__ == '__main__':
0268     logging.basicConfig(level=logging.INFO)
0269   
0270     #test_label()
0271     #test_StepSplit()
0272 
0273     #test_PFold()
0274 
0275