Warning, /epic-data/onnx/identity_gemm_w1x1_b1.ipynb is written in an unsupported language. File is not indexed.
0001 {
0002 "nbformat": 4,
0003 "nbformat_minor": 0,
0004 "metadata": {
0005 "colab": {
0006 "private_outputs": true,
0007 "provenance": []
0008 },
0009 "kernelspec": {
0010 "name": "python3",
0011 "display_name": "Python 3"
0012 },
0013 "language_info": {
0014 "name": "python"
0015 }
0016 },
0017 "cells": [
0018 {
0019 "cell_type": "code",
0020 "execution_count": null,
0021 "metadata": {
0022 "id": "caYDhywKMSWl"
0023 },
0024 "outputs": [],
0025 "source": [
0026 "from __future__ import print_function\n",
0027 "from itertools import count\n",
0028 "\n",
0029 "import torch\n",
0030 "import torch.nn.functional as F"
0031 ]
0032 },
0033 {
0034 "cell_type": "code",
0035 "source": [
0036 "W_target = torch.tensor([[1.]])\n",
0037 "b_target = torch.tensor([0.])\n",
0038 "\n",
0039 "print(W_target)\n",
0040 "print(b_target)"
0041 ],
0042 "metadata": {
0043 "id": "2b6XbnmCNEYF"
0044 },
0045 "execution_count": null,
0046 "outputs": []
0047 },
0048 {
0049 "cell_type": "code",
0050 "source": [
0051 "def f(x):\n",
0052 " \"\"\"Approximated function.\"\"\"\n",
0053 " return x.mm(W_target) + b_target.item()\n",
0054 "\n",
0055 "\n",
0056 "def poly_desc(W, b):\n",
0057 " \"\"\"Creates a string description of a polynomial.\"\"\"\n",
0058 " result = 'y = '\n",
0059 " for i, w in enumerate(W):\n",
0060 " result += '{:+.2f} x^{} '.format(w, i + 1)\n",
0061 " result += '{:+.2f}'.format(b[0])\n",
0062 " return result\n",
0063 "\n",
0064 "\n",
0065 "def get_batch(batch_size=32):\n",
0066 " \"\"\"Builds a batch i.e. (x, f(x)) pair.\"\"\"\n",
0067 " x = torch.randn(batch_size,1)\n",
0068 " y = f(x)\n",
0069 " return x, y\n",
0070 "\n",
0071 "\n",
0072 "# Define model\n",
0073 "torch_model = torch.nn.Linear(W_target.size(0), 1)\n",
0074 "\n",
0075 "for batch_idx in count(1):\n",
0076 " # Get data\n",
0077 " batch_x, batch_y = get_batch()\n",
0078 "\n",
0079 " # Reset gradients\n",
0080 " torch_model.zero_grad()\n",
0081 "\n",
0082 " # Forward pass\n",
0083 " output = F.smooth_l1_loss(torch_model(batch_x), batch_y)\n",
0084 " loss = output.item()\n",
0085 "\n",
0086 " # Backward pass\n",
0087 " output.backward()\n",
0088 "\n",
0089 " # Apply gradients\n",
0090 " for param in torch_model.parameters():\n",
0091 " param.data.add_(-0.1 * param.grad)\n",
0092 "\n",
0093 " # Stop criterion\n",
0094 " if loss < 1e-3:\n",
0095 " break\n",
0096 "\n",
0097 "print('Loss: {:.6f} after {} batches'.format(loss, batch_idx))\n",
0098 "print('==> Learned function:\\t' + poly_desc(torch_model.weight.view(-1), torch_model.bias))\n",
0099 "print('==> Actual function:\\t' + poly_desc(W_target.view(-1), b_target))"
0100 ],
0101 "metadata": {
0102 "id": "k9XLRY6jMgR0"
0103 },
0104 "execution_count": null,
0105 "outputs": []
0106 },
0107 {
0108 "cell_type": "code",
0109 "source": [
0110 "!pip install onnxscript"
0111 ],
0112 "metadata": {
0113 "id": "Z-lKVqBRR3UR"
0114 },
0115 "execution_count": null,
0116 "outputs": []
0117 },
0118 {
0119 "cell_type": "code",
0120 "source": [
0121 "import onnx\n",
0122 "import onnxscript\n",
0123 "\n",
0124 "torch_input = torch.randn(1,1)\n",
0125 "torch_output = torch_model(torch_input)\n",
0126 "onnx_program = torch.onnx.export(\n",
0127 " torch_model,\n",
0128 " torch_input,\n",
0129 " \"identity_gemm_w1x1_b1.onnx\",\n",
0130 " export_params = True,\n",
0131 " input_names = ['InclusiveKinematicsElectron.x'],\n",
0132 " output_names = ['InclusiveKinematicsML.x'],\n",
0133 ")"
0134 ],
0135 "metadata": {
0136 "id": "25N8HUU6MmZF"
0137 },
0138 "execution_count": null,
0139 "outputs": []
0140 },
0141 {
0142 "cell_type": "code",
0143 "source": [
0144 "import onnx\n",
0145 "onnx_model = onnx.load(\"identity_gemm_w1x1_b1.onnx\")\n",
0146 "onnx.checker.check_model(onnx_model)"
0147 ],
0148 "metadata": {
0149 "id": "XPvjg0-URwnO"
0150 },
0151 "execution_count": null,
0152 "outputs": []
0153 },
0154 {
0155 "cell_type": "code",
0156 "source": [
0157 "!pip install onnxruntime"
0158 ],
0159 "metadata": {
0160 "id": "JwNIOzCcVVdg"
0161 },
0162 "execution_count": null,
0163 "outputs": []
0164 },
0165 {
0166 "cell_type": "code",
0167 "source": [
0168 "import numpy as np\n",
0169 "import onnxruntime\n",
0170 "\n",
0171 "ort_session = onnxruntime.InferenceSession(\"identity_gemm_w1x1_b1.onnx\", providers=[\"CPUExecutionProvider\"])\n",
0172 "\n",
0173 "def to_numpy(tensor):\n",
0174 " return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()\n",
0175 "\n",
0176 "# compute ONNX Runtime output prediction\n",
0177 "ort_input = {ort_session.get_inputs()[0].name: to_numpy(torch_input)}\n",
0178 "ort_output = ort_session.run(None, ort_input)\n",
0179 "\n",
0180 "# compare ONNX Runtime and PyTorch results\n",
0181 "np.testing.assert_allclose(to_numpy(torch_output), ort_output[0], rtol=1e-03, atol=1e-05)\n",
0182 "\n",
0183 "print(\"Exported model has been tested with ONNXRuntime, and the result looks good!\")"
0184 ],
0185 "metadata": {
0186 "id": "nJAIj4xtVLGp"
0187 },
0188 "execution_count": null,
0189 "outputs": []
0190 },
0191 {
0192 "cell_type": "code",
0193 "source": [],
0194 "metadata": {
0195 "id": "DT8eI-rwVUCN"
0196 },
0197 "execution_count": null,
0198 "outputs": []
0199 }
0200 ]
0201 }