Lời mở đầu
Chào các bạn, hôm nay tôi sẽ lảm nhảm một chút về Pose Classification. Như các bạn đã biết hiện giờ bài toán về chuyển động của cơ thể hay phát hiện các điểm trên cơ thể người là một bài toán quan trọng trong ngành ML này, do tính ứng dụng của bài toán này khá đa dạng: như phát hiện động tác trong siêu thị, mô phỏng các bài tập vật lý trị liệu trong y tế, hỗ trợ PT trong các bài tập gym, bla…bla…
Tôi viết bài này cũng để tổng kết lại những gì mình đã tìm hiểu, nếu có thiếu sót mong m.n bỏ qua cho
Bài viết này xoay quanh thư viện Mediapipe của Google, nếu còn thời gian tôi sẽ đề cập tới các thư viện khác sau.
Pose Detection and Pose Tracking
Trước khi đi vào phần Pose Classification, chúng ta nhìn qua một chút bên google làm thế nào detect các điểm trên cơ thể người.
Giải pháp của họ dựa trên một paper mà họ đề xuất, keyword: BlazePose (https://export.arxiv.org/pdf/2006.10204). Dựa trên giải pháp này, họ sẽ bóc tách ra 33 điểm tương ứng 33 bộ phận của cơ thể người hoặc 25 điểm tương ứng phần trên của cơ thể trong không gian 3 chiều (x, y, z) từ một RGB video.
Giải thích sơ qua về BlazePose, về cơ bản BlazePose là bản cải tiến của mạng Stacked Hourglass (https://export.arxiv.org/pdf/1603.06937)
Stacked Hourglass Network
Cấu trúc mạng Stacked Hourglass:
Ý tưởng của mạng này thay vì ta có một bộ encoder-decoder siêu to thì mỗi Hourglass (đồng hồ cát) có nhiệm vụ trả về một heat-map dự đoán các phần trên cơ thể. Do đây là mô hình chồng lên nhau nên thằng Hourglass sau có thể học hỏi kết quả của thằng trước.
Làm thế nào để phát hiện chuyển động người qua heat-map (bản đồ nhiệt) ? Khác với dữ liệu về khuôn mặt con người (72 landmark keypoint, …) thì dữ liệu về chuyển động của người đa dạng hơn nên rất khó tìm các điểm trên cơ thể người dựa vào tọa độ. Các nhà khoa học đã nghĩ ra phương thức sử dụng heat-map để đại diện cho một vùng trên ảnh. Heat-map giúp giữ lại thông tin vùng đó và việc của chúng ta là tìm ra cao điểm (điểm sáng nhất) trong vùng. Lấy ví dụ với một ảnh 256×256 thì heat-map có thể có kích thước 64×64. Bạn xem hình cho dễ hiểu:
Theo như tác giả paper nói qua thì họ sẽ tính loss ở mỗi lần predict, điều này giúp họ giám sát không chỉ kết quả trả về cuối cùng mà còn giám sát đầu ra của mỗi Hourglass. Lấy ví dụ, trong khi chuyển động cơ thể sẽ có các bộ phận bị che đi trước camera, rất khó để phân biệt cánh tay hướng bên trái hay bên phải. Với việc sử dụng kết quả dự đoán vị trí của thằng trước làm input, mô hình không chỉ chú ý các vị trí này mà trong lúc đó còn dự đoán thêm vị trí mới.
Hourglass Module
Giờ chúng ta nhìn qua cấu trúc của một Hourglass
Thông qua hình vẽ chắc các bạn cũng thấy đây là kiến trúc encode-decoder có nhiệm vụ downsample các features trước rồi upsample để hồi phục thông tin và chuyển thành heat-map. Mỗi một layer của encoder kết nối với một layer của decoder tương ứng. Còn mỗi layer xây dựng dựa trên kiến trúc residual block và bottleneck của resnet, bạn nào chưa biết về residual block thì vào link này đọc nhé (https://towardsdatascience.com/residual-blocks-building-blocks-of-resnet-fd90ca15d6ec).
Bên trái: Residual Layer. Bên phải: Bottleneck Block
Bottleneck giúp cho việc tính toán dễ dàng hơn, tương ứng với việc tiết kiệm bộ nhớ.
Giờ chúng ta thử phóng to một hộp trong hình trên
Mỗi hộp trong hình là lớp bottleneck mà tôi đã nhắc đến. Sau mỗi bottleneck sẽ có một lớp pooling để loại bỏ các feature không cần thiết.
Tuy nhiên, lớp đầu tiên có một chút khác biệt, lớp này dùng Convolution 7×7 chứ không phải 3×3.
Như hình trên, ở lớp đầu tiên, input đầu tiên đi qua tổ hợp giữa Convolution 7×7, BatchNorm, Activation Relu, tiếp tục đi qua lớp bottleneck. Ở đây output từ lớp bottleneck sẽ di qua 2 nhánh song song. Một cái đi lớp MaxPooling và thực hiện việc trích xuất đặc trưng, cái còn lại sẽ kết nối với lớp tương ứng ở phần decoder.
Các hộp 2, 3, 4 khác có cấu tạo giống nhau và khác hộp đầu tiên.
Mục đích cuối cùng của việc trích xuất đặc trưng là tạo ra feature maps, chứa thông tin đặc trưng của ảnh nhiều nhất nhưng thông tin không gian thấp nhất. Phần này chính là 3 cái hộp nhỏ nằm giữa encode và decoder.
Input sau khi đi qua 4 lớp của encoder và một lớp bottom trả về feature maps đã sẵn sàng đi qua decoder.
Như tôi đã nhắc tới trên kia thì nhánh còn lại sẽ đi qua lớp bottleneck và thực hiện việc cộng element-wise với output của lớp upsample của nhánh chính. Việc này lặp đi lặp lại khoảng 4 lần cho tới khi kết thúc.
Ở lớp cuối cùng, chúng ta có thể quan sát độ chính xác của mỗi dự đoán (prediction of Hourglass module). Cái này còn gọi là immediate supervision, bạn tính toán loss ở cuối mỗi Hourglass stage thay vì tính loss của cả mô hình.
Output của một Hourglass module đi qua Convolution 1×1, sau đó chia làm 2 nhánh song song. Một dùng để dự đoán và một trả về kết quả sẽ làm đầu vào cho Hourglass module tiếp theo. Cuối cùng, chúng ta thực hiện việc cộng từng phần tử (element-wise addition) giữa đầu vào của network (heatmap) và cả 2 đầu ra của Hourglass module. P/S: kết quả dự đoán đi qua Convolution 1×1 để cho đúng shape thì mới cộng từng phần tử được.
Cuối cùng để xây dựng Stacked Hourglass Network, chúng ta cần thực hiện lặp đi lặp lại các Hourglass module này.
BlazePose
Ok quay lại thuật toán BlazePose của Google, các bạn có thể tìm hiểu bằng cách đọc bài “Tìm hiểu về BlazePose” của tác giả Phạm Văn Toàn
Không phải tôi lười đâu các bạn, chỉ là có người viết trước rồi thôi.
Pose classification
Ok, coi như các bạn đã hiểu BlazePose rồi. Sau khi dùng Pose Landmark Model (BlazePose GHUM 3D) để detect chuyển động trên một ảnh sẽ trả về 33 điểm trên cơ thể như hình dưới đây:
Python Solution API
Rất may Google cung cấp giải pháp Python API mà bạn chỉ cần import thư viện mediapipe, code vài dòng là chạy được. Mediapipe cung cấp giải pháp detect pose trên ảnh tĩnh và trên video. Dưới đây là code dành cho ảnh tĩnh
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | <span class="token keyword">import</span> cv2 <span class="token keyword">import</span> mediapipe <span class="token keyword">as</span> mp mp_drawing <span class="token operator">=</span> mp<span class="token punctuation">.</span>solutions<span class="token punctuation">.</span>drawing_utils mp_pose <span class="token operator">=</span> mp<span class="token punctuation">.</span>solutions<span class="token punctuation">.</span>pose <span class="token comment"># For static images:</span> <span class="token keyword">with</span> mp_pose<span class="token punctuation">.</span>Pose<span class="token punctuation">(</span> static_image_mode<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> min_detection_confidence<span class="token operator">=</span><span class="token number">0.5</span><span class="token punctuation">)</span> <span class="token keyword">as</span> pose<span class="token punctuation">:</span> <span class="token keyword">for</span> idx<span class="token punctuation">,</span> <span class="token builtin">file</span> <span class="token keyword">in</span> <span class="token builtin">enumerate</span><span class="token punctuation">(</span>file_list<span class="token punctuation">)</span><span class="token punctuation">:</span> image <span class="token operator">=</span> cv2<span class="token punctuation">.</span>imread<span class="token punctuation">(</span><span class="token builtin">file</span><span class="token punctuation">)</span> image_height<span class="token punctuation">,</span> image_width<span class="token punctuation">,</span> _ <span class="token operator">=</span> image<span class="token punctuation">.</span>shape <span class="token comment"># Convert the BGR image to RGB before processing.</span> results <span class="token operator">=</span> pose<span class="token punctuation">.</span>process<span class="token punctuation">(</span>cv2<span class="token punctuation">.</span>cvtColor<span class="token punctuation">(</span>image<span class="token punctuation">,</span> cv2<span class="token punctuation">.</span>COLOR_BGR2RGB<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">if</span> <span class="token keyword">not</span> results<span class="token punctuation">.</span>pose_landmarks<span class="token punctuation">:</span> <span class="token keyword">continue</span> <span class="token keyword">print</span><span class="token punctuation">(</span> <span class="token string-interpolation"><span class="token string">f'Nose coordinates: ('</span></span> <span class="token string-interpolation"><span class="token string">f'</span><span class="token interpolation"><span class="token punctuation">{</span>results<span class="token punctuation">.</span>pose_landmarks<span class="token punctuation">.</span>landmark<span class="token punctuation">[</span>mp_pose<span class="token punctuation">.</span>PoseLandmark<span class="token punctuation">.</span>NOSE<span class="token punctuation">]</span><span class="token punctuation">.</span>x <span class="token operator">*</span> image_width<span class="token punctuation">}</span></span><span class="token string">, '</span></span> <span class="token string-interpolation"><span class="token string">f'</span><span class="token interpolation"><span class="token punctuation">{</span>results<span class="token punctuation">.</span>pose_landmarks<span class="token punctuation">.</span>landmark<span class="token punctuation">[</span>mp_pose<span class="token punctuation">.</span>PoseLandmark<span class="token punctuation">.</span>NOSE<span class="token punctuation">]</span><span class="token punctuation">.</span>y <span class="token operator">*</span> image_height<span class="token punctuation">}</span></span><span class="token string">)'</span></span> <span class="token punctuation">)</span> <span class="token comment"># Draw pose landmarks on the image.</span> annotated_image <span class="token operator">=</span> image<span class="token punctuation">.</span>copy<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># Use mp_pose.UPPER_BODY_POSE_CONNECTIONS for drawing below when</span> <span class="token comment"># upper_body_only is set to True.</span> mp_drawing<span class="token punctuation">.</span>draw_landmarks<span class="token punctuation">(</span> annotated_image<span class="token punctuation">,</span> results<span class="token punctuation">.</span>pose_landmarks<span class="token punctuation">,</span> mp_pose<span class="token punctuation">.</span>POSE_CONNECTIONS<span class="token punctuation">)</span> cv2<span class="token punctuation">.</span>imwrite<span class="token punctuation">(</span><span class="token string">'/tmp/annotated_image'</span> <span class="token operator">+</span> <span class="token builtin">str</span><span class="token punctuation">(</span>idx<span class="token punctuation">)</span> <span class="token operator">+</span> <span class="token string">'.png'</span><span class="token punctuation">,</span> annotated_image<span class="token punctuation">)</span> |
Import thư viện dùng để xử lý ảnh opencv và thư viện mediapipe. Sau đó đặt 2 biến để dùng các hàm của mediapipe
1 2 3 | mp_drawing <span class="token operator">=</span> mp<span class="token punctuation">.</span>solutions<span class="token punctuation">.</span>drawing_utils <span class="token comment"># instance vẽ line và point lên ảnh</span> mp_pose <span class="token operator">=</span> mp<span class="token punctuation">.</span>solutions<span class="token punctuation">.</span>pose <span class="token comment"># instance detect pose</span> |
Với ảnh tĩnh, các bạn phải chỉnh lại tham số cho class Pose như sau
1 2 | mp_pose<span class="token punctuation">.</span>Pose<span class="token punctuation">(</span>static_image_mode<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> min_detection_confidence<span class="token operator">=</span><span class="token number">0.5</span><span class="token punctuation">)</span> |
Trong thư viện của mediapipe có sẵn model nên bạn không cần train lại.
1 2 | results <span class="token operator">=</span> pose<span class="token punctuation">.</span>process<span class="token punctuation">(</span>cv2<span class="token punctuation">.</span>cvtColor<span class="token punctuation">(</span>image<span class="token punctuation">,</span> cv2<span class="token punctuation">.</span>COLOR_BGR2RGB<span class="token punctuation">)</span><span class="token punctuation">)</span> |
Nếu bạn muốn in giá trị detect ra thì hãy nhớ mediapose trả về tọa độ dưới dạng tỉ lệ so với ảnh nên bạn cần phải nhân với chiều cao và chiều rộng để vẽ đúng tọa độ trên ảnh.
1 2 3 4 5 6 7 | <span class="token keyword">print</span><span class="token punctuation">(</span> <span class="token string-interpolation"><span class="token string">f'Nose coordinates: ('</span></span> <span class="token string-interpolation"><span class="token string">f'</span><span class="token interpolation"><span class="token punctuation">{</span>results<span class="token punctuation">.</span>pose_landmarks<span class="token punctuation">.</span>landmark<span class="token punctuation">[</span>mp_pose<span class="token punctuation">.</span>PoseLandmark<span class="token punctuation">.</span>NOSE<span class="token punctuation">]</span><span class="token punctuation">.</span>x <span class="token operator">*</span> image_width<span class="token punctuation">}</span></span><span class="token string">, '</span></span> f'<span class="token punctuation">{</span>results<span class="token punctuation">.</span>pose_landmarks<span class="token punctuation">.</span>landmark<span class="token punctuation">[</span>mp_pose<span class="token punctuation">.</span>PoseLandmark<span class="token punctuation">.</span>NOSE<span class="token punctuation">]</span><span class="token punctuation">.</span>y <span class="token operator">*</span> image_height<span class="token punctuation">}</span> <span class="token punctuation">)</span>' <span class="token punctuation">)</span> |
Nếu bạn muốn vẽ point hoặc line lên ảnh thì dùng hàm draw_landmarks()
1 2 | mp_drawing<span class="token punctuation">.</span>draw_landmarks<span class="token punctuation">(</span>annotated_image<span class="token punctuation">,</span> results<span class="token punctuation">.</span>pose_landmarks<span class="token punctuation">,</span> mp_pose<span class="token punctuation">.</span>POSE_CONNECTIONS<span class="token punctuation">)</span> |
Ok đối với video thì dùng cv2 (opencv) để capture video, tạo vòng lặp while True để xử lý từng frame một, ai làm việc nhiều với opencv chắc là quen rồi.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | <span class="token comment"># For webcam input:</span> cap <span class="token operator">=</span> cv2<span class="token punctuation">.</span>VideoCapture<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span> <span class="token keyword">with</span> mp_pose<span class="token punctuation">.</span>Pose<span class="token punctuation">(</span> min_detection_confidence<span class="token operator">=</span><span class="token number">0.5</span><span class="token punctuation">,</span> min_tracking_confidence<span class="token operator">=</span><span class="token number">0.5</span><span class="token punctuation">)</span> <span class="token keyword">as</span> pose<span class="token punctuation">:</span> <span class="token keyword">while</span> cap<span class="token punctuation">.</span>isOpened<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span> success<span class="token punctuation">,</span> image <span class="token operator">=</span> cap<span class="token punctuation">.</span>read<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">if</span> <span class="token keyword">not</span> success<span class="token punctuation">:</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Ignoring empty camera frame."</span><span class="token punctuation">)</span> <span class="token comment"># If loading a video, use 'break' instead of 'continue'.</span> <span class="token keyword">continue</span> <span class="token comment"># Flip the image horizontally for a later selfie-view display, and convert</span> <span class="token comment"># the BGR image to RGB.</span> image <span class="token operator">=</span> cv2<span class="token punctuation">.</span>cvtColor<span class="token punctuation">(</span>cv2<span class="token punctuation">.</span>flip<span class="token punctuation">(</span>image<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> cv2<span class="token punctuation">.</span>COLOR_BGR2RGB<span class="token punctuation">)</span> <span class="token comment"># To improve performance, optionally mark the image as not writeable to</span> <span class="token comment"># pass by reference.</span> image<span class="token punctuation">.</span>flags<span class="token punctuation">.</span>writeable <span class="token operator">=</span> <span class="token boolean">False</span> results <span class="token operator">=</span> pose<span class="token punctuation">.</span>process<span class="token punctuation">(</span>image<span class="token punctuation">)</span> <span class="token comment"># Draw the pose annotation on the image.</span> image<span class="token punctuation">.</span>flags<span class="token punctuation">.</span>writeable <span class="token operator">=</span> <span class="token boolean">True</span> image <span class="token operator">=</span> cv2<span class="token punctuation">.</span>cvtColor<span class="token punctuation">(</span>image<span class="token punctuation">,</span> cv2<span class="token punctuation">.</span>COLOR_RGB2BGR<span class="token punctuation">)</span> mp_drawing<span class="token punctuation">.</span>draw_landmarks<span class="token punctuation">(</span> image<span class="token punctuation">,</span> results<span class="token punctuation">.</span>pose_landmarks<span class="token punctuation">,</span> mp_pose<span class="token punctuation">.</span>POSE_CONNECTIONS<span class="token punctuation">)</span> cv2<span class="token punctuation">.</span>imshow<span class="token punctuation">(</span><span class="token string">'MediaPipe Pose'</span><span class="token punctuation">,</span> image<span class="token punctuation">)</span> <span class="token keyword">if</span> cv2<span class="token punctuation">.</span>waitKey<span class="token punctuation">(</span><span class="token number">5</span><span class="token punctuation">)</span> <span class="token operator">&</span> <span class="token number">0xFF</span> <span class="token operator">==</span> <span class="token number">27</span><span class="token punctuation">:</span> <span class="token keyword">break</span> cap<span class="token punctuation">.</span>release<span class="token punctuation">(</span><span class="token punctuation">)</span> |
Về cơ bản cũng không khác gì nhiều với việc xử lý ảnh tĩnh, chỉ có cái khi gọi instance của mediapipe pose thì không cần param static_image_mode
. Ngược lại có thêm param min_tracking_confidence
dùng để tracking người từng frame, giá trị trong khoảng từ 0 đến 1, càng cao thì càng bắt đúng nhưng tăng độ trễ (tốn thời gian xử lý hơn).
Dựa trên BlazaPose, mediapipe trả về một điểm có tọa độ x, y, z. x, y thì như tôi nói ở trên được trả về dưới dạng tỉ lệ của chiều rộng và chiều cao còn z là độ sâu của điểm trọng tâm cơ thể. Giá trị càng nhỏ càng gần camera, z được trả về dưới dạng tỉ lệ với chiều rộng như x. z chỉ dự đoán được trong cơ chế full-body (33 điểm) còn cơ chế upper-body (25 điểm) thì không hỗ trợ.
Link colab của Pose Classification: https://colab.research.google.com/drive/19txHpN8exWhstO6WVkfmYYVC6uug_oVR
Prepare dataset
Ok quay lại về Pose Classification, đơn giản là họ dùng KNN (K-nearest neightbors) để group lại các động tác giống nhau. Về phần KNN là gì mời bạn đọc (https://machinelearningcoban.com/2017/01/08/knn/).
Đầu tiên cần chuẩn bị dữ liệu huấn luyện đã. Giả dụ trong động tác chống đẩy có 2 trạng thái: lên và xuống như hình dưới. Như vậy tôi cần chuẩn bị ảnh của cả 2 trạng thái này.
Google hỗ trợ code để gen dữ liệu, việc còn lại của các bạn là hiểu nó.
Bạn cần tạo 1 folder có cấu trúc như sau:
1 2 3 4 5 6 7 8 9 10 11 12 | fitness_poses_images_in/ pushups_up/ image_001.jpg image_002.jpg ... pushups_down/ image_001.jpg image_002.jpg ... fitness_poses_images_out/ fitness_poses_csvs_out/ |
Ảnh ở đây là ảnh của trạng thái chống đẩy lên và trạng thái chống đẩy xuống.
Gọi instance bootstrap_helper từ class BootstrapHelper(). Class này có nhiệm vụ detect 33 điểm trên cơ thể, vẽ các điểm và nối thành đường lên ảnh, xuất ra file csv.
1 2 3 4 5 6 7 8 9 10 11 12 | bootstrap_images_in_folder <span class="token operator">=</span> <span class="token string">'fitness_poses_images_in'</span> <span class="token comment"># Output folders for bootstrapped images and CSVs.</span> bootstrap_images_out_folder <span class="token operator">=</span> <span class="token string">'fitness_poses_images_out'</span> bootstrap_csvs_out_folder <span class="token operator">=</span> <span class="token string">'fitness_poses_csvs_out'</span> bootstrap_helper <span class="token operator">=</span> BootstrapHelper<span class="token punctuation">(</span> images_in_folder<span class="token operator">=</span>bootstrap_images_in_folder<span class="token punctuation">,</span> images_out_folder<span class="token operator">=</span>bootstrap_images_out_folder<span class="token punctuation">,</span> csvs_out_folder<span class="token operator">=</span>bootstrap_csvs_out_folder<span class="token punctuation">,</span> <span class="token punctuation">)</span> |
Kiểm tra trạng thái các folder này
1 2 3 | bootstrap_instance<span class="token punctuation">.</span>print_images_in_statistics<span class="token punctuation">(</span><span class="token punctuation">)</span> bootstrap_instance<span class="token punctuation">.</span>print_images_out_statistics<span class="token punctuation">(</span><span class="token punctuation">)</span> |
Xóa đi các động tác bị detect lỗi bằng cách so ảnh với file csv
1 2 3 | bootstrap_instance<span class="token punctuation">.</span>align_images_and_csvs<span class="token punctuation">(</span>print_removed_items<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span> bootstrap_instance<span class="token punctuation">.</span>print_images_out_statistics<span class="token punctuation">(</span><span class="token punctuation">)</span> |
Pose Embedding
Embed 33 điểm trên cơ thể bằng cách tính khoảng cách giữa các điểm như: khoảng cách vai trái – hông trái, vai phải – hông phải, đầu gối trái – gót chân trái, đầu gối phải – gót chân phải, …
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | <span class="token keyword">def</span> <span class="token function">_get_pose_distance_embedding</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> landmarks<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token triple-quoted-string string">"""Converts pose landmarks into 3D embedding. We use several pairwise 3D distances to form pose embedding. All distances include X and Y components with sign. We differnt types of pairs to cover different pose classes. Feel free to remove some or add new. Args: landmarks - NumPy array with 3D landmarks of shape (N, 3). Result: Numpy array with pose embedding of shape (M, 3) where `M` is the number of pairwise distances. """</span> embedding <span class="token operator">=</span> np<span class="token punctuation">.</span>array<span class="token punctuation">(</span><span class="token punctuation">[</span> <span class="token comment"># One joint.</span> self<span class="token punctuation">.</span>_get_distance<span class="token punctuation">(</span> self<span class="token punctuation">.</span>_get_average_by_names<span class="token punctuation">(</span>landmarks<span class="token punctuation">,</span> <span class="token string">'left_hip'</span><span class="token punctuation">,</span> <span class="token string">'right_hip'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>_get_average_by_names<span class="token punctuation">(</span>landmarks<span class="token punctuation">,</span> <span class="token string">'left_shoulder'</span><span class="token punctuation">,</span> <span class="token string">'right_shoulder'</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span> landmarks<span class="token punctuation">,</span> <span class="token string">'left_shoulder'</span><span class="token punctuation">,</span> <span class="token string">'left_elbow'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span> landmarks<span class="token punctuation">,</span> <span class="token string">'right_shoulder'</span><span class="token punctuation">,</span> <span class="token string">'right_elbow'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span>landmarks<span class="token punctuation">,</span> <span class="token string">'left_elbow'</span><span class="token punctuation">,</span> <span class="token string">'left_wrist'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span> landmarks<span class="token punctuation">,</span> <span class="token string">'right_elbow'</span><span class="token punctuation">,</span> <span class="token string">'right_wrist'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span>landmarks<span class="token punctuation">,</span> <span class="token string">'left_hip'</span><span class="token punctuation">,</span> <span class="token string">'left_knee'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span>landmarks<span class="token punctuation">,</span> <span class="token string">'right_hip'</span><span class="token punctuation">,</span> <span class="token string">'right_knee'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span>landmarks<span class="token punctuation">,</span> <span class="token string">'left_knee'</span><span class="token punctuation">,</span> <span class="token string">'left_ankle'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span> landmarks<span class="token punctuation">,</span> <span class="token string">'right_knee'</span><span class="token punctuation">,</span> <span class="token string">'right_ankle'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token comment"># Two joints.</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span> landmarks<span class="token punctuation">,</span> <span class="token string">'left_shoulder'</span><span class="token punctuation">,</span> <span class="token string">'left_wrist'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span> landmarks<span class="token punctuation">,</span> <span class="token string">'right_shoulder'</span><span class="token punctuation">,</span> <span class="token string">'right_wrist'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span>landmarks<span class="token punctuation">,</span> <span class="token string">'left_hip'</span><span class="token punctuation">,</span> <span class="token string">'left_ankle'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span>landmarks<span class="token punctuation">,</span> <span class="token string">'right_hip'</span><span class="token punctuation">,</span> <span class="token string">'right_ankle'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token comment"># Four joints.</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span>landmarks<span class="token punctuation">,</span> <span class="token string">'left_hip'</span><span class="token punctuation">,</span> <span class="token string">'left_wrist'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span>landmarks<span class="token punctuation">,</span> <span class="token string">'right_hip'</span><span class="token punctuation">,</span> <span class="token string">'right_wrist'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token comment"># Five joints.</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span> landmarks<span class="token punctuation">,</span> <span class="token string">'left_shoulder'</span><span class="token punctuation">,</span> <span class="token string">'left_ankle'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span> landmarks<span class="token punctuation">,</span> <span class="token string">'right_shoulder'</span><span class="token punctuation">,</span> <span class="token string">'right_ankle'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span>landmarks<span class="token punctuation">,</span> <span class="token string">'left_hip'</span><span class="token punctuation">,</span> <span class="token string">'left_wrist'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span>landmarks<span class="token punctuation">,</span> <span class="token string">'right_hip'</span><span class="token punctuation">,</span> <span class="token string">'right_wrist'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token comment"># Cross body.</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span> landmarks<span class="token punctuation">,</span> <span class="token string">'left_elbow'</span><span class="token punctuation">,</span> <span class="token string">'right_elbow'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span>landmarks<span class="token punctuation">,</span> <span class="token string">'left_knee'</span><span class="token punctuation">,</span> <span class="token string">'right_knee'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span> landmarks<span class="token punctuation">,</span> <span class="token string">'left_wrist'</span><span class="token punctuation">,</span> <span class="token string">'right_wrist'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>_get_distance_by_names<span class="token punctuation">(</span> landmarks<span class="token punctuation">,</span> <span class="token string">'left_ankle'</span><span class="token punctuation">,</span> <span class="token string">'right_ankle'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token keyword">return</span> embedding |
Pose Classifier
Gọi instance của class PoseClassifier(), có tác dụng phân loại class bằng cách so sánh độ tương đồng giữa chiều dài các điểm trên cơ thể người của ảnh và database.
1 2 3 4 5 6 | pose_classifier_instance <span class="token operator">=</span> PoseClassifier<span class="token punctuation">(</span> pose_samples_folder<span class="token operator">=</span>bootstrap_csvs_out_folder<span class="token punctuation">,</span> pose_embedder<span class="token operator">=</span>pose_embed_instance<span class="token punctuation">,</span> top_n_by_max_distance<span class="token operator">=</span><span class="token number">30</span><span class="token punctuation">,</span> top_n_by_mean_distance<span class="token operator">=</span><span class="token number">10</span><span class="token punctuation">)</span> |
Pose Smoothing
Khi detect 33 điểm trên cơ thể, các điểm sẽ dịch chuyển rất loạn ở từng frame dù cho nhìn bằng mắt thường không thấy có gì khác biệt, vì vậy tôi cần làm mịn dữ liệu dự đoán ở các frame bằng thuật toán EMA (Exponential moving average). Lý thuyết các bạn có thể tham khảo ở đây: (https://www.tohaitrieu.net/exponential-moving-average-ema/).
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | <span class="token keyword">class</span> <span class="token class-name">EMADictSmoothing</span><span class="token punctuation">(</span><span class="token builtin">object</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token triple-quoted-string string">"""Smoothes pose classification."""</span> <span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> window_size<span class="token operator">=</span><span class="token number">10</span><span class="token punctuation">,</span> alpha<span class="token operator">=</span><span class="token number">0.2</span><span class="token punctuation">)</span><span class="token punctuation">:</span> self<span class="token punctuation">.</span>_window_size <span class="token operator">=</span> window_size self<span class="token punctuation">.</span>_alpha <span class="token operator">=</span> alpha self<span class="token punctuation">.</span>_data_in_window <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> <span class="token keyword">def</span> <span class="token function">__call__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> data<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token triple-quoted-string string">"""Smoothes given pose classification. Smoothing is done by computing Exponential Moving Average for every pose class observed in the given time window. Missed pose classes arre replaced with 0. Args: data: Dictionary with pose classification. Sample: { 'pushups_down': 8, 'pushups_up': 2, } Result: Dictionary in the same format but with smoothed and float instead of integer values. Sample: { 'pushups_down': 8.3, 'pushups_up': 1.7, } """</span> <span class="token comment"># Add new data to the beginning of the window for simpler code.</span> self<span class="token punctuation">.</span>_data_in_window<span class="token punctuation">.</span>insert<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> data<span class="token punctuation">)</span> self<span class="token punctuation">.</span>_data_in_window <span class="token operator">=</span> self<span class="token punctuation">.</span>_data_in_window<span class="token punctuation">[</span><span class="token punctuation">:</span>self<span class="token punctuation">.</span>_window_size<span class="token punctuation">]</span> <span class="token comment"># Get all keys.</span> keys <span class="token operator">=</span> <span class="token builtin">set</span><span class="token punctuation">(</span><span class="token punctuation">[</span>key <span class="token keyword">for</span> data <span class="token keyword">in</span> self<span class="token punctuation">.</span>_data_in_window <span class="token keyword">for</span> key<span class="token punctuation">,</span> _ <span class="token keyword">in</span> data<span class="token punctuation">.</span>items<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># Get smoothed values.</span> smoothed_data <span class="token operator">=</span> <span class="token builtin">dict</span><span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">for</span> key <span class="token keyword">in</span> keys<span class="token punctuation">:</span> factor <span class="token operator">=</span> <span class="token number">1.0</span> top_sum <span class="token operator">=</span> <span class="token number">0.0</span> bottom_sum <span class="token operator">=</span> <span class="token number">0.0</span> <span class="token keyword">for</span> data <span class="token keyword">in</span> self<span class="token punctuation">.</span>_data_in_window<span class="token punctuation">:</span> value <span class="token operator">=</span> data<span class="token punctuation">[</span>key<span class="token punctuation">]</span> <span class="token keyword">if</span> key <span class="token keyword">in</span> data <span class="token keyword">else</span> <span class="token number">0.0</span> top_sum <span class="token operator">+=</span> factor <span class="token operator">*</span> value bottom_sum <span class="token operator">+=</span> factor <span class="token comment"># Update factor.</span> factor <span class="token operator">*=</span> <span class="token punctuation">(</span><span class="token number">1.0</span> <span class="token operator">-</span> self<span class="token punctuation">.</span>_alpha<span class="token punctuation">)</span> smoothed_data<span class="token punctuation">[</span>key<span class="token punctuation">]</span> <span class="token operator">=</span> top_sum <span class="token operator">/</span> bottom_sum <span class="token keyword">return</span> smoothed_data |
Test
Ok, test với một video không có trong bộ dữ liệu hiện tại bằng cách tổng hợp tất cả các đoạn code ở trên. Code bên dưới tôi đã loại bỏ các phần repetition counter và visualizer để dễ nhìn. Các bạn có thể thêm vào để kết quả trả về trực quan nhất.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | <span class="token comment"># Run classification on a video.</span> <span class="token keyword">import</span> os <span class="token keyword">import</span> tqdm <span class="token keyword">from</span> mediapipe<span class="token punctuation">.</span>python<span class="token punctuation">.</span>solutions <span class="token keyword">import</span> drawing_utils <span class="token keyword">as</span> mp_drawing <span class="token comment"># Open output video.</span> out_video <span class="token operator">=</span> cv2<span class="token punctuation">.</span>VideoWriter<span class="token punctuation">(</span>out_video_path<span class="token punctuation">,</span> cv2<span class="token punctuation">.</span>VideoWriter_fourcc<span class="token punctuation">(</span><span class="token operator">*</span><span class="token string">'mp4v'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> video_fps<span class="token punctuation">,</span> <span class="token punctuation">(</span>video_width<span class="token punctuation">,</span> video_height<span class="token punctuation">)</span><span class="token punctuation">)</span> frame_idx <span class="token operator">=</span> <span class="token number">0</span> output_frame <span class="token operator">=</span> <span class="token boolean">None</span> <span class="token keyword">with</span> tqdm<span class="token punctuation">.</span>tqdm<span class="token punctuation">(</span>total<span class="token operator">=</span>video_n_frames<span class="token punctuation">,</span> position<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> leave<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span> <span class="token keyword">as</span> pbar<span class="token punctuation">:</span> <span class="token keyword">while</span> <span class="token boolean">True</span><span class="token punctuation">:</span> <span class="token comment"># Get next frame of the video.</span> success<span class="token punctuation">,</span> input_frame <span class="token operator">=</span> video_cap<span class="token punctuation">.</span>read<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">if</span> <span class="token keyword">not</span> success<span class="token punctuation">:</span> <span class="token keyword">break</span> <span class="token comment"># Run pose tracker.</span> input_frame <span class="token operator">=</span> cv2<span class="token punctuation">.</span>cvtColor<span class="token punctuation">(</span>input_frame<span class="token punctuation">,</span> cv2<span class="token punctuation">.</span>COLOR_BGR2RGB<span class="token punctuation">)</span> result <span class="token operator">=</span> pose_tracker<span class="token punctuation">.</span>process<span class="token punctuation">(</span>image<span class="token operator">=</span>input_frame<span class="token punctuation">)</span> pose_landmarks <span class="token operator">=</span> result<span class="token punctuation">.</span>pose_landmarks <span class="token comment"># Draw pose prediction.</span> output_frame <span class="token operator">=</span> input_frame<span class="token punctuation">.</span>copy<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">if</span> pose_landmarks <span class="token keyword">is</span> <span class="token keyword">not</span> <span class="token boolean">None</span><span class="token punctuation">:</span> mp_drawing<span class="token punctuation">.</span>draw_landmarks<span class="token punctuation">(</span> image<span class="token operator">=</span>output_frame<span class="token punctuation">,</span> landmark_list<span class="token operator">=</span>pose_landmarks<span class="token punctuation">,</span> connections<span class="token operator">=</span>mp_pose<span class="token punctuation">.</span>POSE_CONNECTIONS<span class="token punctuation">)</span> <span class="token keyword">if</span> pose_landmarks <span class="token keyword">is</span> <span class="token keyword">not</span> <span class="token boolean">None</span><span class="token punctuation">:</span> <span class="token comment"># Get landmarks.</span> frame_height<span class="token punctuation">,</span> frame_width <span class="token operator">=</span> output_frame<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> output_frame<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span> pose_landmarks <span class="token operator">=</span> np<span class="token punctuation">.</span>array<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token punctuation">[</span>lmk<span class="token punctuation">.</span>x <span class="token operator">*</span> frame_width<span class="token punctuation">,</span> lmk<span class="token punctuation">.</span>y <span class="token operator">*</span> frame_height<span class="token punctuation">,</span> lmk<span class="token punctuation">.</span>z <span class="token operator">*</span> frame_width<span class="token punctuation">]</span> <span class="token keyword">for</span> lmk <span class="token keyword">in</span> pose_landmarks<span class="token punctuation">.</span>landmark<span class="token punctuation">]</span><span class="token punctuation">,</span> dtype<span class="token operator">=</span>np<span class="token punctuation">.</span>float32<span class="token punctuation">)</span> <span class="token keyword">assert</span> pose_landmarks<span class="token punctuation">.</span>shape <span class="token operator">==</span> <span class="token punctuation">(</span><span class="token number">33</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token string">'Unexpected landmarks shape: {}'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>pose_landmarks<span class="token punctuation">.</span>shape<span class="token punctuation">)</span> <span class="token comment"># Classify the pose on the current frame.</span> pose_classification <span class="token operator">=</span> pose_classifier<span class="token punctuation">(</span>pose_landmarks<span class="token punctuation">)</span> <span class="token comment"># Smooth classification using EMA.</span> pose_classification_filtered <span class="token operator">=</span> pose_classification_filter<span class="token punctuation">(</span>pose_classification<span class="token punctuation">)</span> <span class="token keyword">else</span><span class="token punctuation">:</span> <span class="token comment"># No pose => no classification on current frame.</span> pose_classification <span class="token operator">=</span> <span class="token boolean">None</span> <span class="token comment"># Still add empty classification to the filter to maintaing correct</span> <span class="token comment"># smoothing for future frames.</span> pose_classification_filtered <span class="token operator">=</span> pose_classification_filter<span class="token punctuation">(</span><span class="token builtin">dict</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span> pose_classification_filtered <span class="token operator">=</span> <span class="token boolean">None</span> <span class="token comment"># Draw classification plot and repetition counter.</span> output_frame <span class="token operator">=</span> pose_classification_visualizer<span class="token punctuation">(</span> frame<span class="token operator">=</span>output_frame<span class="token punctuation">,</span> pose_classification<span class="token operator">=</span>pose_classification<span class="token punctuation">,</span> pose_classification_filtered<span class="token operator">=</span>pose_classification_filtered<span class="token punctuation">,</span> repetitions_count<span class="token operator">=</span>repetitions_count<span class="token punctuation">)</span> <span class="token comment"># Save the output frame.</span> out_video<span class="token punctuation">.</span>write<span class="token punctuation">(</span>cv2<span class="token punctuation">.</span>cvtColor<span class="token punctuation">(</span>np<span class="token punctuation">.</span>array<span class="token punctuation">(</span>output_frame<span class="token punctuation">)</span><span class="token punctuation">,</span> cv2<span class="token punctuation">.</span>COLOR_RGB2BGR<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment"># Show intermediate frames of the video to track progress.</span> <span class="token keyword">if</span> frame_idx <span class="token operator">%</span> <span class="token number">50</span> <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span> show_image<span class="token punctuation">(</span>output_frame<span class="token punctuation">)</span> frame_idx <span class="token operator">+=</span> <span class="token number">1</span> pbar<span class="token punctuation">.</span>update<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># Close output video.</span> out_video<span class="token punctuation">.</span>release<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># Release MediaPipe resources.</span> pose_tracker<span class="token punctuation">.</span>close<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># Show the last frame of the video.</span> <span class="token keyword">if</span> output_frame <span class="token keyword">is</span> <span class="token keyword">not</span> <span class="token boolean">None</span><span class="token punctuation">:</span> show_image<span class="token punctuation">(</span>output_frame<span class="token punctuation">)</span> |
Lời kết
Thôi ngừng, lười rồi, không viết nữa. Cám ơn các bạn đã đọc đến đây
Link tham khảo
https://export.arxiv.org/pdf/2006.10204
https://export.arxiv.org/pdf/1603.06937
https://towardsdatascience.com/using-hourglass-networks-to-understand-human-poses-1e40e349fa15
https://colab.research.google.com/drive/19txHpN8exWhstO6WVkfmYYVC6uug_oVR#scrollTo=4lXymkneOjgZ
https://google.github.io/mediapipe/solutions/pose_classification.html
https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html
https://www.tohaitrieu.net/exponential-moving-average-ema/
…