Tương tự như các bài trước, sau lý thuyết ở Phần 1, thì trong phần 2 này mình sẽ trình bày demo thuật toán. Cùng mình tìm hiểu nhé
1. Xây dựng class MF
Hàm khởi tạo
Tham số đầu vào:
- Y: ma trận Utility, gồm 3 cột, mỗi cột gồm 3 số liệu: user_id, item_id, rating.
- n_factors: số chiều ẩn giữa các users và items, mặc định n_factors = 2.
- X: ma trận users
- W: ma trận ratings
- lamda: trọng số regularization của hàm mất mát để tránh overfitting , mặc định lamda = 0.1
- learning_rate: là learning_rate – trọng số Gradient Descent, sử dụng để điều chỉnh tốc độ học., mặc định learning_rate = 2
- n_epochs: số lần lặp để huấn luyện, mặc định n_epochs = 50
- top: số lượng items gợi ý cho mỗi user. Mặc định bằng
10
. - filename: File lưu số liệu đánh giá.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | <span class="token keyword">class</span> <span class="token class-name">MF</span><span class="token punctuation">(</span><span class="token builtin">object</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> Y<span class="token punctuation">,</span> n_factors <span class="token operator">=</span> <span class="token number">2</span><span class="token punctuation">,</span> X <span class="token operator">=</span> <span class="token boolean">None</span><span class="token punctuation">,</span> W <span class="token operator">=</span> <span class="token boolean">None</span><span class="token punctuation">,</span> lamda <span class="token operator">=</span> <span class="token number">0.1</span><span class="token punctuation">,</span> learning_rate <span class="token operator">=</span> <span class="token number">2</span><span class="token punctuation">,</span> n_epochs <span class="token operator">=</span> <span class="token number">50</span><span class="token punctuation">,</span> top <span class="token operator">=</span> <span class="token number">10</span><span class="token punctuation">,</span> filename <span class="token operator">=</span> <span class="token boolean">None</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">if</span> filename<span class="token punctuation">:</span> self<span class="token punctuation">.</span>f <span class="token operator">=</span> <span class="token builtin">open</span><span class="token punctuation">(</span>filename<span class="token punctuation">,</span> <span class="token string">'a+'</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>Y <span class="token operator">=</span> Y self<span class="token punctuation">.</span>lamda <span class="token operator">=</span> lamda self<span class="token punctuation">.</span>n_factors <span class="token operator">=</span> n_factors self<span class="token punctuation">.</span>learning_rate <span class="token operator">=</span> learning_rate self<span class="token punctuation">.</span>n_epochs <span class="token operator">=</span> n_epochs self<span class="token punctuation">.</span>top <span class="token operator">=</span> top self<span class="token punctuation">.</span>users_count <span class="token operator">=</span> <span class="token builtin">int</span><span class="token punctuation">(</span>np<span class="token punctuation">.</span><span class="token builtin">max</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>Y<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token operator">+</span> <span class="token number">1</span> self<span class="token punctuation">.</span>items_count <span class="token operator">=</span> <span class="token builtin">int</span><span class="token punctuation">(</span>np<span class="token punctuation">.</span><span class="token builtin">max</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>Y<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token operator">+</span> <span class="token number">1</span> self<span class="token punctuation">.</span>ratings_count <span class="token operator">=</span> Y<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token keyword">if</span> X <span class="token operator">==</span> <span class="token boolean">None</span><span class="token punctuation">:</span> self<span class="token punctuation">.</span>X <span class="token operator">=</span> np<span class="token punctuation">.</span>random<span class="token punctuation">.</span>randn<span class="token punctuation">(</span>self<span class="token punctuation">.</span>items_count<span class="token punctuation">,</span> n_factors<span class="token punctuation">)</span> <span class="token keyword">if</span> W <span class="token operator">==</span> <span class="token boolean">None</span><span class="token punctuation">:</span> self<span class="token punctuation">.</span>W <span class="token operator">=</span> np<span class="token punctuation">.</span>random<span class="token punctuation">.</span>randn<span class="token punctuation">(</span>n_factors<span class="token punctuation">,</span> self<span class="token punctuation">.</span>users_count<span class="token punctuation">)</span> self<span class="token punctuation">.</span>Ybar <span class="token operator">=</span> self<span class="token punctuation">.</span>Y<span class="token punctuation">.</span>copy<span class="token punctuation">(</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>bi <span class="token operator">=</span> np<span class="token punctuation">.</span>random<span class="token punctuation">.</span>randn<span class="token punctuation">(</span>self<span class="token punctuation">.</span>items_count<span class="token punctuation">)</span> self<span class="token punctuation">.</span>bu <span class="token operator">=</span> np<span class="token punctuation">.</span>random<span class="token punctuation">.</span>randn<span class="token punctuation">(</span>self<span class="token punctuation">.</span>users_count<span class="token punctuation">)</span> self<span class="token punctuation">.</span>n_ratings <span class="token operator">=</span> self<span class="token punctuation">.</span>Y<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> |
Thay đổi các trọng số, bạn có thể quan sát ảnh hưởng của trọng số tới kết quả đánh giá cuả thuật toán.
Hàm getUserRated() và getItemsRatedByUser()
Hàm get_user_rated_item(i)
trả về danh sách users đã đánh giá item thứ i
1 2 3 4 5 6 7 | <span class="token keyword">def</span> <span class="token function">get_user_rated_item</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> i<span class="token punctuation">)</span><span class="token punctuation">:</span> ids <span class="token operator">=</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>i <span class="token operator">==</span> self<span class="token punctuation">.</span>Ybar<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">.</span>astype<span class="token punctuation">(</span><span class="token builtin">int</span><span class="token punctuation">)</span> users <span class="token operator">=</span> self<span class="token punctuation">.</span>Ybar<span class="token punctuation">[</span>ids<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">.</span>astype<span class="token punctuation">(</span><span class="token builtin">int</span><span class="token punctuation">)</span> ratings <span class="token operator">=</span> self<span class="token punctuation">.</span>Ybar<span class="token punctuation">[</span>ids<span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span> <span class="token keyword">return</span> <span class="token punctuation">(</span>users<span class="token punctuation">,</span> ratings<span class="token punctuation">)</span> |
Hàm get_item_rated_by_user(u)
trả về danh sách items được đánh giá bởi user thứ u
1 2 3 4 5 6 7 | <span class="token keyword">def</span> <span class="token function">get_item_rated_by_user</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> u<span class="token punctuation">)</span><span class="token punctuation">:</span> ids <span class="token operator">=</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>u <span class="token operator">==</span> self<span class="token punctuation">.</span>Ybar<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">.</span>astype<span class="token punctuation">(</span><span class="token builtin">int</span><span class="token punctuation">)</span> items <span class="token operator">=</span> self<span class="token punctuation">.</span>Ybar<span class="token punctuation">[</span>ids<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">.</span>astype<span class="token punctuation">(</span><span class="token builtin">int</span><span class="token punctuation">)</span> ratings <span class="token operator">=</span> self<span class="token punctuation">.</span>Ybar<span class="token punctuation">[</span>ids<span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span> <span class="token keyword">return</span> <span class="token punctuation">(</span>items<span class="token punctuation">,</span> ratings<span class="token punctuation">)</span> |
Chúng ta sẽ sử dụng 2 hàm này để tối ưu hai ma trận X và W.
Hàm update X và W:
Đây là hai hàm tối ưu X và W, với số vòng lặp đang được cố định là 50 lần.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | <span class="token keyword">def</span> <span class="token function">updateX</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">for</span> m <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>items_count<span class="token punctuation">)</span><span class="token punctuation">:</span> users<span class="token punctuation">,</span> ratings <span class="token operator">=</span> self<span class="token punctuation">.</span>get_user_rated_item<span class="token punctuation">(</span>m<span class="token punctuation">)</span> Wm <span class="token operator">=</span> self<span class="token punctuation">.</span>W<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> users<span class="token punctuation">]</span> b <span class="token operator">=</span> self<span class="token punctuation">.</span>bu<span class="token punctuation">[</span>users<span class="token punctuation">]</span> sum_grad_xm <span class="token operator">=</span> np<span class="token punctuation">.</span>full<span class="token punctuation">(</span>shape <span class="token operator">=</span> <span class="token punctuation">(</span>self<span class="token punctuation">.</span>X<span class="token punctuation">[</span>m<span class="token punctuation">]</span><span class="token punctuation">.</span>shape<span class="token punctuation">)</span> <span class="token punctuation">,</span> fill_value <span class="token operator">=</span> <span class="token number">1e</span><span class="token operator">-</span><span class="token number">8</span><span class="token punctuation">)</span> sum_grad_bm <span class="token operator">=</span> <span class="token number">1e</span><span class="token operator">-</span><span class="token number">8</span> <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span><span class="token number">50</span><span class="token punctuation">)</span><span class="token punctuation">:</span> xm <span class="token operator">=</span> self<span class="token punctuation">.</span>X<span class="token punctuation">[</span>m<span class="token punctuation">]</span> error <span class="token operator">=</span> xm<span class="token punctuation">.</span>dot<span class="token punctuation">(</span>Wm<span class="token punctuation">)</span> <span class="token operator">+</span> self<span class="token punctuation">.</span>bi<span class="token punctuation">[</span>m<span class="token punctuation">]</span> <span class="token operator">+</span> b <span class="token operator">-</span> ratings grad_xm <span class="token operator">=</span> error<span class="token punctuation">.</span>dot<span class="token punctuation">(</span>Wm<span class="token punctuation">.</span>T<span class="token punctuation">)</span><span class="token operator">/</span>self<span class="token punctuation">.</span>n_ratings <span class="token operator">+</span> self<span class="token punctuation">.</span>lamda<span class="token operator">*</span>xm grad_bm <span class="token operator">=</span> np<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>error<span class="token punctuation">)</span><span class="token operator">/</span>self<span class="token punctuation">.</span>n_ratings sum_grad_xm <span class="token operator">+=</span> grad_xm<span class="token operator">**</span><span class="token number">2</span> sum_grad_bm <span class="token operator">+=</span> grad_bm<span class="token operator">**</span><span class="token number">2</span> <span class="token comment"># gradient descent</span> self<span class="token punctuation">.</span>X<span class="token punctuation">[</span>m<span class="token punctuation">]</span> <span class="token operator">-=</span> self<span class="token punctuation">.</span>lr<span class="token operator">*</span>grad_xm<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token operator">/</span>np<span class="token punctuation">.</span>sqrt<span class="token punctuation">(</span>sum_grad_xm<span class="token punctuation">)</span> self<span class="token punctuation">.</span>bi<span class="token punctuation">[</span>m<span class="token punctuation">]</span> <span class="token operator">-=</span> self<span class="token punctuation">.</span>lr<span class="token operator">*</span>grad_bm<span class="token operator">/</span>np<span class="token punctuation">.</span>sqrt<span class="token punctuation">(</span>sum_grad_bm<span class="token punctuation">)</span> <span class="token keyword">def</span> <span class="token function">updateW</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">for</span> n <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>users_count<span class="token punctuation">)</span><span class="token punctuation">:</span> items<span class="token punctuation">,</span> ratings <span class="token operator">=</span> self<span class="token punctuation">.</span>get_item_rated_by_user<span class="token punctuation">(</span>n<span class="token punctuation">)</span> Xn <span class="token operator">=</span> self<span class="token punctuation">.</span>X<span class="token punctuation">[</span>items<span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">]</span> b <span class="token operator">=</span> self<span class="token punctuation">.</span>bi<span class="token punctuation">[</span>items<span class="token punctuation">]</span> sum_grad_wn <span class="token operator">=</span> np<span class="token punctuation">.</span>full<span class="token punctuation">(</span>shape <span class="token operator">=</span> <span class="token punctuation">(</span>self<span class="token punctuation">.</span>W<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> n<span class="token punctuation">]</span><span class="token punctuation">.</span>shape<span class="token punctuation">)</span> <span class="token punctuation">,</span> fill_value <span class="token operator">=</span> <span class="token number">1e</span><span class="token operator">-</span><span class="token number">8</span><span class="token punctuation">)</span><span class="token punctuation">.</span>T sum_grad_bn <span class="token operator">=</span> <span class="token number">1e</span><span class="token operator">-</span><span class="token number">8</span> <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span><span class="token number">50</span><span class="token punctuation">)</span><span class="token punctuation">:</span> wn <span class="token operator">=</span> self<span class="token punctuation">.</span>W<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> n<span class="token punctuation">]</span> error <span class="token operator">=</span> Xn<span class="token punctuation">.</span>dot<span class="token punctuation">(</span>wn<span class="token punctuation">)</span> <span class="token operator">+</span> self<span class="token punctuation">.</span>bu<span class="token punctuation">[</span>n<span class="token punctuation">]</span> <span class="token operator">+</span> b <span class="token operator">-</span> ratings grad_wn <span class="token operator">=</span> Xn<span class="token punctuation">.</span>T<span class="token punctuation">.</span>dot<span class="token punctuation">(</span>error<span class="token punctuation">)</span><span class="token operator">/</span>self<span class="token punctuation">.</span>n_ratings <span class="token operator">+</span> self<span class="token punctuation">.</span>lamda<span class="token operator">*</span>wn grad_bn <span class="token operator">=</span> np<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>error<span class="token punctuation">)</span><span class="token operator">/</span>self<span class="token punctuation">.</span>n_ratings sum_grad_wn <span class="token operator">+=</span> grad_wn<span class="token operator">**</span><span class="token number">2</span> sum_grad_bn <span class="token operator">+=</span> grad_bn<span class="token operator">**</span><span class="token number">2</span> <span class="token comment"># gradient descent</span> self<span class="token punctuation">.</span>W<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> n<span class="token punctuation">]</span> <span class="token operator">-=</span> self<span class="token punctuation">.</span>lr<span class="token operator">*</span>grad_wn<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token operator">/</span>np<span class="token punctuation">.</span>sqrt<span class="token punctuation">(</span>sum_grad_wn<span class="token punctuation">)</span> self<span class="token punctuation">.</span>bu<span class="token punctuation">[</span>n<span class="token punctuation">]</span> <span class="token operator">-=</span> self<span class="token punctuation">.</span>lr<span class="token operator">*</span>grad_bn<span class="token operator">/</span>np<span class="token punctuation">.</span>sqrt<span class="token punctuation">(</span>sum_grad_bn<span class="token punctuation">)</span> |
Thuật toán chính
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | <span class="token keyword">def</span> <span class="token function">fit</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">,</span> data_size<span class="token punctuation">,</span> Data_test<span class="token punctuation">,</span> test_size <span class="token operator">=</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>n_epochs<span class="token punctuation">)</span><span class="token punctuation">:</span> self<span class="token punctuation">.</span>updateW<span class="token punctuation">(</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>updateX<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">if</span> <span class="token punctuation">(</span>i <span class="token operator">+</span> <span class="token number">1</span><span class="token punctuation">)</span> <span class="token operator">%</span> x <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span> self<span class="token punctuation">.</span>RMSE<span class="token punctuation">(</span>Data_test<span class="token punctuation">,</span>data_size <span class="token operator">=</span> data_size<span class="token punctuation">,</span> test_size <span class="token operator">=</span> <span class="token number">0</span><span class="token punctuation">,</span> p <span class="token operator">=</span> i<span class="token operator">+</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token comment"># self.evaluate(data_size, Data_test, test_size = 0)</span> <span class="token keyword">def</span> <span class="token function">pred</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> u<span class="token punctuation">,</span> i<span class="token punctuation">)</span><span class="token punctuation">:</span> u <span class="token operator">=</span> <span class="token builtin">int</span><span class="token punctuation">(</span>u<span class="token punctuation">)</span> i <span class="token operator">=</span> <span class="token builtin">int</span><span class="token punctuation">(</span>i<span class="token punctuation">)</span> pred <span class="token operator">=</span> self<span class="token punctuation">.</span>X<span class="token punctuation">[</span>i<span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">.</span>dot<span class="token punctuation">(</span>self<span class="token punctuation">.</span>W<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> u<span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">+</span> self<span class="token punctuation">.</span>bi<span class="token punctuation">[</span>i<span class="token punctuation">]</span> <span class="token operator">+</span> self<span class="token punctuation">.</span>bu<span class="token punctuation">[</span>u<span class="token punctuation">]</span> <span class="token keyword">return</span> <span class="token builtin">max</span><span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token builtin">min</span><span class="token punctuation">(</span><span class="token number">5</span><span class="token punctuation">,</span> pred<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">def</span> <span class="token function">recommend</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> u<span class="token punctuation">)</span><span class="token punctuation">:</span> ids <span class="token operator">=</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>self<span class="token punctuation">.</span>Y<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">==</span> u<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">.</span>astype<span class="token punctuation">(</span><span class="token builtin">int</span><span class="token punctuation">)</span> items_rated_by_user <span class="token operator">=</span> self<span class="token punctuation">.</span>Y<span class="token punctuation">[</span>ids<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">.</span>tolist<span class="token punctuation">(</span><span class="token punctuation">)</span> a <span class="token operator">=</span> np<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>items_count<span class="token punctuation">,</span><span class="token punctuation">)</span><span class="token punctuation">)</span> recommended_items <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> pred <span class="token operator">=</span> self<span class="token punctuation">.</span>X<span class="token punctuation">.</span>dot<span class="token punctuation">(</span>self<span class="token punctuation">.</span>W<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> u<span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>items_count<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">if</span> i <span class="token operator">not</span> <span class="token keyword">in</span> items_rated_by_user<span class="token punctuation">:</span> a<span class="token punctuation">[</span>i<span class="token punctuation">]</span> <span class="token operator">=</span> pred<span class="token punctuation">[</span>i<span class="token punctuation">]</span> <span class="token operator">+</span>self<span class="token punctuation">.</span>bi<span class="token punctuation">[</span>i<span class="token punctuation">]</span> <span class="token operator">+</span> self<span class="token punctuation">.</span>bu<span class="token punctuation">[</span>u<span class="token punctuation">]</span> <span class="token keyword">if</span> <span class="token builtin">len</span><span class="token punctuation">(</span>a<span class="token punctuation">)</span> <span class="token operator"><</span> self<span class="token punctuation">.</span>top<span class="token punctuation">:</span> recommended_items <span class="token operator">=</span> np<span class="token punctuation">.</span>argsort<span class="token punctuation">(</span>a<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token operator">-</span>self<span class="token punctuation">.</span>items_count<span class="token punctuation">:</span><span class="token punctuation">]</span> <span class="token keyword">else</span><span class="token punctuation">:</span> recommended_items <span class="token operator">=</span> np<span class="token punctuation">.</span>argsort<span class="token punctuation">(</span>a<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token operator">-</span>self<span class="token punctuation">.</span>top<span class="token punctuation">:</span><span class="token punctuation">]</span> recommended_items <span class="token operator">=</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>a<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">]</span> <span class="token operator">></span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">.</span>astype<span class="token punctuation">(</span><span class="token builtin">int</span><span class="token punctuation">)</span> <span class="token comment"># return random.sample(list(recommended_items), self.top)</span> <span class="token keyword">return</span> recommended_items<span class="token punctuation">[</span><span class="token punctuation">:</span>self<span class="token punctuation">.</span>limit<span class="token punctuation">]</span> <span class="token comment"># return recommended_items</span> |
2. Đánh giá
Tương tự như 2 phương pháp trước, ở đây mình sử dụng 2 độ đo, RMSE
và PR
:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | <span class="token keyword">def</span> <span class="token function">RMSE</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> Data_test<span class="token punctuation">,</span> test_size <span class="token operator">=</span> <span class="token number">0</span><span class="token punctuation">,</span> data_size <span class="token operator">=</span> <span class="token string">'100K'</span><span class="token punctuation">,</span> p <span class="token operator">=</span> <span class="token number">10</span><span class="token punctuation">)</span><span class="token punctuation">:</span> n_tests <span class="token operator">=</span> Data_test<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> SE <span class="token operator">=</span> <span class="token number">0</span> <span class="token keyword">for</span> n <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>n_tests<span class="token punctuation">)</span><span class="token punctuation">:</span> pred <span class="token operator">=</span> self<span class="token punctuation">.</span>pred<span class="token punctuation">(</span>Data_test<span class="token punctuation">[</span>n<span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> Data_test<span class="token punctuation">[</span>n<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> SE <span class="token operator">+=</span> <span class="token punctuation">(</span>pred <span class="token operator">-</span> Data_test<span class="token punctuation">[</span>n<span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token operator">**</span><span class="token number">2</span> RMSE <span class="token operator">=</span> np<span class="token punctuation">.</span>sqrt<span class="token punctuation">(</span>SE<span class="token operator">/</span>n_tests<span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'%s::1::%d::%d::%r::%r::%rrn'</span> <span class="token operator">%</span> <span class="token punctuation">(</span><span class="token builtin">str</span><span class="token punctuation">(</span>data_size<span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>n_factors<span class="token punctuation">,</span> self<span class="token punctuation">.</span>n_epochs<span class="token punctuation">,</span> self<span class="token punctuation">.</span>lamda<span class="token punctuation">,</span> self<span class="token punctuation">.</span>lr<span class="token punctuation">,</span> RMSE<span class="token punctuation">)</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>f<span class="token punctuation">.</span>write<span class="token punctuation">(</span><span class="token string">'%s::1::%d::%d::%d::%r::%r::%rrn'</span> <span class="token operator">%</span> <span class="token punctuation">(</span><span class="token builtin">str</span><span class="token punctuation">(</span>data_size<span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>n_factors<span class="token punctuation">,</span> self<span class="token punctuation">.</span>n_epochs<span class="token punctuation">,</span> p<span class="token punctuation">,</span> self<span class="token punctuation">.</span>lamda<span class="token punctuation">,</span> self<span class="token punctuation">.</span>lr<span class="token punctuation">,</span> RMSE<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">return</span> RMSE <span class="token keyword">def</span> <span class="token function">evaluate</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> data_size<span class="token punctuation">,</span> Data_test<span class="token punctuation">,</span> test_size <span class="token operator">=</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">:</span> sum_p <span class="token operator">=</span> <span class="token number">0</span> sum_r <span class="token operator">=</span> <span class="token number">0</span> self<span class="token punctuation">.</span>Pu <span class="token operator">=</span> np<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>users_count<span class="token punctuation">,</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">for</span> u <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>users_count<span class="token punctuation">)</span><span class="token punctuation">:</span> recommended_items <span class="token operator">=</span> self<span class="token punctuation">.</span>recommend<span class="token punctuation">(</span>u<span class="token punctuation">)</span> ids <span class="token operator">=</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>Data_test<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">==</span> u<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> rated_items <span class="token operator">=</span> Data_test<span class="token punctuation">[</span>ids<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span> <span class="token keyword">for</span> i <span class="token keyword">in</span> recommended_items<span class="token punctuation">:</span> <span class="token keyword">if</span> i <span class="token keyword">in</span> rated_items<span class="token punctuation">:</span> self<span class="token punctuation">.</span>Pu<span class="token punctuation">[</span>u<span class="token punctuation">]</span> <span class="token operator">+=</span> <span class="token number">1</span> sum_p <span class="token operator">+=</span> self<span class="token punctuation">.</span>Pu<span class="token punctuation">[</span>u<span class="token punctuation">]</span> p <span class="token operator">=</span> sum_p<span class="token operator">/</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>users_count <span class="token operator">*</span> self<span class="token punctuation">.</span>limit<span class="token punctuation">)</span> r <span class="token operator">=</span> sum_p<span class="token operator">/</span><span class="token punctuation">(</span>Data_test<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>f<span class="token punctuation">.</span>write<span class="token punctuation">(</span><span class="token string">'%s::1::%d::%d::%d::%r::%r::%rrn'</span> <span class="token operator">%</span> <span class="token punctuation">(</span><span class="token builtin">str</span><span class="token punctuation">(</span>data_size<span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>top<span class="token punctuation">,</span> self<span class="token punctuation">.</span>n_factors<span class="token punctuation">,</span> self<span class="token punctuation">.</span>n_epochs<span class="token punctuation">,</span> test_size<span class="token punctuation">,</span> p<span class="token punctuation">,</span> r<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">return</span> p<span class="token punctuation">,</span> r |
3. Demo với tập dữ liệu Movielen
1 2 3 4 | rs <span class="token operator">=</span> MF<span class="token punctuation">(</span>rate_train<span class="token punctuation">,</span> n_factors <span class="token operator">=</span> <span class="token number">2</span><span class="token punctuation">,</span> lamda <span class="token operator">=</span> <span class="token number">0.01</span><span class="token punctuation">,</span> lr <span class="token operator">=</span> <span class="token number">0.1</span><span class="token punctuation">,</span> n_epochs<span class="token operator">=</span> <span class="token number">20</span><span class="token punctuation">,</span> filename <span class="token operator">=</span> <span class="token string">'RMSE_100K_MF.dat'</span><span class="token punctuation">)</span> rs<span class="token punctuation">.</span>fit<span class="token punctuation">(</span><span class="token number">10</span><span class="token punctuation">,</span> <span class="token string">"100K"</span><span class="token punctuation">,</span> rate_test<span class="token punctuation">)</span> rs<span class="token punctuation">.</span>f<span class="token punctuation">.</span>close<span class="token punctuation">(</span><span class="token punctuation">)</span> |
Kết quả mình thu được là:
1 2 3 4 | 100K::1::2::20::0.01::0.1::0.9634817342439627 100K::1::2::20::0.01::0.1::0.9634984986336697 |
Thay đổi các trọng số để tìm bộ trọng số tốt nhất
1 2 3 4 5 6 7 | <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token punctuation">[</span><span class="token number">50</span><span class="token punctuation">,</span> <span class="token number">60</span><span class="token punctuation">]</span><span class="token punctuation">:</span> <span class="token keyword">for</span> j <span class="token keyword">in</span> <span class="token punctuation">[</span><span class="token number">0.01</span><span class="token punctuation">,</span> <span class="token number">0.1</span><span class="token punctuation">,</span> <span class="token number">0.5</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">:</span> <span class="token keyword">for</span> k <span class="token keyword">in</span> <span class="token punctuation">[</span><span class="token number">0.1</span><span class="token punctuation">,</span> <span class="token number">0.5</span><span class="token punctuation">,</span> <span class="token number">0.75</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">:</span> rs <span class="token operator">=</span> MF<span class="token punctuation">(</span>rate_train<span class="token punctuation">,</span> n_factors <span class="token operator">=</span> i<span class="token punctuation">,</span> lamda <span class="token operator">=</span> <span class="token number">0.1</span><span class="token punctuation">,</span> lr <span class="token operator">=</span> <span class="token number">0.1</span><span class="token punctuation">,</span> n_epochs<span class="token operator">=</span> <span class="token number">10</span><span class="token punctuation">)</span> rs<span class="token punctuation">.</span>fit<span class="token punctuation">(</span><span class="token number">10</span><span class="token punctuation">,</span> data_size <span class="token operator">=</span> <span class="token string">"1M"</span><span class="token punctuation">,</span> Data_test <span class="token operator">=</span> rate_test<span class="token punctuation">,</span> test_size <span class="token operator">=</span><span class="token number">0.1</span><span class="token punctuation">)</span> rs<span class="token punctuation">.</span>f<span class="token punctuation">.</span>close<span class="token punctuation">(</span><span class="token punctuation">)</span> |
Source code và tài liệu tham khảo:
https://machinelearningcoban.com/2017/05/31/matrixfactorization/