Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-23 09:22:36

0001 """
0002 ** convert **
0003 defines the conversion function to and ONNX file
0004 """
0005 
0006 import argparse
0007 import sys
0008 
0009 import tf2onnx
0010 import numpy as np
0011 from onnxruntime import InferenceSession
0012 
0013 from core.constants import GLOBAL_CHECKPOINT_DIR, CONV_DIR, ORIGINAL_DIM
0014 from core.model import VAEHandler
0015 """
0016     epoch: epoch of the saved checkpoint model
0017     study-name: study-name for which the model is trained for
0018 """
0019 
0020 
0021 def parse_args(argv):
0022     p = argparse.ArgumentParser()
0023     p.add_argument("--epoch", type=int, default=None)
0024     p.add_argument("--study-name", type=str, default="default_study_name")
0025     args = p.parse_args()
0026     return args
0027 
0028 
0029 # main function
0030 def main(argv):
0031     # 1. Set up the model to convert
0032     # Parse commandline arguments
0033     args = parse_args(argv)
0034     epoch = args.epoch
0035     study_name = args.study_name
0036 
0037     # Instantiate and load a saved model
0038     vae = VAEHandler()
0039 
0040     # Load the saved weights
0041     weights_dir = f"VAE_epoch_{epoch:03}" if epoch is not None else "VAE_best"
0042     vae.model.load_weights(
0043         f"{GLOBAL_CHECKPOINT_DIR}/{study_name}/{weights_dir}/model_weights"
0044     ).expect_partial()
0045 
0046     # 2. Convert the model to ONNX format
0047     # Create the Keras model, convert it into an ONNX model, and save.
0048     keras_model = vae.model.decoder
0049     output_path = f"{CONV_DIR}/{study_name}/Generator_{weights_dir}.onnx"
0050     onnx_model = tf2onnx.convert.from_keras(keras_model,
0051                                             output_path=output_path)
0052 
0053     # Checking the converted model
0054     input_1 = np.random.randn(10).astype(np.float32).reshape(1, -1)
0055     input_2 = np.random.randn(1).astype(np.float32).reshape(1, -1)
0056     input_3 = np.random.randn(1).astype(np.float32).reshape(1, -1)
0057     input_4 = np.random.randn(2).astype(np.float32).reshape(1, -1)
0058 
0059     sess = InferenceSession(output_path)
0060     # TODO: @Piyush-555 Find a way to use predefined names
0061     result = sess.run(
0062         None, {
0063             'input_9': input_1,
0064             'input_6': input_2,
0065             'input_7': input_3,
0066             'input_8': input_4
0067         })
0068     assert result[0].shape[1] == ORIGINAL_DIM
0069 
0070 
0071 if __name__ == "__main__":
0072     exit(main(sys.argv[1:]))