File indexing completed on 2025-02-23 09:22:36
0001 """
0002 ** convert **
0003 defines the conversion function to and ONNX file
0004 """
0005
0006 import argparse
0007 import sys
0008
0009 import tf2onnx
0010 import numpy as np
0011 from onnxruntime import InferenceSession
0012
0013 from core.constants import GLOBAL_CHECKPOINT_DIR, CONV_DIR, ORIGINAL_DIM
0014 from core.model import VAEHandler
0015 """
0016 epoch: epoch of the saved checkpoint model
0017 study-name: study-name for which the model is trained for
0018 """
0019
0020
0021 def parse_args(argv):
0022 p = argparse.ArgumentParser()
0023 p.add_argument("--epoch", type=int, default=None)
0024 p.add_argument("--study-name", type=str, default="default_study_name")
0025 args = p.parse_args()
0026 return args
0027
0028
0029
0030 def main(argv):
0031
0032
0033 args = parse_args(argv)
0034 epoch = args.epoch
0035 study_name = args.study_name
0036
0037
0038 vae = VAEHandler()
0039
0040
0041 weights_dir = f"VAE_epoch_{epoch:03}" if epoch is not None else "VAE_best"
0042 vae.model.load_weights(
0043 f"{GLOBAL_CHECKPOINT_DIR}/{study_name}/{weights_dir}/model_weights"
0044 ).expect_partial()
0045
0046
0047
0048 keras_model = vae.model.decoder
0049 output_path = f"{CONV_DIR}/{study_name}/Generator_{weights_dir}.onnx"
0050 onnx_model = tf2onnx.convert.from_keras(keras_model,
0051 output_path=output_path)
0052
0053
0054 input_1 = np.random.randn(10).astype(np.float32).reshape(1, -1)
0055 input_2 = np.random.randn(1).astype(np.float32).reshape(1, -1)
0056 input_3 = np.random.randn(1).astype(np.float32).reshape(1, -1)
0057 input_4 = np.random.randn(2).astype(np.float32).reshape(1, -1)
0058
0059 sess = InferenceSession(output_path)
0060
0061 result = sess.run(
0062 None, {
0063 'input_9': input_1,
0064 'input_6': input_2,
0065 'input_7': input_3,
0066 'input_8': input_4
0067 })
0068 assert result[0].shape[1] == ORIGINAL_DIM
0069
0070
0071 if __name__ == "__main__":
0072 exit(main(sys.argv[1:]))