Preamble
AI applications are getting closer and closer to users. Since then a lot of demand has arisen for bringing AI models to run in different types of environments such as Edge Device, Web Browser, Mobile App, Arduino … For that reason, the export of AI models to available formats. being able to run on these platforms is a very essential job. In this article we will learn about how to export the model from PyTorch to a hot framework in the AI community, ONNX, and test this model in web browser with ONNX.js. OK is no longer wordy, and we’ll get started right away.
Deploy the model under the client
To unify the general concept for this part, I would like to define the deploy concept on the client, only the types of deploying the AI model directly on edge devices, web browsers, mobile apps … to differentiate the deploying model. backend (or server) image.
Why need to deploy the model under the client
First of all, we need to distinguish the pros and cons of these two types of deploying:
- Server-side Deploy: The AI models are deployed on a centralized server. Clients communicate with the AI model through APIs. The advantage of this method is that AI models can be centrally handled on the server, less dependent on configuration and deployment environment. Centralized model management is also easier for deploying as well as versioning, mantaining models, and in particular, model privacy, which is ensured when deploying centrally. You won’t have to worry about someone stealing your model or your model research architecture being copied. However, these deployments have a drawback that concentrating all the processing on the server can cause our system to overload as well as the operating costs, scale of AI models on the centralized server. is huge. Another downside to this approach is that data privacy is difficult to guarantee since users will have to send data to the server for processing.
- Client-side Deploy: This is the second method we usually use. This method will bring AI processing as well as model down the client side. This method has the advantage that the AI-side processes are distributed and do not need a server configured too large to run the model. Applications running AI models can completely run under the client offline without having to access the server. Another advantage of this approach is to keep User Data Privacy . However, this method also has disadvantages which are the difficulty in updating and managing version of the model. Only suitable for small-sized models due to the current computing limitations of client-side hardware. It will be necessary to ensure the model runs on different platforms, deploy environments, and moreover, it is very important that we cannot guarantee Model Privacy when we have brought the entire AI model and processing down. client. This is the same for hackers as giving children to evil . Summing up the pros and cons of these two methods we can see in the following figure
When should deploy the model under the client
So the question is in what case will we use the deploy method under the client. Through my experience when working through AI projects, we can choose to deploy under the client in the following cases:
- The application requires to run offline: at this point there is no other way than to have to bring the model under the client
- The model is light enough and still accurate: This is a very important point to avoid affecting the user experience. Because models that are too heavy will result in very long load and inference times on client devices due to hardware limitations. This affects the user experience.
- The model does not need to be updated much: This depends on each problem, if your problem does not require constant updating of models like Online Learning, the direction of bringing the model down to the client should also be considered.
- Important application of user data security: If your application does not want to send data to a centralized server for processing, then bringing the AI model down to the client is an advantage.
- There is a copyright protection solution of the model: When bringing the model to the client, an important point to consider is the copyright protection of the model made with attacks.
Those are the points to consider before deciding to bring the AI model under the client for deployment. If you are ready then we will dive into the techniques and tools to do this. In this article, I use ONNX – Open Neural Network Exchange – one of the famous frameworks to perform model conversion and demo on web application with ONNX.js offline.
What is ONNX
ONNX can be considered as an intermediate framework to represent an AI model that has been trianing from many different frameworks such as PyTorch, Tensorflow, Caffe … With ONNX format we can completely run the mode on different platforms. web, desktop, FPGA, ARM, Mobile …. And for that reason it’s really useful if you want to develop cross-platform AI models.
Basically like that. We just need to understand its meaning. I won’t go too deep into using this framework. You can find out on its own documentation it. We will get into the regular code.
Show me your code
Build model on PyTorch
Very simply, we will build a number recognition application with MNIST. This is a fairly simple article so I would like to not explain too much. You just run by the code
- The first is to import the library
1 2 3 4 5 6 | <span class="token keyword">import</span> torch <span class="token keyword">import</span> torchvision <span class="token keyword">import</span> torch <span class="token punctuation">.</span> nn <span class="token keyword">as</span> nn <span class="token keyword">import</span> torch <span class="token punctuation">.</span> nn <span class="token punctuation">.</span> functional <span class="token keyword">as</span> F <span class="token keyword">import</span> torch <span class="token punctuation">.</span> optim <span class="token keyword">as</span> optim |
- Next is the declaration of necessary hyperparams
1 2 3 4 5 6 7 8 9 10 11 | n_epochs <span class="token operator">=</span> <span class="token number">30</span> batch_size_train <span class="token operator">=</span> <span class="token number">64</span> batch_size_test <span class="token operator">=</span> <span class="token number">1000</span> learning_rate <span class="token operator">=</span> <span class="token number">0.01</span> momentum <span class="token operator">=</span> <span class="token number">0.5</span> log_interval <span class="token operator">=</span> <span class="token number">10</span> random_seed <span class="token operator">=</span> <span class="token number">1</span> torch <span class="token punctuation">.</span> backends <span class="token punctuation">.</span> cudnn <span class="token punctuation">.</span> enabled <span class="token operator">=</span> <span class="token boolean">False</span> torch <span class="token punctuation">.</span> manual_seed <span class="token punctuation">(</span> random_seed <span class="token punctuation">)</span> |
- Next is transform declaration for data and load corresponding data
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | mnist_transform <span class="token operator">=</span> transform <span class="token operator">=</span> torchvision <span class="token punctuation">.</span> transforms <span class="token punctuation">.</span> Compose <span class="token punctuation">(</span> <span class="token punctuation">[</span> torchvision <span class="token punctuation">.</span> transforms <span class="token punctuation">.</span> ToTensor <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> torchvision <span class="token punctuation">.</span> transforms <span class="token punctuation">.</span> Normalize <span class="token punctuation">(</span> <span class="token punctuation">(</span> <span class="token number">0.1307</span> <span class="token punctuation">,</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> <span class="token punctuation">(</span> <span class="token number">0.3081</span> <span class="token punctuation">,</span> <span class="token punctuation">)</span> <span class="token punctuation">)</span> <span class="token punctuation">]</span> <span class="token punctuation">)</span> <span class="token comment"># Train dataloader </span> train_loader <span class="token operator">=</span> torch <span class="token punctuation">.</span> utils <span class="token punctuation">.</span> data <span class="token punctuation">.</span> DataLoader <span class="token punctuation">(</span> torchvision <span class="token punctuation">.</span> datasets <span class="token punctuation">.</span> MNIST <span class="token punctuation">(</span> <span class="token string">'data'</span> <span class="token punctuation">,</span> train <span class="token operator">=</span> <span class="token boolean">True</span> <span class="token punctuation">,</span> download <span class="token operator">=</span> <span class="token boolean">True</span> <span class="token punctuation">,</span> transform <span class="token operator">=</span> mnist_transform <span class="token punctuation">)</span> <span class="token punctuation">,</span> batch_size <span class="token operator">=</span> batch_size_train <span class="token punctuation">,</span> shuffle <span class="token operator">=</span> <span class="token boolean">True</span> <span class="token punctuation">)</span> <span class="token comment"># Test dataloader </span> test_loader <span class="token operator">=</span> torch <span class="token punctuation">.</span> utils <span class="token punctuation">.</span> data <span class="token punctuation">.</span> DataLoader <span class="token punctuation">(</span> torchvision <span class="token punctuation">.</span> datasets <span class="token punctuation">.</span> MNIST <span class="token punctuation">(</span> <span class="token string">'data'</span> <span class="token punctuation">,</span> train <span class="token operator">=</span> <span class="token boolean">False</span> <span class="token punctuation">,</span> download <span class="token operator">=</span> <span class="token boolean">True</span> <span class="token punctuation">,</span> transform <span class="token operator">=</span> mnist_transform <span class="token punctuation">)</span> <span class="token punctuation">,</span> batch_size <span class="token operator">=</span> batch_size_test <span class="token punctuation">,</span> shuffle <span class="token operator">=</span> <span class="token boolean">True</span> <span class="token punctuation">)</span> |
- Declare a simple network
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | <span class="token keyword">class</span> <span class="token class-name">Net</span> <span class="token punctuation">(</span> nn <span class="token punctuation">.</span> Module <span class="token punctuation">)</span> <span class="token punctuation">:</span> <span class="token keyword">def</span> <span class="token function">__init__</span> <span class="token punctuation">(</span> self <span class="token punctuation">)</span> <span class="token punctuation">:</span> <span class="token builtin">super</span> <span class="token punctuation">(</span> Net <span class="token punctuation">,</span> self <span class="token punctuation">)</span> <span class="token punctuation">.</span> __init__ <span class="token punctuation">(</span> <span class="token punctuation">)</span> self <span class="token punctuation">.</span> conv1 <span class="token operator">=</span> nn <span class="token punctuation">.</span> Conv2d <span class="token punctuation">(</span> <span class="token number">1</span> <span class="token punctuation">,</span> <span class="token number">10</span> <span class="token punctuation">,</span> kernel_size <span class="token operator">=</span> <span class="token number">5</span> <span class="token punctuation">)</span> self <span class="token punctuation">.</span> conv2 <span class="token operator">=</span> nn <span class="token punctuation">.</span> Conv2d <span class="token punctuation">(</span> <span class="token number">10</span> <span class="token punctuation">,</span> <span class="token number">20</span> <span class="token punctuation">,</span> kernel_size <span class="token operator">=</span> <span class="token number">5</span> <span class="token punctuation">)</span> self <span class="token punctuation">.</span> conv2_drop <span class="token operator">=</span> nn <span class="token punctuation">.</span> Dropout2d <span class="token punctuation">(</span> <span class="token punctuation">)</span> self <span class="token punctuation">.</span> fc1 <span class="token operator">=</span> nn <span class="token punctuation">.</span> Linear <span class="token punctuation">(</span> <span class="token number">320</span> <span class="token punctuation">,</span> <span class="token number">50</span> <span class="token punctuation">)</span> self <span class="token punctuation">.</span> fc2 <span class="token operator">=</span> nn <span class="token punctuation">.</span> Linear <span class="token punctuation">(</span> <span class="token number">50</span> <span class="token punctuation">,</span> <span class="token number">10</span> <span class="token punctuation">)</span> <span class="token keyword">def</span> <span class="token function">forward</span> <span class="token punctuation">(</span> self <span class="token punctuation">,</span> x <span class="token punctuation">)</span> <span class="token punctuation">:</span> x <span class="token operator">=</span> F <span class="token punctuation">.</span> relu <span class="token punctuation">(</span> F <span class="token punctuation">.</span> max_pool2d <span class="token punctuation">(</span> self <span class="token punctuation">.</span> conv1 <span class="token punctuation">(</span> x <span class="token punctuation">)</span> <span class="token punctuation">,</span> <span class="token number">2</span> <span class="token punctuation">)</span> <span class="token punctuation">)</span> x <span class="token operator">=</span> F <span class="token punctuation">.</span> relu <span class="token punctuation">(</span> F <span class="token punctuation">.</span> max_pool2d <span class="token punctuation">(</span> self <span class="token punctuation">.</span> conv2_drop <span class="token punctuation">(</span> self <span class="token punctuation">.</span> conv2 <span class="token punctuation">(</span> x <span class="token punctuation">)</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> <span class="token number">2</span> <span class="token punctuation">)</span> <span class="token punctuation">)</span> x <span class="token operator">=</span> x <span class="token punctuation">.</span> view <span class="token punctuation">(</span> <span class="token operator">-</span> <span class="token number">1</span> <span class="token punctuation">,</span> <span class="token number">320</span> <span class="token punctuation">)</span> x <span class="token operator">=</span> F <span class="token punctuation">.</span> relu <span class="token punctuation">(</span> self <span class="token punctuation">.</span> fc1 <span class="token punctuation">(</span> x <span class="token punctuation">)</span> <span class="token punctuation">)</span> x <span class="token operator">=</span> F <span class="token punctuation">.</span> dropout <span class="token punctuation">(</span> x <span class="token punctuation">,</span> training <span class="token operator">=</span> self <span class="token punctuation">.</span> training <span class="token punctuation">)</span> x <span class="token operator">=</span> self <span class="token punctuation">.</span> fc2 <span class="token punctuation">(</span> x <span class="token punctuation">)</span> <span class="token keyword">return</span> F <span class="token punctuation">.</span> log_softmax <span class="token punctuation">(</span> x <span class="token punctuation">)</span> |
- Declare the loss function and the optimizer
1 2 3 4 | net <span class="token operator">=</span> Net <span class="token punctuation">(</span> <span class="token punctuation">)</span> criterion <span class="token operator">=</span> nn <span class="token punctuation">.</span> CrossEntropyLoss <span class="token punctuation">(</span> <span class="token punctuation">)</span> optimizer <span class="token operator">=</span> optim <span class="token punctuation">.</span> SGD <span class="token punctuation">(</span> net <span class="token punctuation">.</span> parameters <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> lr <span class="token operator">=</span> learning_rate <span class="token punctuation">,</span> momentum <span class="token operator">=</span> momentum <span class="token punctuation">)</span> |
- Build the test function
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | <span class="token keyword">def</span> <span class="token function">test</span> <span class="token punctuation">(</span> net <span class="token punctuation">,</span> test_loader <span class="token punctuation">)</span> <span class="token punctuation">:</span> correct <span class="token operator">=</span> <span class="token number">0</span> total <span class="token operator">=</span> <span class="token number">0</span> <span class="token keyword">with</span> torch <span class="token punctuation">.</span> no_grad <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token punctuation">:</span> <span class="token keyword">for</span> data <span class="token keyword">in</span> test_loader <span class="token punctuation">:</span> images <span class="token punctuation">,</span> labels <span class="token operator">=</span> data outputs <span class="token operator">=</span> net <span class="token punctuation">(</span> images <span class="token punctuation">)</span> _ <span class="token punctuation">,</span> predicted <span class="token operator">=</span> torch <span class="token punctuation">.</span> <span class="token builtin">max</span> <span class="token punctuation">(</span> outputs <span class="token punctuation">.</span> data <span class="token punctuation">,</span> <span class="token number">1</span> <span class="token punctuation">)</span> total <span class="token operator">+=</span> labels <span class="token punctuation">.</span> size <span class="token punctuation">(</span> <span class="token number">0</span> <span class="token punctuation">)</span> correct <span class="token operator">+=</span> <span class="token punctuation">(</span> predicted <span class="token operator">==</span> labels <span class="token punctuation">)</span> <span class="token punctuation">.</span> <span class="token builtin">sum</span> <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token punctuation">.</span> item <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token keyword">print</span> <span class="token punctuation">(</span> <span class="token string">'Accuracy of the network on the 10000 test images: %d %%'</span> <span class="token operator">%</span> <span class="token punctuation">(</span> <span class="token number">100</span> <span class="token operator">*</span> correct <span class="token operator">/</span> total <span class="token punctuation">)</span> <span class="token punctuation">)</span> |
- Training model
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | <span class="token keyword">for</span> epoch <span class="token keyword">in</span> <span class="token builtin">range</span> <span class="token punctuation">(</span> <span class="token number">200</span> <span class="token punctuation">)</span> <span class="token punctuation">:</span> running_loss <span class="token operator">=</span> <span class="token number">0.0</span> <span class="token keyword">for</span> i <span class="token punctuation">,</span> data <span class="token keyword">in</span> <span class="token builtin">enumerate</span> <span class="token punctuation">(</span> train_loader <span class="token punctuation">,</span> <span class="token number">0</span> <span class="token punctuation">)</span> <span class="token punctuation">:</span> <span class="token comment"># get the inputs; data is a list of [inputs, labels]</span> inputs <span class="token punctuation">,</span> labels <span class="token operator">=</span> data <span class="token comment"># zero the parameter gradients</span> optimizer <span class="token punctuation">.</span> zero_grad <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token comment"># forward + backward + optimize</span> outputs <span class="token operator">=</span> net <span class="token punctuation">(</span> inputs <span class="token punctuation">)</span> loss <span class="token operator">=</span> criterion <span class="token punctuation">(</span> outputs <span class="token punctuation">,</span> labels <span class="token punctuation">)</span> loss <span class="token punctuation">.</span> backward <span class="token punctuation">(</span> <span class="token punctuation">)</span> optimizer <span class="token punctuation">.</span> step <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token comment"># print statistics</span> running_loss <span class="token operator">+=</span> loss <span class="token punctuation">.</span> item <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token keyword">if</span> i <span class="token operator">%</span> <span class="token number">200</span> <span class="token operator">==</span> <span class="token number">0</span> <span class="token punctuation">:</span> <span class="token comment"># print every 2000 mini-batches</span> <span class="token keyword">print</span> <span class="token punctuation">(</span> <span class="token string">'[Epoch %d, %5d] loss: %.3f'</span> <span class="token operator">%</span> <span class="token punctuation">(</span> epoch <span class="token operator">+</span> <span class="token number">1</span> <span class="token punctuation">,</span> i <span class="token operator">+</span> <span class="token number">1</span> <span class="token punctuation">,</span> running_loss <span class="token operator">/</span> <span class="token number">2000</span> <span class="token punctuation">)</span> <span class="token punctuation">)</span> running_loss <span class="token operator">=</span> <span class="token number">0.0</span> test <span class="token punctuation">(</span> net <span class="token punctuation">,</span> test_loader <span class="token punctuation">)</span> <span class="token keyword">print</span> <span class="token punctuation">(</span> <span class="token string">'Finished Training'</span> <span class="token punctuation">)</span> |
Then we conduct training as usual. Sip a cup of coffee and wait for the result …
1 2 3 4 5 6 7 | <span class="token punctuation">[</span> Epoch <span class="token number">200</span> <span class="token punctuation">,</span> <span class="token number">1</span> <span class="token punctuation">]</span> loss <span class="token punctuation">:</span> <span class="token number">0.000</span> <span class="token punctuation">[</span> Epoch <span class="token number">200</span> <span class="token punctuation">,</span> <span class="token number">201</span> <span class="token punctuation">]</span> loss <span class="token punctuation">:</span> <span class="token number">0.007</span> <span class="token punctuation">[</span> Epoch <span class="token number">200</span> <span class="token punctuation">,</span> <span class="token number">401</span> <span class="token punctuation">]</span> loss <span class="token punctuation">:</span> <span class="token number">0.008</span> <span class="token punctuation">[</span> Epoch <span class="token number">200</span> <span class="token punctuation">,</span> <span class="token number">601</span> <span class="token punctuation">]</span> loss <span class="token punctuation">:</span> <span class="token number">0.008</span> <span class="token punctuation">[</span> Epoch <span class="token number">200</span> <span class="token punctuation">,</span> <span class="token number">801</span> <span class="token punctuation">]</span> loss <span class="token punctuation">:</span> <span class="token number">0.007</span> Accuracy of the network on the <span class="token number">10000</span> test images <span class="token punctuation">:</span> <span class="token number">97</span> <span class="token operator">%</span> |
OK, we will save this model with 97% accuracy for further processing
1 2 | torch <span class="token punctuation">.</span> save <span class="token punctuation">(</span> net <span class="token punctuation">.</span> state_dict <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> <span class="token string">"pytorch_model.pt"</span> <span class="token punctuation">)</span> |
Export model on ONNX
Exporting the model from PyTorch to ONNX is actually very simple. We just need to execute a few commands as follows
1 2 3 4 5 6 7 8 9 10 | model <span class="token operator">=</span> Net <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token comment"># Load pretrained weight </span> model <span class="token punctuation">.</span> load_state_dict <span class="token punctuation">(</span> torch <span class="token punctuation">.</span> load <span class="token punctuation">(</span> <span class="token string">'pytorch_model.pt'</span> <span class="token punctuation">)</span> <span class="token punctuation">)</span> <span class="token comment"># Set dummy input </span> dummy_input <span class="token operator">=</span> torch <span class="token punctuation">.</span> zeros <span class="token punctuation">(</span> <span class="token number">1</span> <span class="token punctuation">,</span> <span class="token number">1</span> <span class="token punctuation">,</span> <span class="token number">28</span> <span class="token punctuation">,</span> <span class="token number">28</span> <span class="token punctuation">)</span> <span class="token comment"># Export to ONNX </span> torch <span class="token punctuation">.</span> onnx <span class="token punctuation">.</span> export <span class="token punctuation">(</span> model <span class="token punctuation">,</span> dummy_input <span class="token punctuation">,</span> <span class="token string">'onnx_model.onnx'</span> <span class="token punctuation">,</span> verbose <span class="token operator">=</span> <span class="token boolean">True</span> <span class="token punctuation">)</span> |
Test run in browser
Import ONNX.js libraries
You create the index.html
file in the current directory and paste the code like the test code on Github ONNX.js.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | <span class="token operator"><</span> html <span class="token operator">></span> <span class="token operator"><</span> head <span class="token operator">></span> <span class="token operator"><</span> <span class="token operator">/</span> head <span class="token operator">></span> <span class="token operator"><</span> body <span class="token operator">></span> <span class="token operator"><</span> ! <span class="token operator">-</span> <span class="token operator">-</span> Load ONNX <span class="token punctuation">.</span> js <span class="token operator">-</span> <span class="token operator">-</span> <span class="token operator">></span> <span class="token operator"><</span> script src <span class="token operator">=</span> <span class="token string">"https://cdn.jsdelivr.net/npm/onnxjs/dist/onnx.min.js"</span> <span class="token operator">></span> <span class="token operator"><</span> <span class="token operator">/</span> script <span class="token operator">></span> <span class="token operator"><</span> ! <span class="token operator">-</span> <span class="token operator">-</span> Code that consume ONNX <span class="token punctuation">.</span> js <span class="token operator">-</span> <span class="token operator">-</span> <span class="token operator">></span> <span class="token operator"><</span> script <span class="token operator">></span> <span class="token operator">//</span> create a session const myOnnxSession <span class="token operator">=</span> new onnx <span class="token punctuation">.</span> InferenceSession <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token punctuation">;</span> <span class="token operator">//</span> load the ONNX model <span class="token builtin">file</span> myOnnxSession <span class="token punctuation">.</span> loadModel <span class="token punctuation">(</span> <span class="token string">"./onnx_model.onnx"</span> <span class="token punctuation">)</span> <span class="token punctuation">.</span> then <span class="token punctuation">(</span> <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token operator">=</span> <span class="token operator">></span> <span class="token punctuation">{</span> <span class="token operator">//</span> generate model <span class="token builtin">input</span> const inferenceInputs <span class="token operator">=</span> new onnx <span class="token punctuation">.</span> Tensor <span class="token punctuation">(</span> new Float32Array <span class="token punctuation">(</span> <span class="token number">28</span> <span class="token operator">*</span> <span class="token number">28</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> <span class="token string">'float32'</span> <span class="token punctuation">,</span> <span class="token punctuation">[</span> <span class="token number">1</span> <span class="token punctuation">,</span> <span class="token number">1</span> <span class="token punctuation">,</span> <span class="token number">28</span> <span class="token punctuation">,</span> <span class="token number">28</span> <span class="token punctuation">]</span> <span class="token punctuation">)</span> <span class="token punctuation">;</span> <span class="token operator">//</span> execute the model myOnnxSession <span class="token punctuation">.</span> run <span class="token punctuation">(</span> <span class="token punctuation">[</span> inferenceInputs <span class="token punctuation">]</span> <span class="token punctuation">)</span> <span class="token punctuation">.</span> then <span class="token punctuation">(</span> <span class="token punctuation">(</span> output <span class="token punctuation">)</span> <span class="token operator">=</span> <span class="token operator">></span> <span class="token punctuation">{</span> <span class="token operator">//</span> consume the output const outputTensor <span class="token operator">=</span> output <span class="token punctuation">.</span> values <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token punctuation">.</span> <span class="token builtin">next</span> <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token punctuation">.</span> value <span class="token punctuation">;</span> console <span class="token punctuation">.</span> log <span class="token punctuation">(</span> `model output tensor <span class="token punctuation">:</span> $ <span class="token punctuation">{</span> outputTensor <span class="token punctuation">.</span> data <span class="token punctuation">}</span> <span class="token punctuation">.</span> ` <span class="token punctuation">)</span> <span class="token punctuation">;</span> <span class="token punctuation">}</span> <span class="token punctuation">)</span> <span class="token punctuation">;</span> <span class="token punctuation">}</span> <span class="token punctuation">)</span> <span class="token punctuation">;</span> <span class="token operator"><</span> <span class="token operator">/</span> script <span class="token operator">></span> <span class="token operator"><</span> <span class="token operator">/</span> body <span class="token operator">></span> <span class="token operator"><</span> <span class="token operator">/</span> html <span class="token operator">></span> |
Let’s create a small server to test our results. Enter the terminal in the current directory and type the command
1 2 3 4 | python <span class="token operator">-</span> m http <span class="token punctuation">.</span> server <span class="token operator">>></span> <span class="token operator">></span> Serving HTTP on <span class="token number">0.0</span> <span class="token number">.0</span> <span class="token number">.0</span> port <span class="token number">8000</span> <span class="token punctuation">(</span> http <span class="token punctuation">:</span> <span class="token operator">//</span> <span class="token number">0.0</span> <span class="token number">.0</span> <span class="token number">.0</span> <span class="token punctuation">:</span> <span class="token number">8000</span> <span class="token operator">/</span> <span class="token punctuation">)</span> <span class="token punctuation">.</span> <span class="token punctuation">.</span> <span class="token punctuation">.</span> |
Check the result
After accessing the above address we see a white screen. To check if the model has loaded successfully we turn on the console log
The result shows a red error. We will investigate to see where the error is
Error investigation
The error above indicates that the LogSoftMax function is not currently supported on ONNX.js. In ONNX.js github we find the Openrators Support section. Here clearly states the operators currently supporting. We look to LogSoftmax , currently versions of ONNX.js do not support this operator
So we have to go back to the PyTorch model in the beginning to work on tweaking and find replacement operators.
Adjust the model
In the PyTorch model we see the log_softmax
operator used at the end of the forward function as follows:
1 2 | <span class="token keyword">return</span> F <span class="token punctuation">.</span> log_softmax <span class="token punctuation">(</span> x <span class="token punctuation">)</span> |
Change this function to softmax
function and for convenience we will declare a new network for easy editing. Now the new network is
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | <span class="token comment"># Inference Net</span> <span class="token keyword">class</span> <span class="token class-name">InferenceNet</span> <span class="token punctuation">(</span> nn <span class="token punctuation">.</span> Module <span class="token punctuation">)</span> <span class="token punctuation">:</span> <span class="token keyword">def</span> <span class="token function">__init__</span> <span class="token punctuation">(</span> self <span class="token punctuation">)</span> <span class="token punctuation">:</span> <span class="token builtin">super</span> <span class="token punctuation">(</span> InferenceNet <span class="token punctuation">,</span> self <span class="token punctuation">)</span> <span class="token punctuation">.</span> __init__ <span class="token punctuation">(</span> <span class="token punctuation">)</span> self <span class="token punctuation">.</span> conv1 <span class="token operator">=</span> nn <span class="token punctuation">.</span> Conv2d <span class="token punctuation">(</span> <span class="token number">1</span> <span class="token punctuation">,</span> <span class="token number">10</span> <span class="token punctuation">,</span> kernel_size <span class="token operator">=</span> <span class="token number">5</span> <span class="token punctuation">)</span> self <span class="token punctuation">.</span> conv2 <span class="token operator">=</span> nn <span class="token punctuation">.</span> Conv2d <span class="token punctuation">(</span> <span class="token number">10</span> <span class="token punctuation">,</span> <span class="token number">20</span> <span class="token punctuation">,</span> kernel_size <span class="token operator">=</span> <span class="token number">5</span> <span class="token punctuation">)</span> self <span class="token punctuation">.</span> conv2_drop <span class="token operator">=</span> nn <span class="token punctuation">.</span> Dropout2d <span class="token punctuation">(</span> <span class="token punctuation">)</span> self <span class="token punctuation">.</span> fc1 <span class="token operator">=</span> nn <span class="token punctuation">.</span> Linear <span class="token punctuation">(</span> <span class="token number">320</span> <span class="token punctuation">,</span> <span class="token number">50</span> <span class="token punctuation">)</span> self <span class="token punctuation">.</span> fc2 <span class="token operator">=</span> nn <span class="token punctuation">.</span> Linear <span class="token punctuation">(</span> <span class="token number">50</span> <span class="token punctuation">,</span> <span class="token number">10</span> <span class="token punctuation">)</span> <span class="token keyword">def</span> <span class="token function">forward</span> <span class="token punctuation">(</span> self <span class="token punctuation">,</span> x <span class="token punctuation">)</span> <span class="token punctuation">:</span> x <span class="token operator">=</span> F <span class="token punctuation">.</span> relu <span class="token punctuation">(</span> F <span class="token punctuation">.</span> max_pool2d <span class="token punctuation">(</span> self <span class="token punctuation">.</span> conv1 <span class="token punctuation">(</span> x <span class="token punctuation">)</span> <span class="token punctuation">,</span> <span class="token number">2</span> <span class="token punctuation">)</span> <span class="token punctuation">)</span> x <span class="token operator">=</span> F <span class="token punctuation">.</span> relu <span class="token punctuation">(</span> F <span class="token punctuation">.</span> max_pool2d <span class="token punctuation">(</span> self <span class="token punctuation">.</span> conv2_drop <span class="token punctuation">(</span> self <span class="token punctuation">.</span> conv2 <span class="token punctuation">(</span> x <span class="token punctuation">)</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> <span class="token number">2</span> <span class="token punctuation">)</span> <span class="token punctuation">)</span> x <span class="token operator">=</span> x <span class="token punctuation">.</span> view <span class="token punctuation">(</span> <span class="token operator">-</span> <span class="token number">1</span> <span class="token punctuation">,</span> <span class="token number">320</span> <span class="token punctuation">)</span> x <span class="token operator">=</span> F <span class="token punctuation">.</span> relu <span class="token punctuation">(</span> self <span class="token punctuation">.</span> fc1 <span class="token punctuation">(</span> x <span class="token punctuation">)</span> <span class="token punctuation">)</span> x <span class="token operator">=</span> F <span class="token punctuation">.</span> dropout <span class="token punctuation">(</span> x <span class="token punctuation">,</span> training <span class="token operator">=</span> self <span class="token punctuation">.</span> training <span class="token punctuation">)</span> x <span class="token operator">=</span> self <span class="token punctuation">.</span> fc2 <span class="token punctuation">(</span> x <span class="token punctuation">)</span> <span class="token keyword">return</span> F <span class="token punctuation">.</span> softmax <span class="token punctuation">(</span> x <span class="token punctuation">)</span> |
Since replacing the softmax function does not affect the weight of the model, we can reload the old state_dict
1 2 3 | model <span class="token operator">=</span> InferenceNet <span class="token punctuation">(</span> <span class="token punctuation">)</span> model <span class="token punctuation">.</span> load_state_dict <span class="token punctuation">(</span> torch <span class="token punctuation">.</span> load <span class="token punctuation">(</span> <span class="token string">'pytorch_model.pt'</span> <span class="token punctuation">)</span> <span class="token punctuation">)</span> |
Then test the performance again on the test set
1 2 3 4 | test <span class="token punctuation">(</span> model <span class="token punctuation">,</span> test_loader <span class="token punctuation">)</span> <span class="token operator">>></span> <span class="token operator">></span> Accuracy of the network on the <span class="token number">10000</span> test images <span class="token punctuation">:</span> <span class="token number">97</span> <span class="token operator">%</span> |
The accuracy of the model has not changed, so we do not need network finetuning again. Proceed to re-export the model to ONNX.js to replace the old one and reload the web page. Now that we turn on the console log we will see the results of the model as follows:
This proves our model has been loaded successfully in ONNX.js
Build demo interface
This part we do not discuss too much. You copy the code and run it. Add in index.html
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> card elevation <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> h3</span> <span class="token punctuation">></span></span> MNIST ONNX.js <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> h3</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> canvas</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> canvas elevation <span class="token punctuation">"</span></span> <span class="token attr-name">id</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> canvas <span class="token punctuation">"</span></span> <span class="token attr-name">width</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> 280 <span class="token punctuation">"</span></span> <span class="token attr-name">height</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> 280 <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> canvas</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> button <span class="token punctuation">"</span></span> <span class="token attr-name">id</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> clear-button <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> CLEAR <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> predictions <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-col <span class="token punctuation">"</span></span> <span class="token attr-name">id</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-0 <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-bar-container <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-bar <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-number <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> 0 <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-col <span class="token punctuation">"</span></span> <span class="token attr-name">id</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-1 <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-bar-container <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-bar <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-number <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> 1 <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-col <span class="token punctuation">"</span></span> <span class="token attr-name">id</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-2 <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-bar-container <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-bar <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-number <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> 2 <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-col <span class="token punctuation">"</span></span> <span class="token attr-name">id</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-3 <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-bar-container <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-bar <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-number <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> 3 <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-col <span class="token punctuation">"</span></span> <span class="token attr-name">id</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-4 <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-bar-container <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-bar <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-number <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> 4 <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-col <span class="token punctuation">"</span></span> <span class="token attr-name">id</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-5 <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-bar-container <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-bar <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-number <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> 5 <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-col <span class="token punctuation">"</span></span> <span class="token attr-name">id</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-6 <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-bar-container <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-bar <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-number <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> 6 <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-col <span class="token punctuation">"</span></span> <span class="token attr-name">id</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-7 <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-bar-container <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-bar <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-number <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> 7 <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-col <span class="token punctuation">"</span></span> <span class="token attr-name">id</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-8 <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-bar-container <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-bar <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-number <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> 8 <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-col <span class="token punctuation">"</span></span> <span class="token attr-name">id</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-9 <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-bar-container <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-bar <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"><</span> div</span> <span class="token attr-name">class</span> <span class="token attr-value"><span class="token punctuation">=</span> <span class="token punctuation">"</span> prediction-number <span class="token punctuation">"</span></span> <span class="token punctuation">></span></span> 9 <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> <span class="token tag"><span class="token tag"><span class="token punctuation"></</span> div</span> <span class="token punctuation">></span></span> |
And don’t forget the css style anymore. You can export a separate css file for convenience.
1 2 3 4 | <span class="token operator"><</span> head <span class="token operator">></span> <span class="token operator"><</span> link rel <span class="token operator">=</span> <span class="token string">"stylesheet"</span> href <span class="token operator">=</span> <span class="token string">"style.css"</span> <span class="token operator">/</span> <span class="token operator">></span> <span class="token operator"><</span> <span class="token operator">/</span> head <span class="token operator">></span> |
The content of the style.css file you refer to in the source code
Running again we get the following interface
Now we will go through the code of the main handlers
Main processing code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | <span class="token operator"><</span> script <span class="token operator">></span> <span class="token keyword">const</span> <span class="token constant">CANVAS_SIZE</span> <span class="token operator">=</span> <span class="token number">280</span> <span class="token punctuation">;</span> <span class="token keyword">const</span> <span class="token constant">CANVAS_SCALE</span> <span class="token operator">=</span> <span class="token number">0.5</span> <span class="token punctuation">;</span> <span class="token keyword">const</span> canvas <span class="token operator">=</span> document <span class="token punctuation">.</span> <span class="token function">getElementById</span> <span class="token punctuation">(</span> <span class="token string">"canvas"</span> <span class="token punctuation">)</span> <span class="token punctuation">;</span> <span class="token keyword">const</span> ctx <span class="token operator">=</span> canvas <span class="token punctuation">.</span> <span class="token function">getContext</span> <span class="token punctuation">(</span> <span class="token string">"2d"</span> <span class="token punctuation">)</span> <span class="token punctuation">;</span> <span class="token keyword">const</span> clearButton <span class="token operator">=</span> document <span class="token punctuation">.</span> <span class="token function">getElementById</span> <span class="token punctuation">(</span> <span class="token string">"clear-button"</span> <span class="token punctuation">)</span> <span class="token punctuation">;</span> <span class="token keyword">let</span> isMouseDown <span class="token operator">=</span> <span class="token boolean">false</span> <span class="token punctuation">;</span> <span class="token keyword">let</span> hasIntroText <span class="token operator">=</span> <span class="token boolean">true</span> <span class="token punctuation">;</span> <span class="token keyword">let</span> lastX <span class="token operator">=</span> <span class="token number">0</span> <span class="token punctuation">;</span> <span class="token keyword">let</span> lastY <span class="token operator">=</span> <span class="token number">0</span> <span class="token punctuation">;</span> <span class="token comment">// Load our model.</span> <span class="token keyword">const</span> sess <span class="token operator">=</span> <span class="token keyword">new</span> <span class="token class-name">onnx <span class="token punctuation">.</span> InferenceSession</span> <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token punctuation">;</span> <span class="token keyword">const</span> loadingModelPromise <span class="token operator">=</span> sess <span class="token punctuation">.</span> <span class="token function">loadModel</span> <span class="token punctuation">(</span> <span class="token string">"./onnx_model.onnx"</span> <span class="token punctuation">)</span> <span class="token punctuation">;</span> <span class="token keyword">async</span> <span class="token keyword">function</span> <span class="token function">updatePredictions</span> <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token punctuation">{</span> <span class="token comment">// Get the predictions for the canvas data.</span> <span class="token keyword">const</span> imgData <span class="token operator">=</span> ctx <span class="token punctuation">.</span> <span class="token function">getImageData</span> <span class="token punctuation">(</span> <span class="token number">0</span> <span class="token punctuation">,</span> <span class="token number">0</span> <span class="token punctuation">,</span> <span class="token constant">CANVAS_SIZE</span> <span class="token punctuation">,</span> <span class="token constant">CANVAS_SIZE</span> <span class="token punctuation">)</span> <span class="token punctuation">;</span> <span class="token keyword">const</span> input <span class="token operator">=</span> <span class="token keyword">new</span> <span class="token class-name">onnx <span class="token punctuation">.</span> Tensor</span> <span class="token punctuation">(</span> <span class="token keyword">new</span> <span class="token class-name">Float32Array</span> <span class="token punctuation">(</span> imgData <span class="token punctuation">.</span> data <span class="token punctuation">)</span> <span class="token punctuation">,</span> <span class="token string">"float32"</span> <span class="token punctuation">)</span> <span class="token punctuation">;</span> <span class="token keyword">const</span> outputMap <span class="token operator">=</span> <span class="token keyword">await</span> sess <span class="token punctuation">.</span> <span class="token function">run</span> <span class="token punctuation">(</span> <span class="token punctuation">[</span> input <span class="token punctuation">]</span> <span class="token punctuation">)</span> <span class="token punctuation">;</span> <span class="token keyword">const</span> outputTensor <span class="token operator">=</span> outputMap <span class="token punctuation">.</span> <span class="token function">values</span> <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token punctuation">.</span> <span class="token function">next</span> <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token punctuation">.</span> value <span class="token punctuation">;</span> <span class="token keyword">const</span> predictions <span class="token operator">=</span> outputTensor <span class="token punctuation">.</span> data <span class="token punctuation">;</span> <span class="token keyword">const</span> maxPrediction <span class="token operator">=</span> Math <span class="token punctuation">.</span> <span class="token function">max</span> <span class="token punctuation">(</span> <span class="token operator">...</span> predictions <span class="token punctuation">)</span> <span class="token punctuation">;</span> <span class="token keyword">for</span> <span class="token punctuation">(</span> <span class="token keyword">let</span> i <span class="token operator">=</span> <span class="token number">0</span> <span class="token punctuation">;</span> i <span class="token operator"><</span> predictions <span class="token punctuation">.</span> length <span class="token punctuation">;</span> i <span class="token operator">++</span> <span class="token punctuation">)</span> <span class="token punctuation">{</span> <span class="token keyword">const</span> element <span class="token operator">=</span> document <span class="token punctuation">.</span> <span class="token function">getElementById</span> <span class="token punctuation">(</span> <span class="token template-string"><span class="token template-punctuation string">`</span> <span class="token string">prediction-</span> <span class="token interpolation"><span class="token interpolation-punctuation punctuation">${</span> i <span class="token interpolation-punctuation punctuation">}</span></span> <span class="token template-punctuation string">`</span></span> <span class="token punctuation">)</span> <span class="token punctuation">;</span> element <span class="token punctuation">.</span> children <span class="token punctuation">[</span> <span class="token number">0</span> <span class="token punctuation">]</span> <span class="token punctuation">.</span> children <span class="token punctuation">[</span> <span class="token number">0</span> <span class="token punctuation">]</span> <span class="token punctuation">.</span> style <span class="token punctuation">.</span> height <span class="token operator">=</span> <span class="token template-string"><span class="token template-punctuation string">`</span> <span class="token interpolation"><span class="token interpolation-punctuation punctuation">${</span> predictions <span class="token punctuation">[</span> i <span class="token punctuation">]</span> <span class="token operator">*</span> <span class="token number">100</span> <span class="token interpolation-punctuation punctuation">}</span></span> <span class="token string">%</span> <span class="token template-punctuation string">`</span></span> <span class="token punctuation">;</span> element <span class="token punctuation">.</span> className <span class="token operator">=</span> predictions <span class="token punctuation">[</span> i <span class="token punctuation">]</span> <span class="token operator">===</span> maxPrediction <span class="token operator">?</span> <span class="token string">"prediction-col top-prediction"</span> <span class="token operator">:</span> <span class="token string">"prediction-col"</span> <span class="token punctuation">;</span> <span class="token punctuation">}</span> <span class="token punctuation">}</span> loadingModelPromise <span class="token punctuation">.</span> <span class="token function">then</span> <span class="token punctuation">(</span> <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token operator">=></span> <span class="token punctuation">{</span> canvas <span class="token punctuation">.</span> <span class="token function">addEventListener</span> <span class="token punctuation">(</span> <span class="token string">"mousedown"</span> <span class="token punctuation">,</span> canvasMouseDown <span class="token punctuation">)</span> <span class="token punctuation">;</span> canvas <span class="token punctuation">.</span> <span class="token function">addEventListener</span> <span class="token punctuation">(</span> <span class="token string">"mousemove"</span> <span class="token punctuation">,</span> canvasMouseMove <span class="token punctuation">)</span> <span class="token punctuation">;</span> document <span class="token punctuation">.</span> body <span class="token punctuation">.</span> <span class="token function">addEventListener</span> <span class="token punctuation">(</span> <span class="token string">"mouseup"</span> <span class="token punctuation">,</span> bodyMouseUp <span class="token punctuation">)</span> <span class="token punctuation">;</span> document <span class="token punctuation">.</span> body <span class="token punctuation">.</span> <span class="token function">addEventListener</span> <span class="token punctuation">(</span> <span class="token string">"mouseout"</span> <span class="token punctuation">,</span> bodyMouseOut <span class="token punctuation">)</span> <span class="token punctuation">;</span> clearButton <span class="token punctuation">.</span> <span class="token function">addEventListener</span> <span class="token punctuation">(</span> <span class="token string">"mousedown"</span> <span class="token punctuation">,</span> clearCanvas <span class="token punctuation">)</span> <span class="token punctuation">;</span> ctx <span class="token punctuation">.</span> <span class="token function">clearRect</span> <span class="token punctuation">(</span> <span class="token number">0</span> <span class="token punctuation">,</span> <span class="token number">0</span> <span class="token punctuation">,</span> <span class="token constant">CANVAS_SIZE</span> <span class="token punctuation">,</span> <span class="token constant">CANVAS_SIZE</span> <span class="token punctuation">)</span> <span class="token punctuation">;</span> ctx <span class="token punctuation">.</span> <span class="token function">fillText</span> <span class="token punctuation">(</span> <span class="token string">"Draw a number here!"</span> <span class="token punctuation">,</span> <span class="token constant">CANVAS_SIZE</span> <span class="token operator">/</span> <span class="token number">2</span> <span class="token punctuation">,</span> <span class="token constant">CANVAS_SIZE</span> <span class="token operator">/</span> <span class="token number">2</span> <span class="token punctuation">)</span> <span class="token punctuation">;</span> <span class="token punctuation">}</span> <span class="token punctuation">)</span> <span class="token operator"><</span> <span class="token operator">/</span> script <span class="token operator">></span> |
The handle manipulation of canvas we will not discuss here. We will focus on the updatePredictions()
function. This function will perform data from the canvas every time there is a manipulation of a new line on the canvas. This function takes the image from the current canvas and injects it into the model for prediction. Then update the results to the view. Let’s try to run the model
An error has occurred. Because the canvas input is not compatible with the model we exported. So we will need to recalibrate the model in the PyTorch code again.
Adjust the model again
Notice that the input of the model is an image obtained from the canvas with size 280 280 4 with 4 channels. The input is a channel with size 28 * 28. We will proceed to edit in the forward function of InferenceNet with the following lines
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | <span class="token keyword">def</span> <span class="token function">forward</span> <span class="token punctuation">(</span> self <span class="token punctuation">,</span> x <span class="token punctuation">)</span> <span class="token punctuation">:</span> x <span class="token operator">=</span> x <span class="token punctuation">.</span> reshape <span class="token punctuation">(</span> <span class="token number">280</span> <span class="token punctuation">,</span> <span class="token number">280</span> <span class="token punctuation">,</span> <span class="token number">4</span> <span class="token punctuation">)</span> x <span class="token operator">=</span> torch <span class="token punctuation">.</span> narrow <span class="token punctuation">(</span> x <span class="token punctuation">,</span> dim <span class="token operator">=</span> <span class="token number">2</span> <span class="token punctuation">,</span> start <span class="token operator">=</span> <span class="token number">3</span> <span class="token punctuation">,</span> length <span class="token operator">=</span> <span class="token number">1</span> <span class="token punctuation">)</span> x <span class="token operator">=</span> x <span class="token punctuation">.</span> reshape <span class="token punctuation">(</span> <span class="token number">1</span> <span class="token punctuation">,</span> <span class="token number">1</span> <span class="token punctuation">,</span> <span class="token number">280</span> <span class="token punctuation">,</span> <span class="token number">280</span> <span class="token punctuation">)</span> x <span class="token operator">=</span> F <span class="token punctuation">.</span> avg_pool2d <span class="token punctuation">(</span> x <span class="token punctuation">,</span> <span class="token number">10</span> <span class="token punctuation">,</span> stride <span class="token operator">=</span> <span class="token number">10</span> <span class="token punctuation">)</span> x <span class="token operator">=</span> x <span class="token operator">/</span> <span class="token number">255</span> x <span class="token operator">=</span> <span class="token punctuation">(</span> x <span class="token operator">-</span> MEAN <span class="token punctuation">)</span> <span class="token operator">/</span> STANDARD_DEVIATION x <span class="token operator">=</span> F <span class="token punctuation">.</span> relu <span class="token punctuation">(</span> F <span class="token punctuation">.</span> max_pool2d <span class="token punctuation">(</span> self <span class="token punctuation">.</span> conv1 <span class="token punctuation">(</span> x <span class="token punctuation">)</span> <span class="token punctuation">,</span> <span class="token number">2</span> <span class="token punctuation">)</span> <span class="token punctuation">)</span> x <span class="token operator">=</span> F <span class="token punctuation">.</span> relu <span class="token punctuation">(</span> F <span class="token punctuation">.</span> max_pool2d <span class="token punctuation">(</span> self <span class="token punctuation">.</span> conv2_drop <span class="token punctuation">(</span> self <span class="token punctuation">.</span> conv2 <span class="token punctuation">(</span> x <span class="token punctuation">)</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> <span class="token number">2</span> <span class="token punctuation">)</span> <span class="token punctuation">)</span> x <span class="token operator">=</span> x <span class="token punctuation">.</span> view <span class="token punctuation">(</span> <span class="token operator">-</span> <span class="token number">1</span> <span class="token punctuation">,</span> <span class="token number">320</span> <span class="token punctuation">)</span> x <span class="token operator">=</span> F <span class="token punctuation">.</span> relu <span class="token punctuation">(</span> self <span class="token punctuation">.</span> fc1 <span class="token punctuation">(</span> x <span class="token punctuation">)</span> <span class="token punctuation">)</span> x <span class="token operator">=</span> F <span class="token punctuation">.</span> dropout <span class="token punctuation">(</span> x <span class="token punctuation">,</span> training <span class="token operator">=</span> self <span class="token punctuation">.</span> training <span class="token punctuation">)</span> x <span class="token operator">=</span> self <span class="token punctuation">.</span> fc2 <span class="token punctuation">(</span> x <span class="token punctuation">)</span> <span class="token keyword">return</span> F <span class="token punctuation">.</span> softmax <span class="token punctuation">(</span> x <span class="token punctuation">)</span> |
Then we proceed to export again but thanks to changing the input input
1 2 3 4 5 | model <span class="token operator">=</span> InferenceNet <span class="token punctuation">(</span> <span class="token punctuation">)</span> model <span class="token punctuation">.</span> load_state_dict <span class="token punctuation">(</span> torch <span class="token punctuation">.</span> load <span class="token punctuation">(</span> <span class="token string">'pytorch_model.pt'</span> <span class="token punctuation">)</span> <span class="token punctuation">)</span> dummy_input <span class="token operator">=</span> torch <span class="token punctuation">.</span> zeros <span class="token punctuation">(</span> <span class="token number">280</span> <span class="token operator">*</span> <span class="token number">280</span> <span class="token operator">*</span> <span class="token number">4</span> <span class="token punctuation">)</span> torch <span class="token punctuation">.</span> onnx <span class="token punctuation">.</span> export <span class="token punctuation">(</span> model <span class="token punctuation">,</span> dummy_input <span class="token punctuation">,</span> <span class="token string">'onnx_model.onnx'</span> <span class="token punctuation">,</span> verbose <span class="token operator">=</span> <span class="token boolean">True</span> <span class="token punctuation">)</span> |
Demo run
After having the model, we proceed to refesh the website now no longer see the error. You can test the following demo
Source code
The source code of the article is provided here
Conclude
Thus, together we have tested converting the PyTorch model to ONNX.js. Hopefully this article will help you better understand how it works, the benefits and disadvantages of the under-client model deploying methods. See you in the next posts.