Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2024-05-18 07:43:47

0001 # SPDX-License-Identifier: LGPL-3.0-or-later
0002 # Copyright (C) 2023 Connor Pecar
0003 
0004 import tensorflow
0005 from tensorflow import keras
0006 import energyflow
0007 from energyflow.archs import PFN
0008 import numpy as np
0009 
0010 modelname = "pfn_testEpic_000-2_vecQele_nHFS2_500_bs10k_bestValLoss"
0011 model = keras.models.load_model(modelname)
0012 
0013 def eflowPredict(feat, globalfeat):
0014     feat = np.asarray(feat)
0015     feat = np.reshape(feat, (1,len(feat),7))
0016     globalfeat = np.asarray(globalfeat)
0017     globalfeat = np.reshape(globalfeat, (1,10))    
0018     pred = model([feat,globalfeat])
0019     pred = np.reshape(pred,(4))
0020     return pred