PyTorch Hooks – A tool that you should not ignore when working with PyTorch

Tram Ho

Foreword

PyTorch is one of the very powerful frameworks with deep learning tasks. It is easy enough to understand and explicit enough so beginners can get started easily but also have the ability to expand and customize very flexibly for complex models or used in new studies, when that model architecture has never seen before . PyTorch is also powerful because the library system is rich with a full range of tasks from processing for image data, audio, natural language to digital signals … and a community of development and sharing. very large. In this article, I would like to share with you about PyTorch Hooks – one of the very effective tools to use when programming and debugging with PyTorch. In this article we will talk about the effects and different uses of this tool.

Misery when debugging models

If you’ve worked with Deep Learning models before, you’ll probably realize that debugging them has never been easier. There are many things that can happen like exploding / vanishing gradients, tensor shape mismatch or a lot of other things that arise in the process of building and training models. To solve those problems sometimes makes you groping into each layer, each node, each feature to really see what our model is doing inside. For those of you who have used static graph-based frameworks like Tensorflow 1, debugging through the session is difficult and painful. I believe that more than once you’ve encountered this type of debug

All calculations must be done through the session so debugging becomes very difficult.

That’s why when working with Tensorflow 2 or PyTorch we have a happier experience. The pattern debugging is almost the same as the syntax of a regular Python program. As in PyTorch we can write commands that print out the value of the variable through each step in the forward function with the regular print statement in Python. However, doing so with small models also makes your code very confusing and big . You’ve probably coded a PyTorch program like this before

With debugging on even larger models like Resnet50 seems to be impossible. It requires you to make a judgment as a prophet to see where the model is going wrong, which layer is working well, which one is not working well. And that’s when you need PyTorch Hooks

What are PyTorch hooks?

Put simply, hooks are a place to listen for events. Similar to webhooks but in deep learning we have model hooks used to listen for events happening to the model. When a trggier method is used with hooks it can send the outputs of each change to hooks. The good thing is that it is possible to record the value of the variable at the time before, after the forward and after the backward process has been updated by the Gradient Descent algorithm. It sounds complicated, but actually using it is very simple. We will start using it shortly.

Example used to save output after each convulution class

Suppose you want to save the result of the image after going through each convulution layer, how will the feature map be created. This can be done by visualization tools, but we can also do it by using hooks.

We also define the model first

Next we define a hook with the purpose of saving the outputs of the layers. This section is fairly simple to define, just like a regular function. When the trigger function (for example, forward ) changes the value, call this hook to save the change value.

After the definition, we instantiate an instance of this hook

Our purpose is to see that the outputs of these conv layers and hooks will be registered with the register_forward_hook(hook) function to listen for changes and send events to the hook during the forward process. We also have other types of register functions like register_backward_hook and register_forward_pre_hook to perform processing in the backward process and the value before the forward occurs. Next we will handle the value of the layer during the forward process and save it to a list of hook_handle

Next we load the dog photo to make sample data by making transforms

After this section, hook values ​​are stored in our hook_output array.

And finally we will visualize our images array

These are the feature maps after the first conv layer. You can see that each different kernel acts on the input image to form a feature map as follows

A look into the layers makes it easy to see what actually the kernels in this Convulution class have done and how they result.

Conclude

PyTorch Hooks is one of the very useful tools for you to debug the AI ​​model when using PyTorch. Combined with the use of other visualization tools like TensorBoard will help us better understand how the model works. I wish you happy and see you in the following articles.

Share the news now

Source : Viblo