Introduction
Hello everyone, today I’m going to babble a little about the Pose Classification. As you know, the problem of body movement or detecting points on the human body is an important problem in this ML industry, because the applicability of this problem is quite diverse: such as detection moves in the supermarket, simulates physical therapy exercises in health, supports PT in gym exercises, bla … bla …
I write this article to summarize what I have learned, if there are shortcomings, please ignore them
This article revolves around Google’s Mediapipe library, if there is time I will mention other libraries later.
Pose Detection and Pose Tracking
Before going into the Pose Classification section, let’s take a look at google how it is possible to detect spots on the human body.
Their solution is based on a paper they suggested, keyword: BlazePose ( https://export.arxiv.org/pdf/2006.10204 ). Based on this solution, they will extract 33 points corresponding to 33 parts of the human body or 25 points corresponding to the upper part of the body in 3 dimensional space (x, y, z) from an RGB video.
BlazePose briefly explains, basically BlazePose is an improvement of the Stacked Hourglass network ( https://export.arxiv.org/pdf/1603.06937 )
Stacked Hourglass Network
Network structure Stacked Hourglass:
The idea of this network is that instead of having a super large encoder-decoder, each Hourglass is responsible for returning a heat-map that predicts the parts of the body. Since this is a superimposed model, the following Hourglass can learn from the guy’s results first.
How do you detect human movement via heat-map? Different from data on human faces (72 landmark keypoints, …), the data on people’s movements is more diverse, so it is difficult to find points on the human body based on the coordinates. Scientists have devised a method to use heat-map to represent an area in the image. Heat-map helps to retain information about that area and our job is to find the peak (brightest point) in the area. For example, with a 256×256 image the heat-map could have a size of 64×64. See the picture for easy understanding:
As the author of the paper mentioned, they will calculate the loss in each predict, which helps them monitor not only the final returns, but also the output of each Hourglass. For example, while moving the body there will be parts that are hidden in front of the camera, it is difficult to distinguish whether the arm is facing to the left or the right. Using the position prediction result of the previous one as input, the model not only notices these positions but also predicts new positions at the same time.
Hourglass Module
Now let’s take a look at the structure of a Hourglass
As you can see from the picture, this is an encode-decoder architecture that downsample the features first and then upsample to retrieve the information and convert it to heat-map. Each layer of the encoder is connected to a layer of the corresponding decoder. And each layer is built based on residual block and bottleneck architecture of resnet, if you do not know about residual block, please read this link ( https://towardsdatascience.com/residual-blocks-building-blocks-of-resnet- fd90ca15d6ec ).
Left: Residual Layer. Right: Bottleneck Block
Bottleneck makes computations easier, corresponding to memory savings.
Now let’s try to enlarge a box in the image above
Each box in the picture is the bottleneck class I mentioned. After each bottleneck there will be a pooling class to remove unnecessary features.
However, the first layer is a bit different, this class uses Convolution 7×7, not 3×3.
As shown above, in the first layer, the first input goes through the combination of Convolution 7×7, BatchNorm, Activation Relu, and continues through the bottleneck layer. Here the output from the bottleneck class will pass through 2 parallel branches. One goes to the MaxPooling class and does the feature extraction, the other connects to the corresponding class in the decoder.
The other 2, 3, and 4 boxes have the same structure and are different from the first box.
The ultimate goal of feature extraction is to create feature maps, which contain the most image-specific information but minimal spatial information. This part is 3 small boxes located between the encode and decoder.
The input after going through the 4 layers of the encoder and a bottom class returns feature maps that are ready to go through the decoder.
As I mentioned earlier, the other branch will go through the bottleneck class and do element-wise adding to the output of the upsample class of the main branch. This was repeated about 4 times until the end.
In the last layer, we can observe the prediction of Hourglass module. This is also known as immediate supervision, you will calculate the loss at the end of each Hourglass stage instead of calculating the loss of the entire model.
The output of a Hourglass module goes through Convolution 1×1, then split into 2 parallel branches. One is used for prediction and the other returns the result as input to the next Hourglass module. Finally, we perform element-wise addition between the input of the network (heatmap) and both outputs of the Hourglass module. P / S: Predictive results go through Convolution 1×1 to get the correct shape, then add each element.
Finally, to build a Stacked Hourglass Network, we need to iterate over these Hourglass modules over and over.
BlazePose
Ok, back to Google’s BlazePose algorithm, you can find out by reading the article “Learn about BlazePose” by Pham Van Toan
It’s not like I’m lazy, guys, it’s just someone who wrote it first.
Pose classification
Ok, you guys understand BlazePose already. After using Pose Landmark Model (BlazePose GHUM 3D) to detect motion on an image, it returns 33 points on the body as shown below:
Python Solution API
Thankfully, Google provides a Python API solution where you only need to import mediapipe library, code a few lines to run. Mediapipe provides a solution of detect pose on still images and videos. Below is the code for still images
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | <span class="token keyword">import</span> cv2 <span class="token keyword">import</span> mediapipe <span class="token keyword">as</span> mp mp_drawing <span class="token operator">=</span> mp <span class="token punctuation">.</span> solutions <span class="token punctuation">.</span> drawing_utils mp_pose <span class="token operator">=</span> mp <span class="token punctuation">.</span> solutions <span class="token punctuation">.</span> pose <span class="token comment"># For static images:</span> <span class="token keyword">with</span> mp_pose <span class="token punctuation">.</span> Pose <span class="token punctuation">(</span> static_image_mode <span class="token operator">=</span> <span class="token boolean">True</span> <span class="token punctuation">,</span> min_detection_confidence <span class="token operator">=</span> <span class="token number">0.5</span> <span class="token punctuation">)</span> <span class="token keyword">as</span> pose <span class="token punctuation">:</span> <span class="token keyword">for</span> idx <span class="token punctuation">,</span> <span class="token builtin">file</span> <span class="token keyword">in</span> <span class="token builtin">enumerate</span> <span class="token punctuation">(</span> file_list <span class="token punctuation">)</span> <span class="token punctuation">:</span> image <span class="token operator">=</span> cv2 <span class="token punctuation">.</span> imread <span class="token punctuation">(</span> <span class="token builtin">file</span> <span class="token punctuation">)</span> image_height <span class="token punctuation">,</span> image_width <span class="token punctuation">,</span> _ <span class="token operator">=</span> image <span class="token punctuation">.</span> shape <span class="token comment"># Convert the BGR image to RGB before processing.</span> results <span class="token operator">=</span> pose <span class="token punctuation">.</span> process <span class="token punctuation">(</span> cv2 <span class="token punctuation">.</span> cvtColor <span class="token punctuation">(</span> image <span class="token punctuation">,</span> cv2 <span class="token punctuation">.</span> COLOR_BGR2RGB <span class="token punctuation">)</span> <span class="token punctuation">)</span> <span class="token keyword">if</span> <span class="token keyword">not</span> results <span class="token punctuation">.</span> pose_landmarks <span class="token punctuation">:</span> <span class="token keyword">continue</span> <span class="token keyword">print</span> <span class="token punctuation">(</span> <span class="token string-interpolation"><span class="token string">f'Nose coordinates: ('</span></span> <span class="token string-interpolation"><span class="token string">f'</span> <span class="token interpolation"><span class="token punctuation">{</span> results <span class="token punctuation">.</span> pose_landmarks <span class="token punctuation">.</span> landmark <span class="token punctuation">[</span> mp_pose <span class="token punctuation">.</span> PoseLandmark <span class="token punctuation">.</span> NOSE <span class="token punctuation">]</span> <span class="token punctuation">.</span> x <span class="token operator">*</span> image_width <span class="token punctuation">}</span></span> <span class="token string">, '</span></span> <span class="token string-interpolation"><span class="token string">f'</span> <span class="token interpolation"><span class="token punctuation">{</span> results <span class="token punctuation">.</span> pose_landmarks <span class="token punctuation">.</span> landmark <span class="token punctuation">[</span> mp_pose <span class="token punctuation">.</span> PoseLandmark <span class="token punctuation">.</span> NOSE <span class="token punctuation">]</span> <span class="token punctuation">.</span> y <span class="token operator">*</span> image_height <span class="token punctuation">}</span></span> <span class="token string">)'</span></span> <span class="token punctuation">)</span> <span class="token comment"># Draw pose landmarks on the image.</span> annotated_image <span class="token operator">=</span> image <span class="token punctuation">.</span> copy <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token comment"># Use mp_pose.UPPER_BODY_POSE_CONNECTIONS for drawing below when</span> <span class="token comment"># upper_body_only is set to True.</span> mp_drawing <span class="token punctuation">.</span> draw_landmarks <span class="token punctuation">(</span> annotated_image <span class="token punctuation">,</span> results <span class="token punctuation">.</span> pose_landmarks <span class="token punctuation">,</span> mp_pose <span class="token punctuation">.</span> POSE_CONNECTIONS <span class="token punctuation">)</span> cv2 <span class="token punctuation">.</span> imwrite <span class="token punctuation">(</span> <span class="token string">'/tmp/annotated_image'</span> <span class="token operator">+</span> <span class="token builtin">str</span> <span class="token punctuation">(</span> idx <span class="token punctuation">)</span> <span class="token operator">+</span> <span class="token string">'.png'</span> <span class="token punctuation">,</span> annotated_image <span class="token punctuation">)</span> |
Import library used to process opencv images and mediapipe library. Then set 2 variables to use the mediapipe functions
1 2 3 | mp_drawing <span class="token operator">=</span> mp <span class="token punctuation">.</span> solutions <span class="token punctuation">.</span> drawing_utils <span class="token comment"># instance vẽ line và point lên ảnh</span> mp_pose <span class="token operator">=</span> mp <span class="token punctuation">.</span> solutions <span class="token punctuation">.</span> pose <span class="token comment"># instance detect pose</span> |
With still images, you have to adjust the parameters for the Pose class as follows
1 2 | mp_pose <span class="token punctuation">.</span> Pose <span class="token punctuation">(</span> static_image_mode <span class="token operator">=</span> <span class="token boolean">True</span> <span class="token punctuation">,</span> min_detection_confidence <span class="token operator">=</span> <span class="token number">0.5</span> <span class="token punctuation">)</span> |
There is a model in the library of mediapipe, so you don’t need to train again.
1 2 | results <span class="token operator">=</span> pose <span class="token punctuation">.</span> process <span class="token punctuation">(</span> cv2 <span class="token punctuation">.</span> cvtColor <span class="token punctuation">(</span> image <span class="token punctuation">,</span> cv2 <span class="token punctuation">.</span> COLOR_BGR2RGB <span class="token punctuation">)</span> <span class="token punctuation">)</span> |
If you want to print the detect output then remember to mediapose returns the coordinates as a ratio to the image so you need to multiply by the height and width to plot the correct coordinates on the image.
1 2 3 4 5 6 7 | <span class="token keyword">print</span> <span class="token punctuation">(</span> <span class="token string-interpolation"><span class="token string">f'Nose coordinates: ('</span></span> <span class="token string-interpolation"><span class="token string">f'</span> <span class="token interpolation"><span class="token punctuation">{</span> results <span class="token punctuation">.</span> pose_landmarks <span class="token punctuation">.</span> landmark <span class="token punctuation">[</span> mp_pose <span class="token punctuation">.</span> PoseLandmark <span class="token punctuation">.</span> NOSE <span class="token punctuation">]</span> <span class="token punctuation">.</span> x <span class="token operator">*</span> image_width <span class="token punctuation">}</span></span> <span class="token string">, '</span></span> f' <span class="token punctuation">{</span> results <span class="token punctuation">.</span> pose_landmarks <span class="token punctuation">.</span> landmark <span class="token punctuation">[</span> mp_pose <span class="token punctuation">.</span> PoseLandmark <span class="token punctuation">.</span> NOSE <span class="token punctuation">]</span> <span class="token punctuation">.</span> y <span class="token operator">*</span> image_height <span class="token punctuation">}</span> <span class="token punctuation">)</span> ' <span class="token punctuation">)</span> |
If you want to draw points or lines on an image, use the draw_landmarks () function.
1 2 | mp_drawing <span class="token punctuation">.</span> draw_landmarks <span class="token punctuation">(</span> annotated_image <span class="token punctuation">,</span> results <span class="token punctuation">.</span> pose_landmarks <span class="token punctuation">,</span> mp_pose <span class="token punctuation">.</span> POSE_CONNECTIONS <span class="token punctuation">)</span> |
Ok for video, use cv2 (opencv) to capture video, create while True loop to process each frame one by one, anyone who works a lot with opencv must be familiar.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | <span class="token comment"># For webcam input:</span> cap <span class="token operator">=</span> cv2 <span class="token punctuation">.</span> VideoCapture <span class="token punctuation">(</span> <span class="token number">0</span> <span class="token punctuation">)</span> <span class="token keyword">with</span> mp_pose <span class="token punctuation">.</span> Pose <span class="token punctuation">(</span> min_detection_confidence <span class="token operator">=</span> <span class="token number">0.5</span> <span class="token punctuation">,</span> min_tracking_confidence <span class="token operator">=</span> <span class="token number">0.5</span> <span class="token punctuation">)</span> <span class="token keyword">as</span> pose <span class="token punctuation">:</span> <span class="token keyword">while</span> cap <span class="token punctuation">.</span> isOpened <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token punctuation">:</span> success <span class="token punctuation">,</span> image <span class="token operator">=</span> cap <span class="token punctuation">.</span> read <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token keyword">if</span> <span class="token keyword">not</span> success <span class="token punctuation">:</span> <span class="token keyword">print</span> <span class="token punctuation">(</span> <span class="token string">"Ignoring empty camera frame."</span> <span class="token punctuation">)</span> <span class="token comment"># If loading a video, use 'break' instead of 'continue'.</span> <span class="token keyword">continue</span> <span class="token comment"># Flip the image horizontally for a later selfie-view display, and convert</span> <span class="token comment"># the BGR image to RGB.</span> image <span class="token operator">=</span> cv2 <span class="token punctuation">.</span> cvtColor <span class="token punctuation">(</span> cv2 <span class="token punctuation">.</span> flip <span class="token punctuation">(</span> image <span class="token punctuation">,</span> <span class="token number">1</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> cv2 <span class="token punctuation">.</span> COLOR_BGR2RGB <span class="token punctuation">)</span> <span class="token comment"># To improve performance, optionally mark the image as not writeable to</span> <span class="token comment"># pass by reference.</span> image <span class="token punctuation">.</span> flags <span class="token punctuation">.</span> writeable <span class="token operator">=</span> <span class="token boolean">False</span> results <span class="token operator">=</span> pose <span class="token punctuation">.</span> process <span class="token punctuation">(</span> image <span class="token punctuation">)</span> <span class="token comment"># Draw the pose annotation on the image.</span> image <span class="token punctuation">.</span> flags <span class="token punctuation">.</span> writeable <span class="token operator">=</span> <span class="token boolean">True</span> image <span class="token operator">=</span> cv2 <span class="token punctuation">.</span> cvtColor <span class="token punctuation">(</span> image <span class="token punctuation">,</span> cv2 <span class="token punctuation">.</span> COLOR_RGB2BGR <span class="token punctuation">)</span> mp_drawing <span class="token punctuation">.</span> draw_landmarks <span class="token punctuation">(</span> image <span class="token punctuation">,</span> results <span class="token punctuation">.</span> pose_landmarks <span class="token punctuation">,</span> mp_pose <span class="token punctuation">.</span> POSE_CONNECTIONS <span class="token punctuation">)</span> cv2 <span class="token punctuation">.</span> imshow <span class="token punctuation">(</span> <span class="token string">'MediaPipe Pose'</span> <span class="token punctuation">,</span> image <span class="token punctuation">)</span> <span class="token keyword">if</span> cv2 <span class="token punctuation">.</span> waitKey <span class="token punctuation">(</span> <span class="token number">5</span> <span class="token punctuation">)</span> <span class="token operator">&</span> <span class="token number">0xFF</span> <span class="token operator">==</span> <span class="token number">27</span> <span class="token punctuation">:</span> <span class="token keyword">break</span> cap <span class="token punctuation">.</span> release <span class="token punctuation">(</span> <span class="token punctuation">)</span> |
Basically, it is not much different from static image processing, only when calling the instance of mediapipe pose does not need the static_image_mode
. On the contrary, there is an additional min_tracking_confidence
used to min_tracking_confidence
each frame person, the value is between 0 and 1, the higher it is, the better it is but the delay (more processing time).
Based on BlazaPose, mediapipe returns a point with x, y, z coordinates. x, y then as I said above is returned as a ratio of width and height and z is the depth of the center of gravity of the body. The closer the value is to the camera, the z is returned as a ratio to the width like x. z is only predictable in full-body mechanism (33 points) and upper-body mechanism (25 points) is not supported.
Link colab of Pose Classification:https://colab.research.google.com/drive/19txHpN8exWhstO6WVkfmYYVC6uug_oVR
Prepare dataset
Ok, let’s return to Pose Classification, they simply use KNN (K-nearest neightbors) to group the same actions. About what is KNN invite readers ( https://machinelearningcoban.com/2017/01/08/knn/ ).
First you need to prepare training data. For example, in push-ups there are two states: up and down as shown below. So I need to prepare images of both of these states.
Google supports the code to gen the data, all you have to do is understand it.
You need to create a folder with the following structure:
1 2 3 4 5 6 7 8 9 10 11 12 | fitness_poses_images_in/ pushups_up/ image_001.jpg image_002.jpg ... pushups_down/ image_001.jpg image_002.jpg ... fitness_poses_images_out/ fitness_poses_csvs_out/ |
The photo here is an image of the push-up and push-down status.
Call the bootstrap_helper instance from the BootstrapHelper () class. This class is responsible for detecting 33 points on the body, drawing points and linking in lines on the image, outputting a csv file.
1 2 3 4 5 6 7 8 9 10 11 12 | bootstrap_images_in_folder <span class="token operator">=</span> <span class="token string">'fitness_poses_images_in'</span> <span class="token comment"># Output folders for bootstrapped images and CSVs.</span> bootstrap_images_out_folder <span class="token operator">=</span> <span class="token string">'fitness_poses_images_out'</span> bootstrap_csvs_out_folder <span class="token operator">=</span> <span class="token string">'fitness_poses_csvs_out'</span> bootstrap_helper <span class="token operator">=</span> BootstrapHelper <span class="token punctuation">(</span> images_in_folder <span class="token operator">=</span> bootstrap_images_in_folder <span class="token punctuation">,</span> images_out_folder <span class="token operator">=</span> bootstrap_images_out_folder <span class="token punctuation">,</span> csvs_out_folder <span class="token operator">=</span> bootstrap_csvs_out_folder <span class="token punctuation">,</span> <span class="token punctuation">)</span> |
Check the status of these folders
1 2 3 | bootstrap_instance <span class="token punctuation">.</span> print_images_in_statistics <span class="token punctuation">(</span> <span class="token punctuation">)</span> bootstrap_instance <span class="token punctuation">.</span> print_images_out_statistics <span class="token punctuation">(</span> <span class="token punctuation">)</span> |
Erase the action detected error by comparing images with csv files
1 2 3 | bootstrap_instance <span class="token punctuation">.</span> align_images_and_csvs <span class="token punctuation">(</span> print_removed_items <span class="token operator">=</span> <span class="token boolean">False</span> <span class="token punctuation">)</span> bootstrap_instance <span class="token punctuation">.</span> print_images_out_statistics <span class="token punctuation">(</span> <span class="token punctuation">)</span> |
Pose Embedding
Embed 33 points on the body by calculating the distance between points such as: left shoulder distance – left hip, right shoulder – right hip, left knee – left heel, right knee – right heel, …
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | <span class="token keyword">def</span> <span class="token function">_get_pose_distance_embedding</span> <span class="token punctuation">(</span> self <span class="token punctuation">,</span> landmarks <span class="token punctuation">)</span> <span class="token punctuation">:</span> <span class="token triple-quoted-string string">"""Converts pose landmarks into 3D embedding. We use several pairwise 3D distances to form pose embedding. All distances include X and Y components with sign. We differnt types of pairs to cover different pose classes. Feel free to remove some or add new. Args: landmarks - NumPy array with 3D landmarks of shape (N, 3). Result: Numpy array with pose embedding of shape (M, 3) where `M` is the number of pairwise distances. """</span> embedding <span class="token operator">=</span> np <span class="token punctuation">.</span> array <span class="token punctuation">(</span> <span class="token punctuation">[</span> <span class="token comment"># One joint.</span> self <span class="token punctuation">.</span> _get_distance <span class="token punctuation">(</span> self <span class="token punctuation">.</span> _get_average_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'left_hip'</span> <span class="token punctuation">,</span> <span class="token string">'right_hip'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> self <span class="token punctuation">.</span> _get_average_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'left_shoulder'</span> <span class="token punctuation">,</span> <span class="token string">'right_shoulder'</span> <span class="token punctuation">)</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'left_shoulder'</span> <span class="token punctuation">,</span> <span class="token string">'left_elbow'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'right_shoulder'</span> <span class="token punctuation">,</span> <span class="token string">'right_elbow'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'left_elbow'</span> <span class="token punctuation">,</span> <span class="token string">'left_wrist'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'right_elbow'</span> <span class="token punctuation">,</span> <span class="token string">'right_wrist'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'left_hip'</span> <span class="token punctuation">,</span> <span class="token string">'left_knee'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'right_hip'</span> <span class="token punctuation">,</span> <span class="token string">'right_knee'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'left_knee'</span> <span class="token punctuation">,</span> <span class="token string">'left_ankle'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'right_knee'</span> <span class="token punctuation">,</span> <span class="token string">'right_ankle'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> <span class="token comment"># Two joints.</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'left_shoulder'</span> <span class="token punctuation">,</span> <span class="token string">'left_wrist'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'right_shoulder'</span> <span class="token punctuation">,</span> <span class="token string">'right_wrist'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'left_hip'</span> <span class="token punctuation">,</span> <span class="token string">'left_ankle'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'right_hip'</span> <span class="token punctuation">,</span> <span class="token string">'right_ankle'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> <span class="token comment"># Four joints.</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'left_hip'</span> <span class="token punctuation">,</span> <span class="token string">'left_wrist'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'right_hip'</span> <span class="token punctuation">,</span> <span class="token string">'right_wrist'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> <span class="token comment"># Five joints.</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'left_shoulder'</span> <span class="token punctuation">,</span> <span class="token string">'left_ankle'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'right_shoulder'</span> <span class="token punctuation">,</span> <span class="token string">'right_ankle'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'left_hip'</span> <span class="token punctuation">,</span> <span class="token string">'left_wrist'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'right_hip'</span> <span class="token punctuation">,</span> <span class="token string">'right_wrist'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> <span class="token comment"># Cross body.</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'left_elbow'</span> <span class="token punctuation">,</span> <span class="token string">'right_elbow'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'left_knee'</span> <span class="token punctuation">,</span> <span class="token string">'right_knee'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'left_wrist'</span> <span class="token punctuation">,</span> <span class="token string">'right_wrist'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> self <span class="token punctuation">.</span> _get_distance_by_names <span class="token punctuation">(</span> landmarks <span class="token punctuation">,</span> <span class="token string">'left_ankle'</span> <span class="token punctuation">,</span> <span class="token string">'right_ankle'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> <span class="token keyword">return</span> embedding |
Pose Classifier
Call the instance of class PoseClassifier (), which classifies the class by comparing the similarity between the length of points on the human body of the image and the database.
1 2 3 4 5 6 | pose_classifier_instance <span class="token operator">=</span> PoseClassifier <span class="token punctuation">(</span> pose_samples_folder <span class="token operator">=</span> bootstrap_csvs_out_folder <span class="token punctuation">,</span> pose_embedder <span class="token operator">=</span> pose_embed_instance <span class="token punctuation">,</span> top_n_by_max_distance <span class="token operator">=</span> <span class="token number">30</span> <span class="token punctuation">,</span> top_n_by_mean_distance <span class="token operator">=</span> <span class="token number">10</span> <span class="token punctuation">)</span> |
Pose Smoothing
When detecting 33 points on the body, the points will move wildly in each frame even though there is no difference with the naked eye, so I need to smooth the prediction data in the frames using the EMA (Exponential) algorithm. moving average). The theory can be found here: ( https://www.tohaitrieu.net/exponential-moving-average-ema/ ).
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | <span class="token keyword">class</span> <span class="token class-name">EMADictSmoothing</span> <span class="token punctuation">(</span> <span class="token builtin">object</span> <span class="token punctuation">)</span> <span class="token punctuation">:</span> <span class="token triple-quoted-string string">"""Smoothes pose classification."""</span> <span class="token keyword">def</span> <span class="token function">__init__</span> <span class="token punctuation">(</span> self <span class="token punctuation">,</span> window_size <span class="token operator">=</span> <span class="token number">10</span> <span class="token punctuation">,</span> alpha <span class="token operator">=</span> <span class="token number">0.2</span> <span class="token punctuation">)</span> <span class="token punctuation">:</span> self <span class="token punctuation">.</span> _window_size <span class="token operator">=</span> window_size self <span class="token punctuation">.</span> _alpha <span class="token operator">=</span> alpha self <span class="token punctuation">.</span> _data_in_window <span class="token operator">=</span> <span class="token punctuation">[</span> <span class="token punctuation">]</span> <span class="token keyword">def</span> <span class="token function">__call__</span> <span class="token punctuation">(</span> self <span class="token punctuation">,</span> data <span class="token punctuation">)</span> <span class="token punctuation">:</span> <span class="token triple-quoted-string string">"""Smoothes given pose classification. Smoothing is done by computing Exponential Moving Average for every pose class observed in the given time window. Missed pose classes arre replaced with 0. Args: data: Dictionary with pose classification. Sample: { 'pushups_down': 8, 'pushups_up': 2, } Result: Dictionary in the same format but with smoothed and float instead of integer values. Sample: { 'pushups_down': 8.3, 'pushups_up': 1.7, } """</span> <span class="token comment"># Add new data to the beginning of the window for simpler code.</span> self <span class="token punctuation">.</span> _data_in_window <span class="token punctuation">.</span> insert <span class="token punctuation">(</span> <span class="token number">0</span> <span class="token punctuation">,</span> data <span class="token punctuation">)</span> self <span class="token punctuation">.</span> _data_in_window <span class="token operator">=</span> self <span class="token punctuation">.</span> _data_in_window <span class="token punctuation">[</span> <span class="token punctuation">:</span> self <span class="token punctuation">.</span> _window_size <span class="token punctuation">]</span> <span class="token comment"># Get all keys.</span> keys <span class="token operator">=</span> <span class="token builtin">set</span> <span class="token punctuation">(</span> <span class="token punctuation">[</span> key <span class="token keyword">for</span> data <span class="token keyword">in</span> self <span class="token punctuation">.</span> _data_in_window <span class="token keyword">for</span> key <span class="token punctuation">,</span> _ <span class="token keyword">in</span> data <span class="token punctuation">.</span> items <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token punctuation">]</span> <span class="token punctuation">)</span> <span class="token comment"># Get smoothed values.</span> smoothed_data <span class="token operator">=</span> <span class="token builtin">dict</span> <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token keyword">for</span> key <span class="token keyword">in</span> keys <span class="token punctuation">:</span> factor <span class="token operator">=</span> <span class="token number">1.0</span> top_sum <span class="token operator">=</span> <span class="token number">0.0</span> bottom_sum <span class="token operator">=</span> <span class="token number">0.0</span> <span class="token keyword">for</span> data <span class="token keyword">in</span> self <span class="token punctuation">.</span> _data_in_window <span class="token punctuation">:</span> value <span class="token operator">=</span> data <span class="token punctuation">[</span> key <span class="token punctuation">]</span> <span class="token keyword">if</span> key <span class="token keyword">in</span> data <span class="token keyword">else</span> <span class="token number">0.0</span> top_sum <span class="token operator">+=</span> factor <span class="token operator">*</span> value bottom_sum <span class="token operator">+=</span> factor <span class="token comment"># Update factor.</span> factor <span class="token operator">*=</span> <span class="token punctuation">(</span> <span class="token number">1.0</span> <span class="token operator">-</span> self <span class="token punctuation">.</span> _alpha <span class="token punctuation">)</span> smoothed_data <span class="token punctuation">[</span> key <span class="token punctuation">]</span> <span class="token operator">=</span> top_sum <span class="token operator">/</span> bottom_sum <span class="token keyword">return</span> smoothed_data |
Check TRA
Ok, test with a video that is not in the current dataset by synthesizing all the code above. The code below I removed the repetition counter and visualizer for easy viewing. You can add to get the most intuitive results.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | <span class="token comment"># Run classification on a video.</span> <span class="token keyword">import</span> os <span class="token keyword">import</span> tqdm <span class="token keyword">from</span> mediapipe <span class="token punctuation">.</span> python <span class="token punctuation">.</span> solutions <span class="token keyword">import</span> drawing_utils <span class="token keyword">as</span> mp_drawing <span class="token comment"># Open output video.</span> out_video <span class="token operator">=</span> cv2 <span class="token punctuation">.</span> VideoWriter <span class="token punctuation">(</span> out_video_path <span class="token punctuation">,</span> cv2 <span class="token punctuation">.</span> VideoWriter_fourcc <span class="token punctuation">(</span> <span class="token operator">*</span> <span class="token string">'mp4v'</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> video_fps <span class="token punctuation">,</span> <span class="token punctuation">(</span> video_width <span class="token punctuation">,</span> video_height <span class="token punctuation">)</span> <span class="token punctuation">)</span> frame_idx <span class="token operator">=</span> <span class="token number">0</span> output_frame <span class="token operator">=</span> <span class="token boolean">None</span> <span class="token keyword">with</span> tqdm <span class="token punctuation">.</span> tqdm <span class="token punctuation">(</span> total <span class="token operator">=</span> video_n_frames <span class="token punctuation">,</span> position <span class="token operator">=</span> <span class="token number">0</span> <span class="token punctuation">,</span> leave <span class="token operator">=</span> <span class="token boolean">True</span> <span class="token punctuation">)</span> <span class="token keyword">as</span> pbar <span class="token punctuation">:</span> <span class="token keyword">while</span> <span class="token boolean">True</span> <span class="token punctuation">:</span> <span class="token comment"># Get next frame of the video.</span> success <span class="token punctuation">,</span> input_frame <span class="token operator">=</span> video_cap <span class="token punctuation">.</span> read <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token keyword">if</span> <span class="token keyword">not</span> success <span class="token punctuation">:</span> <span class="token keyword">break</span> <span class="token comment"># Run pose tracker.</span> input_frame <span class="token operator">=</span> cv2 <span class="token punctuation">.</span> cvtColor <span class="token punctuation">(</span> input_frame <span class="token punctuation">,</span> cv2 <span class="token punctuation">.</span> COLOR_BGR2RGB <span class="token punctuation">)</span> result <span class="token operator">=</span> pose_tracker <span class="token punctuation">.</span> process <span class="token punctuation">(</span> image <span class="token operator">=</span> input_frame <span class="token punctuation">)</span> pose_landmarks <span class="token operator">=</span> result <span class="token punctuation">.</span> pose_landmarks <span class="token comment"># Draw pose prediction.</span> output_frame <span class="token operator">=</span> input_frame <span class="token punctuation">.</span> copy <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token keyword">if</span> pose_landmarks <span class="token keyword">is</span> <span class="token keyword">not</span> <span class="token boolean">None</span> <span class="token punctuation">:</span> mp_drawing <span class="token punctuation">.</span> draw_landmarks <span class="token punctuation">(</span> image <span class="token operator">=</span> output_frame <span class="token punctuation">,</span> landmark_list <span class="token operator">=</span> pose_landmarks <span class="token punctuation">,</span> connections <span class="token operator">=</span> mp_pose <span class="token punctuation">.</span> POSE_CONNECTIONS <span class="token punctuation">)</span> <span class="token keyword">if</span> pose_landmarks <span class="token keyword">is</span> <span class="token keyword">not</span> <span class="token boolean">None</span> <span class="token punctuation">:</span> <span class="token comment"># Get landmarks.</span> frame_height <span class="token punctuation">,</span> frame_width <span class="token operator">=</span> output_frame <span class="token punctuation">.</span> shape <span class="token punctuation">[</span> <span class="token number">0</span> <span class="token punctuation">]</span> <span class="token punctuation">,</span> output_frame <span class="token punctuation">.</span> shape <span class="token punctuation">[</span> <span class="token number">1</span> <span class="token punctuation">]</span> pose_landmarks <span class="token operator">=</span> np <span class="token punctuation">.</span> array <span class="token punctuation">(</span> <span class="token punctuation">[</span> <span class="token punctuation">[</span> lmk <span class="token punctuation">.</span> x <span class="token operator">*</span> frame_width <span class="token punctuation">,</span> lmk <span class="token punctuation">.</span> y <span class="token operator">*</span> frame_height <span class="token punctuation">,</span> lmk <span class="token punctuation">.</span> z <span class="token operator">*</span> frame_width <span class="token punctuation">]</span> <span class="token keyword">for</span> lmk <span class="token keyword">in</span> pose_landmarks <span class="token punctuation">.</span> landmark <span class="token punctuation">]</span> <span class="token punctuation">,</span> dtype <span class="token operator">=</span> np <span class="token punctuation">.</span> float32 <span class="token punctuation">)</span> <span class="token keyword">assert</span> pose_landmarks <span class="token punctuation">.</span> shape <span class="token operator">==</span> <span class="token punctuation">(</span> <span class="token number">33</span> <span class="token punctuation">,</span> <span class="token number">3</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> <span class="token string">'Unexpected landmarks shape: {}'</span> <span class="token punctuation">.</span> <span class="token builtin">format</span> <span class="token punctuation">(</span> pose_landmarks <span class="token punctuation">.</span> shape <span class="token punctuation">)</span> <span class="token comment"># Classify the pose on the current frame.</span> pose_classification <span class="token operator">=</span> pose_classifier <span class="token punctuation">(</span> pose_landmarks <span class="token punctuation">)</span> <span class="token comment"># Smooth classification using EMA.</span> pose_classification_filtered <span class="token operator">=</span> pose_classification_filter <span class="token punctuation">(</span> pose_classification <span class="token punctuation">)</span> <span class="token keyword">else</span> <span class="token punctuation">:</span> <span class="token comment"># No pose => no classification on current frame.</span> pose_classification <span class="token operator">=</span> <span class="token boolean">None</span> <span class="token comment"># Still add empty classification to the filter to maintaing correct</span> <span class="token comment"># smoothing for future frames.</span> pose_classification_filtered <span class="token operator">=</span> pose_classification_filter <span class="token punctuation">(</span> <span class="token builtin">dict</span> <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token punctuation">)</span> pose_classification_filtered <span class="token operator">=</span> <span class="token boolean">None</span> <span class="token comment"># Draw classification plot and repetition counter.</span> output_frame <span class="token operator">=</span> pose_classification_visualizer <span class="token punctuation">(</span> frame <span class="token operator">=</span> output_frame <span class="token punctuation">,</span> pose_classification <span class="token operator">=</span> pose_classification <span class="token punctuation">,</span> pose_classification_filtered <span class="token operator">=</span> pose_classification_filtered <span class="token punctuation">,</span> repetitions_count <span class="token operator">=</span> repetitions_count <span class="token punctuation">)</span> <span class="token comment"># Save the output frame.</span> out_video <span class="token punctuation">.</span> write <span class="token punctuation">(</span> cv2 <span class="token punctuation">.</span> cvtColor <span class="token punctuation">(</span> np <span class="token punctuation">.</span> array <span class="token punctuation">(</span> output_frame <span class="token punctuation">)</span> <span class="token punctuation">,</span> cv2 <span class="token punctuation">.</span> COLOR_RGB2BGR <span class="token punctuation">)</span> <span class="token punctuation">)</span> <span class="token comment"># Show intermediate frames of the video to track progress.</span> <span class="token keyword">if</span> frame_idx <span class="token operator">%</span> <span class="token number">50</span> <span class="token operator">==</span> <span class="token number">0</span> <span class="token punctuation">:</span> show_image <span class="token punctuation">(</span> output_frame <span class="token punctuation">)</span> frame_idx <span class="token operator">+=</span> <span class="token number">1</span> pbar <span class="token punctuation">.</span> update <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token comment"># Close output video.</span> out_video <span class="token punctuation">.</span> release <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token comment"># Release MediaPipe resources.</span> pose_tracker <span class="token punctuation">.</span> close <span class="token punctuation">(</span> <span class="token punctuation">)</span> <span class="token comment"># Show the last frame of the video.</span> <span class="token keyword">if</span> output_frame <span class="token keyword">is</span> <span class="token keyword">not</span> <span class="token boolean">None</span> <span class="token punctuation">:</span> show_image <span class="token punctuation">(</span> output_frame <span class="token punctuation">)</span> |
Epilogue
Stop, I’m lazy, don’t write anymore. Thank you for reading here
Reference links
https://export.arxiv.org/pdf/2006.10204
https://export.arxiv.org/pdf/1603.06937
https://towardsdatascience.com/using-hourglass-networks-to-understand-human-poses-1e40e349fa15
https://colab.research.google.com/drive/19txHpN8exWhstO6WVkfmYYVC6uug_oVR#scrollTo=4lXymkneOjgZ
https://google.github.io/mediapipe/solutions/pose_classification.html
https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html
https://www.tohaitrieu.net/exponential-moving-average-ema/
…