Tản mạn một chút về Pose Classification

Tram Ho

Lời mở đầu

Chào các bạn, hôm nay tôi sẽ lảm nhảm một chút về Pose Classification. Như các bạn đã biết hiện giờ bài toán về chuyển động của cơ thể hay phát hiện các điểm trên cơ thể người là một bài toán quan trọng trong ngành ML này, do tính ứng dụng của bài toán này khá đa dạng: như phát hiện động tác trong siêu thị, mô phỏng các bài tập vật lý trị liệu trong y tế, hỗ trợ PT trong các bài tập gym, bla…bla…

Tôi viết bài này cũng để tổng kết lại những gì mình đã tìm hiểu, nếu có thiếu sót mong m.n bỏ qua cho

Bài viết này xoay quanh thư viện Mediapipe của Google, nếu còn thời gian tôi sẽ đề cập tới các thư viện khác sau.

Pose Detection and Pose Tracking

Trước khi đi vào phần Pose Classification, chúng ta nhìn qua một chút bên google làm thế nào detect các điểm trên cơ thể người.

Giải pháp của họ dựa trên một paper mà họ đề xuất, keyword: BlazePose (https://export.arxiv.org/pdf/2006.10204). Dựa trên giải pháp này, họ sẽ bóc tách ra 33 điểm tương ứng 33 bộ phận của cơ thể người hoặc 25 điểm tương ứng phần trên của cơ thể trong không gian 3 chiều (x, y, z) từ một RGB video.

Giải thích sơ qua về BlazePose, về cơ bản BlazePose là bản cải tiến của mạng Stacked Hourglass (https://export.arxiv.org/pdf/1603.06937)

Stacked Hourglass Network

Cấu trúc mạng Stacked Hourglass:

Ý tưởng của mạng này thay vì ta có một bộ encoder-decoder siêu to thì mỗi Hourglass (đồng hồ cát) có nhiệm vụ trả về một heat-map dự đoán các phần trên cơ thể. Do đây là mô hình chồng lên nhau nên thằng Hourglass sau có thể học hỏi kết quả của thằng trước.

Làm thế nào để phát hiện chuyển động người qua heat-map (bản đồ nhiệt) ? Khác với dữ liệu về khuôn mặt con người (72 landmark keypoint, …) thì dữ liệu về chuyển động của người đa dạng hơn nên rất khó tìm các điểm trên cơ thể người dựa vào tọa độ. Các nhà khoa học đã nghĩ ra phương thức sử dụng heat-map để đại diện cho một vùng trên ảnh. Heat-map giúp giữ lại thông tin vùng đó và việc của chúng ta là tìm ra cao điểm (điểm sáng nhất) trong vùng. Lấy ví dụ với một ảnh 256×256 thì heat-map có thể có kích thước 64×64. Bạn xem hình cho dễ hiểu:

Theo như tác giả paper nói qua thì họ sẽ tính loss ở mỗi lần predict, điều này giúp họ giám sát không chỉ kết quả trả về cuối cùng mà còn giám sát đầu ra của mỗi Hourglass. Lấy ví dụ, trong khi chuyển động cơ thể sẽ có các bộ phận bị che đi trước camera, rất khó để phân biệt cánh tay hướng bên trái hay bên phải. Với việc sử dụng kết quả dự đoán vị trí của thằng trước làm input, mô hình không chỉ chú ý các vị trí này mà trong lúc đó còn dự đoán thêm vị trí mới.

Hourglass Module

Giờ chúng ta nhìn qua cấu trúc của một Hourglass

Thông qua hình vẽ chắc các bạn cũng thấy đây là kiến trúc encode-decoder có nhiệm vụ downsample các features trước rồi upsample để hồi phục thông tin và chuyển thành heat-map. Mỗi một layer của encoder kết nối với một layer của decoder tương ứng. Còn mỗi layer xây dựng dựa trên kiến trúc residual block và bottleneck của resnet, bạn nào chưa biết về residual block thì vào link này đọc nhé (https://towardsdatascience.com/residual-blocks-building-blocks-of-resnet-fd90ca15d6ec).

Bên trái: Residual Layer. Bên phải: Bottleneck Block

Bottleneck giúp cho việc tính toán dễ dàng hơn, tương ứng với việc tiết kiệm bộ nhớ.

Giờ chúng ta thử phóng to một hộp trong hình trên

Mỗi hộp trong hình là lớp bottleneck mà tôi đã nhắc đến. Sau mỗi bottleneck sẽ có một lớp pooling để loại bỏ các feature không cần thiết.

Tuy nhiên, lớp đầu tiên có một chút khác biệt, lớp này dùng Convolution 7×7 chứ không phải 3×3.

Như hình trên, ở lớp đầu tiên, input đầu tiên đi qua tổ hợp giữa Convolution 7×7, BatchNorm, Activation Relu, tiếp tục đi qua lớp bottleneck. Ở đây output từ lớp bottleneck sẽ di qua 2 nhánh song song. Một cái đi lớp MaxPooling và thực hiện việc trích xuất đặc trưng, cái còn lại sẽ kết nối với lớp tương ứng ở phần decoder.

Các hộp 2, 3, 4 khác có cấu tạo giống nhau và khác hộp đầu tiên.

Mục đích cuối cùng của việc trích xuất đặc trưng là tạo ra feature maps, chứa thông tin đặc trưng của ảnh nhiều nhất nhưng thông tin không gian thấp nhất. Phần này chính là 3 cái hộp nhỏ nằm giữa encode và decoder.

Input sau khi đi qua 4 lớp của encoder và một lớp bottom trả về feature maps đã sẵn sàng đi qua decoder.

Như tôi đã nhắc tới trên kia thì nhánh còn lại sẽ đi qua lớp bottleneck và thực hiện việc cộng element-wise với output của lớp upsample của nhánh chính. Việc này lặp đi lặp lại khoảng 4 lần cho tới khi kết thúc.

Ở lớp cuối cùng, chúng ta có thể quan sát độ chính xác của mỗi dự đoán (prediction of Hourglass module). Cái này còn gọi là immediate supervision, bạn tính toán loss ở cuối mỗi Hourglass stage thay vì tính loss của cả mô hình.

Output của một Hourglass module đi qua Convolution 1×1, sau đó chia làm 2 nhánh song song. Một dùng để dự đoán và một trả về kết quả sẽ làm đầu vào cho Hourglass module tiếp theo. Cuối cùng, chúng ta thực hiện việc cộng từng phần tử (element-wise addition) giữa đầu vào của network (heatmap) và cả 2 đầu ra của Hourglass module. P/S: kết quả dự đoán đi qua Convolution 1×1 để cho đúng shape thì mới cộng từng phần tử được.

Cuối cùng để xây dựng Stacked Hourglass Network, chúng ta cần thực hiện lặp đi lặp lại các Hourglass module này.

BlazePose

Ok quay lại thuật toán BlazePose của Google, các bạn có thể tìm hiểu bằng cách đọc bài “Tìm hiểu về BlazePose” của tác giả Phạm Văn Toàn
Không phải tôi lười đâu các bạn, chỉ là có người viết trước rồi thôi.

Pose classification

Ok, coi như các bạn đã hiểu BlazePose rồi. Sau khi dùng Pose Landmark Model (BlazePose GHUM 3D) để detect chuyển động trên một ảnh sẽ trả về 33 điểm trên cơ thể như hình dưới đây:

Python Solution API

Rất may Google cung cấp giải pháp Python API mà bạn chỉ cần import thư viện mediapipe, code vài dòng là chạy được. Mediapipe cung cấp giải pháp detect pose trên ảnh tĩnh và trên video. Dưới đây là code dành cho ảnh tĩnh

Import thư viện dùng để xử lý ảnh opencv và thư viện mediapipe. Sau đó đặt 2 biến để dùng các hàm của mediapipe

Với ảnh tĩnh, các bạn phải chỉnh lại tham số cho class Pose như sau

Trong thư viện của mediapipe có sẵn model nên bạn không cần train lại.

Nếu bạn muốn in giá trị detect ra thì hãy nhớ mediapose trả về tọa độ dưới dạng tỉ lệ so với ảnh nên bạn cần phải nhân với chiều cao và chiều rộng để vẽ đúng tọa độ trên ảnh.

Nếu bạn muốn vẽ point hoặc line lên ảnh thì dùng hàm draw_landmarks()

Ok đối với video thì dùng cv2 (opencv) để capture video, tạo vòng lặp while True để xử lý từng frame một, ai làm việc nhiều với opencv chắc là quen rồi.

Về cơ bản cũng không khác gì nhiều với việc xử lý ảnh tĩnh, chỉ có cái khi gọi instance của mediapipe pose thì không cần param static_image_mode. Ngược lại có thêm param min_tracking_confidence dùng để tracking người từng frame, giá trị trong khoảng từ 0 đến 1, càng cao thì càng bắt đúng nhưng tăng độ trễ (tốn thời gian xử lý hơn).

Dựa trên BlazaPose, mediapipe trả về một điểm có tọa độ x, y, z. x, y thì như tôi nói ở trên được trả về dưới dạng tỉ lệ của chiều rộng và chiều cao còn z là độ sâu của điểm trọng tâm cơ thể. Giá trị càng nhỏ càng gần camera, z được trả về dưới dạng tỉ lệ với chiều rộng như x. z chỉ dự đoán được trong cơ chế full-body (33 điểm) còn cơ chế upper-body (25 điểm) thì không hỗ trợ.

Link colab của Pose Classification: https://colab.research.google.com/drive/19txHpN8exWhstO6WVkfmYYVC6uug_oVR

Prepare dataset

Ok quay lại về Pose Classification, đơn giản là họ dùng KNN (K-nearest neightbors) để group lại các động tác giống nhau. Về phần KNN là gì mời bạn đọc (https://machinelearningcoban.com/2017/01/08/knn/).

Đầu tiên cần chuẩn bị dữ liệu huấn luyện đã. Giả dụ trong động tác chống đẩy có 2 trạng thái: lên và xuống như hình dưới. Như vậy tôi cần chuẩn bị ảnh của cả 2 trạng thái này.

Google hỗ trợ code để gen dữ liệu, việc còn lại của các bạn là hiểu nó.

Bạn cần tạo 1 folder có cấu trúc như sau:

Ảnh ở đây là ảnh của trạng thái chống đẩy lên và trạng thái chống đẩy xuống.

Gọi instance bootstrap_helper từ class BootstrapHelper(). Class này có nhiệm vụ detect 33 điểm trên cơ thể, vẽ các điểm và nối thành đường lên ảnh, xuất ra file csv.

Kiểm tra trạng thái các folder này

Xóa đi các động tác bị detect lỗi bằng cách so ảnh với file csv

Pose Embedding

Embed 33 điểm trên cơ thể bằng cách tính khoảng cách giữa các điểm như: khoảng cách vai trái – hông trái, vai phải – hông phải, đầu gối trái – gót chân trái, đầu gối phải – gót chân phải, …

Pose Classifier

Gọi instance của class PoseClassifier(), có tác dụng phân loại class bằng cách so sánh độ tương đồng giữa chiều dài các điểm trên cơ thể người của ảnh và database.

Pose Smoothing

Khi detect 33 điểm trên cơ thể, các điểm sẽ dịch chuyển rất loạn ở từng frame dù cho nhìn bằng mắt thường không thấy có gì khác biệt, vì vậy tôi cần làm mịn dữ liệu dự đoán ở các frame bằng thuật toán EMA (Exponential moving average). Lý thuyết các bạn có thể tham khảo ở đây: (https://www.tohaitrieu.net/exponential-moving-average-ema/).

Test

Ok, test với một video không có trong bộ dữ liệu hiện tại bằng cách tổng hợp tất cả các đoạn code ở trên. Code bên dưới tôi đã loại bỏ các phần repetition counter và visualizer để dễ nhìn. Các bạn có thể thêm vào để kết quả trả về trực quan nhất.

Lời kết

Thôi ngừng, lười rồi, không viết nữa. Cám ơn các bạn đã đọc đến đây

Link tham khảo

https://export.arxiv.org/pdf/2006.10204

https://export.arxiv.org/pdf/1603.06937

https://towardsdatascience.com/using-hourglass-networks-to-understand-human-poses-1e40e349fa15

https://colab.research.google.com/drive/19txHpN8exWhstO6WVkfmYYVC6uug_oVR#scrollTo=4lXymkneOjgZ

https://google.github.io/mediapipe/solutions/pose_classification.html

https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html

https://www.tohaitrieu.net/exponential-moving-average-ema/

Chia sẻ bài viết ngay

Nguồn bài viết : Viblo