Multitask Learning training combines multiple datasets with Tensorflow

Tram Ho

1. Preface

Multitask problems can be performed at the same time many tasks applied in computer vision. For example, is face analysis making predictions about age, emotion, gender or predicting a flower? How many years have been planted ?, …. However, multitasking problems often require many stickers on a training dataset where we often have trouble finding a dataset that contains all the stickers. I wish. So in this article, I introduce to you the CNN Shared Network model built on the tensorflow framework to help solve the data shortage problem just mentioned above.

2. How does the CNN Shared Network model work?

We first build a training data set that is composed of several datasets depending on the purpose. The combined data set will be put into a CNN Shared Network and will then split into separate branches to perform different tasks. The branch number is equal to the desired output of the model.

The advantage of the CNN Shared Network model is that sharing a network helps the model to learn many lower features from many different datasets to improve accuracy, especially with data-limited tasks and a The model can be used for both classification and regression

3. Building a multitask learning model to predict age, gender and smile

To illustrate the effectiveness of the CNN Shared Network model, I built a demo predicting age, gender, smile based on a BKNET backbone network. For more information about BKNET, see the BKNET paper

3.1. Load data

Here I use two main datasets: IMDB-WIKI Age & Gender Datasets and GENKI-4K Smile Datasets . Code specific data load, you can see in Multitask Age-Gender-Smile . There is a very important detail here that makes it possible for the model to receive combined data from many different datasets. That is, we assign to each data type of a label an index to distinguish when processing data. Index = 1 for Smile, Index = 3 for Age, Index = 4 for Gender. Here I have normalized the data and put the data as a hot vector. Note: we put the data as a hot one by the maximum number of classes a label can have. For example, here age task has a maximum of 7 classes, we take 7

3.2. Model

In this introduction, I use the BKNET model to conduct training. The first data is put into a common network of 4 VGG_BLOCK and then will turn into 3 branches corresponding to 3 tasks: Smile branch, Gender branch, Age branch. At the end of each branch we have a softmax activation classifier that helps sort the multiclass for each label based on the extracted feature.

3.3. Loss function

To be able to train combined data from many different datasets, the processing of data in the loss function is very important.

First we use three network masks (mask) based on the index is passed as we mentioned in section 3.1 . Network masks help distinguish each type of data transmitted.

We then take the input label according to the number of classes per label. Here smile has 2 classes: Smile, Not Smile ; Age has 5 classes divided respectively: 1-13, 14-23, 24-39, 40-55, 56-80 and Gender has 2 classes: Male, Female

In the calculation of the correct prediction (smile_true_pred, age_true_pred, gender_true_pred) for each task we need to multiply the mask because in a batch of data can be included smile, age and gender, so multiplying with the mask helps retrieve The exact predictions correspond to each task. Finally, because the Model part uses softmax activation, the loss function here is cross_entropy. Note: the tf.clip_by_value function eliminates large or zero errors that cause the log function

Finally, the total loss function is equal to the sum of the errors of each task, ensuring a balance between each task. Dividing each separate loss like this helps us to be able to perform many types of loss functions in a model and not affect each other. Here we use the l2 regularizer so we have to add l2_loss to the total loss.

self.total_loss = self.smile_cross_entropy + self.gender_cross_entropy + self.l2_loss + self.age_cross_entropy

4. Result

You can see the entire data processing, model training & prediction as well as your model accuracy in Multitask learning Age-Gender-Smile . Here are some results I have obtained:

Hope this article solves some problems for you about the lack of data as well as implementations in building multitask models. Thanks to everyone for taking the time to read your post

References

Share the news now

Source : Viblo