Giới thiệu
Siêu tham số là gì ?
Đối với những bạn đã xây dựng các mô hình deep learning thì chắc hẳn không còn xa lạ gì với khái niệm siêu tham số này nữa.Trong học máy và học sâu, siêu tham số là những tham số mà giá trị của nó có thể điều khiển quá trình huấn luyện của mô hình. Nó không cố định, thay đổi theo từng bài toán, từng quá trình huấn luyện và có thể thay đổi trong quá trình huấn luyện đó. Đặc biệt, nó đóng vai trò quan trọng trong việc quyết định hiệu suất của mô hình. Ví dụ về siêu tham số như là số lượng các units và các layers trong neural network, hay trong network trainer có các siêu tham số như là batchsize, learning rate schedule, optimizer, momentum, adam alpha,…
Optuna là gì ?
Optuna là 1 framwork hỗ trợ việc tự động điều chỉnh tham số mô hình để mô hình có thể đạt được hiệu năng tốt nhất ứng.
Khó khăn gặp phải khi phải điều chỉnh tham số thủ công
Đặt vấn đề : các bạn cần tạo một mô hình MLP, làm sao để chọn số hidden layers, số units trong 1 layers một cách phù hợp và chính xác ?
Việc chọn siêu tham số khi thực hiện training mô hình dường như đều dựa vào cảm tính và kinh nghiệm của các bạn. Và điều nay không phải lúc nào cũng có thể làm cho mô hình đạt performance tốt nhất . Còn đối với những bạn chưa có kinh nghiệm lựa chọn siêu tham số thì đây thực sự là một thử thách.
Cách sử dụng thư viện Optuna trong Pytorch
Install
1 2 | pip install optuna |
Định nghĩa objective function
1 2 3 4 | <span class="token keyword">def</span> <span class="token function">objective</span> <span class="token punctuation">(</span> trial <span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">return</span> |
Trial object definition
Trial là một đối tượng thể hiện của một class được thực thi trong Optuna.Nó được sử dụng để định nghĩa các siêu tham số được tối ưu.Giá trị được lấy trong phạm vi được bạn định nghĩa, thông tin của các tham số được tìm kiếm trong quá khứ vẫn được giữ lại và giá trị mới sẽ dựa trên những thông tin đó.
Trial được định nghĩa như sau:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | <span class="token comment"># Tham số thực hiện chọn loại </span> param1 <span class="token operator">=</span> Trial<span class="token punctuation">.</span>Suggest_categorical<span class="token punctuation">(</span> Name <span class="token punctuation">,</span> Choices <span class="token punctuation">)</span> <span class="token comment"># Định nghĩa tham số nguyên cho trial</span> param2 <span class="token operator">=</span> trial<span class="token punctuation">.</span>Suggest_int<span class="token punctuation">(</span> name <span class="token punctuation">,</span> low <span class="token punctuation">,</span> high <span class="token punctuation">)</span> <span class="token comment"># Biểu diễn tham số có giá trị liên tục</span> param3 <span class="token operator">=</span> trial<span class="token punctuation">.</span>Suggest_uniform<span class="token punctuation">(</span> name <span class="token punctuation">,</span> low <span class="token punctuation">,</span> high <span class="token punctuation">)</span> <span class="token comment"># Biểu diễn tham số có giá trị rời rạc</span> param4 <span class="token operator">=</span> trial<span class="token punctuation">.</span>Suggest_discrete_uniform<span class="token punctuation">(</span> name <span class="token punctuation">,</span> low <span class="token punctuation">,</span> high <span class="token punctuation">,</span> q <span class="token punctuation">)</span> <span class="token comment"># Biểu diễn các tham số logarithm</span> param5 <span class="token operator">=</span> trial<span class="token punctuation">.</span>Suggest_loguniform<span class="token punctuation">(</span> name <span class="token punctuation">,</span> low <span class="token punctuation">,</span> high <span class="token punctuation">)</span> |
name là kiểu string và giá trị là tên tham số đó. Choicesis là dạng list và nó được biểu diễn là sự lựa chọn tên của nhiều loại. Low và high là giá trị của giá trị cực tiểu và cực đại của tham số
Ví dụ sau:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | <span class="token keyword">def</span> <span class="token function">objective</span><span class="token punctuation">(</span>trial<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token comment"># Categorical parameter</span> optimizer <span class="token operator">=</span> trial<span class="token punctuation">.</span>suggest_categorical<span class="token punctuation">(</span><span class="token string">'optimizer'</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token string">'MomentumSGD'</span><span class="token punctuation">,</span> <span class="token string">'Adam'</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># Int parameter</span> num_layers <span class="token operator">=</span> trial<span class="token punctuation">.</span>suggest_int<span class="token punctuation">(</span><span class="token string">'num_layers'</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span> <span class="token comment"># Uniform parameter</span> dropout_rate <span class="token operator">=</span> trial<span class="token punctuation">.</span>suggest_uniform<span class="token punctuation">(</span><span class="token string">'dropout_rate'</span><span class="token punctuation">,</span> <span class="token number">0.0</span><span class="token punctuation">,</span> <span class="token number">1.0</span><span class="token punctuation">)</span> <span class="token comment"># Loguniform parameter</span> learning_rate <span class="token operator">=</span> trial<span class="token punctuation">.</span>suggest_loguniform<span class="token punctuation">(</span><span class="token string">'learning_rate'</span><span class="token punctuation">,</span> <span class="token number">1e</span><span class="token operator">-</span><span class="token number">5</span><span class="token punctuation">,</span> <span class="token number">1e</span><span class="token operator">-</span><span class="token number">2</span><span class="token punctuation">)</span> <span class="token comment"># Discrete-uniform parameter</span> drop_path_rate <span class="token operator">=</span> trial<span class="token punctuation">.</span>suggest_discrete_uniform<span class="token punctuation">(</span><span class="token string">'drop_path_rate'</span><span class="token punctuation">,</span> <span class="token number">0.0</span><span class="token punctuation">,</span> <span class="token number">1.0</span><span class="token punctuation">,</span> <span class="token number">0.1</span><span class="token punctuation">)</span> |
Định nghĩa study object
Để tìm kiếm siêu tham số, bạn cần khởi tạo một đối tượng là study
Đối tượng này lưu kết quả tối ưu của bạn.
1 2 | Study <span class="token operator">=</span> Optuna<span class="token punctuation">.</span>Creat_study <span class="token punctuation">(</span><span class="token punctuation">)</span> |
Sau đó, sử dụng phương thức optimize:
1 2 | Study<span class="token punctuation">.</span>Optimize <span class="token punctuation">(</span> Objective <span class="token punctuation">,</span> N_trials <span class="token operator">=</span> <span class="token number">100</span> <span class="token punctuation">)</span> |
Trong đó, tham số thứ 1 là hàm Objective , tham số thứ 2 là số lượng thử nghiệm. Quá trình tối ưu này được thực hiện trong đối tượng Study và sẽ thực hiện tìm giá trị cực tiểu của các tham số trong hàm Objective bằng phương pháp tối ưu hóa.
Khi kết thúc quá trình trên, giá trị tối ưu sẽ được lưu lại và bạn có thể xem nó bằng câu lệnh sau:
1 2 3 4 5 6 7 8 9 10 11 | <span class="token comment"># Giá trị tối ưu của các tham số khởi tạo trong hàm **Objective**</span> Study<span class="token punctuation">.</span>Best_params <span class="token comment"># Kết quả mô hình tương ứng với các tham số tối ưu trên</span> Study<span class="token punctuation">.</span>Best_value <span class="token comment"># Xem tất cả trạng thái của các lần thử nghiệm</span> Study<span class="token punctuation">.</span>Trials |
Thử nghiệm và so sánh giữa mô hình sử dụng Optuna và chọn cơm siêu tham số trên bộ Iris
Install Optuna
Có thể install Optuna bằng pip or conda
1 2 | !pip install <span class="token operator">-</span><span class="token operator">-</span>quiet optuna |
Kiểm tra phiên bản
1 2 3 | <span class="token keyword">import</span> optuna optuna<span class="token punctuation">.</span>__version__ |
Tối ưu hóa siêu tham số
Thử nghiệm với mô hình chọn tham số thử công
Mình sẽ thực hiện thử nghiệm mô hình random forest để thực hiện phân loại trên bộ dataset Iris và chọn tham số thủ công như sau:
1 2 3 4 5 6 7 8 9 10 | <span class="token keyword">import</span> sklearn<span class="token punctuation">.</span>datasets <span class="token keyword">import</span> sklearn<span class="token punctuation">.</span>ensemble <span class="token keyword">import</span> sklearn<span class="token punctuation">.</span>model_selection iris <span class="token operator">=</span> sklearn<span class="token punctuation">.</span>datasets<span class="token punctuation">.</span>load_iris<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># Prepare the data.</span> clf <span class="token operator">=</span> sklearn<span class="token punctuation">.</span>ensemble<span class="token punctuation">.</span>RandomForestClassifier<span class="token punctuation">(</span> n_estimators<span class="token operator">=</span><span class="token number">5</span><span class="token punctuation">,</span> max_depth<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">)</span> <span class="token comment"># Define the model.</span> sklearn<span class="token punctuation">.</span>model_selection<span class="token punctuation">.</span>cross_val_score<span class="token punctuation">(</span> clf<span class="token punctuation">,</span> iris<span class="token punctuation">.</span>data<span class="token punctuation">,</span> iris<span class="token punctuation">.</span>target<span class="token punctuation">,</span> n_jobs<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> cv<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">.</span>mean<span class="token punctuation">(</span><span class="token punctuation">)</span> |
Kết quả mà mô hình trên đạt được là : 0.966 . Thật may là kết quả đạt được khá tốt với n_estimators = 5 và max_depth =3.
Thử nghiệm sử dụng Optuna để tự động điều chỉnh tham số
Siêu tham số của mô hình trên là n_estimators và max_depth. Sử dụng đối tượng trial để định nghĩa chúng. Sau đó tạo đối tượng study để tối ưu hóa siêu tham số này và cuối cùng lấy ra siêu tham số tốt nhất.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | <span class="token keyword">import</span> optuna <span class="token keyword">def</span> <span class="token function">objective</span><span class="token punctuation">(</span>trial<span class="token punctuation">)</span><span class="token punctuation">:</span> iris <span class="token operator">=</span> sklearn<span class="token punctuation">.</span>datasets<span class="token punctuation">.</span>load_iris<span class="token punctuation">(</span><span class="token punctuation">)</span> n_estimators <span class="token operator">=</span> trial<span class="token punctuation">.</span>suggest_int<span class="token punctuation">(</span><span class="token string">'n_estimators'</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">20</span><span class="token punctuation">)</span> max_depth <span class="token operator">=</span> <span class="token builtin">int</span><span class="token punctuation">(</span>trial<span class="token punctuation">.</span>suggest_float<span class="token punctuation">(</span><span class="token string">'max_depth'</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">32</span><span class="token punctuation">,</span> log<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">)</span> clf <span class="token operator">=</span> sklearn<span class="token punctuation">.</span>ensemble<span class="token punctuation">.</span>RandomForestClassifier<span class="token punctuation">(</span> n_estimators<span class="token operator">=</span>n_estimators<span class="token punctuation">,</span> max_depth<span class="token operator">=</span>max_depth<span class="token punctuation">)</span> <span class="token keyword">return</span> sklearn<span class="token punctuation">.</span>model_selection<span class="token punctuation">.</span>cross_val_score<span class="token punctuation">(</span> clf<span class="token punctuation">,</span> iris<span class="token punctuation">.</span>data<span class="token punctuation">,</span> iris<span class="token punctuation">.</span>target<span class="token punctuation">,</span> n_jobs<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> cv<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">.</span>mean<span class="token punctuation">(</span><span class="token punctuation">)</span> study <span class="token operator">=</span> optuna<span class="token punctuation">.</span>create_study<span class="token punctuation">(</span>direction<span class="token operator">=</span><span class="token string">'maximize'</span><span class="token punctuation">)</span> study<span class="token punctuation">.</span>optimize<span class="token punctuation">(</span>objective<span class="token punctuation">,</span> n_trials<span class="token operator">=</span><span class="token number">100</span><span class="token punctuation">)</span> |
Sau khi xong quá trình tối ưu sau 100 lần thử nghiệm thì chúng ta có được thử nghiệm tốt nhất.
1 2 3 4 5 | trial <span class="token operator">=</span> study<span class="token punctuation">.</span>best_trial <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'Accuracy: {}'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>trial<span class="token punctuation">.</span>value<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Best hyperparameters: {}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>trial<span class="token punctuation">.</span>params<span class="token punctuation">)</span><span class="token punctuation">)</span> |
Kết quả như sau :
1 2 3 | Accuracy<span class="token punctuation">:</span> <span class="token number">0.9733333333333333</span> Best hyperparameters<span class="token punctuation">:</span> <span class="token punctuation">{</span><span class="token string">'n_estimators'</span><span class="token punctuation">:</span> <span class="token number">7</span><span class="token punctuation">,</span> <span class="token string">'max_depth'</span><span class="token punctuation">:</span> <span class="token number">8.773682352088931</span><span class="token punctuation">}</span> |
Như vậy, accuracy tăng hơn 1%, và các siêu tham số cũng đã thay đổi và không giống với mình chọn thủ công ở trên.
Thử nghiệm bộ dữ liệu trên các thuật toán khác nhau
Từ trước đến nay, việc thử nghiệm trên các thuật toán khác nhau để đưa ra được đánh giá khách quan nhất tốn rất nhiều thời gian code và huấn luyện thì nay đã có Optuna hỗ trợ cho bạn việc đó. Let’s get it!
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | <span class="token keyword">import</span> sklearn<span class="token punctuation">.</span>svm <span class="token keyword">def</span> <span class="token function">objective</span><span class="token punctuation">(</span>trial<span class="token punctuation">)</span><span class="token punctuation">:</span> iris <span class="token operator">=</span> sklearn<span class="token punctuation">.</span>datasets<span class="token punctuation">.</span>load_iris<span class="token punctuation">(</span><span class="token punctuation">)</span> classifier <span class="token operator">=</span> trial<span class="token punctuation">.</span>suggest_categorical<span class="token punctuation">(</span><span class="token string">'classifier'</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token string">'RandomForest'</span><span class="token punctuation">,</span> <span class="token string">'SVC'</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token keyword">if</span> classifier <span class="token operator">==</span> <span class="token string">'RandomForest'</span><span class="token punctuation">:</span> n_estimators <span class="token operator">=</span> trial<span class="token punctuation">.</span>suggest_int<span class="token punctuation">(</span><span class="token string">'n_estimators'</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">20</span><span class="token punctuation">)</span> max_depth <span class="token operator">=</span> <span class="token builtin">int</span><span class="token punctuation">(</span>trial<span class="token punctuation">.</span>suggest_float<span class="token punctuation">(</span><span class="token string">'max_depth'</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">32</span><span class="token punctuation">,</span> log<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">)</span> clf <span class="token operator">=</span> sklearn<span class="token punctuation">.</span>ensemble<span class="token punctuation">.</span>RandomForestClassifier<span class="token punctuation">(</span> n_estimators<span class="token operator">=</span>n_estimators<span class="token punctuation">,</span> max_depth<span class="token operator">=</span>max_depth<span class="token punctuation">)</span> <span class="token keyword">else</span><span class="token punctuation">:</span> c <span class="token operator">=</span> trial<span class="token punctuation">.</span>suggest_float<span class="token punctuation">(</span><span class="token string">'svc_c'</span><span class="token punctuation">,</span> <span class="token number">1e</span><span class="token operator">-</span><span class="token number">10</span><span class="token punctuation">,</span> <span class="token number">1e10</span><span class="token punctuation">,</span> log<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span> clf <span class="token operator">=</span> sklearn<span class="token punctuation">.</span>svm<span class="token punctuation">.</span>SVC<span class="token punctuation">(</span>C<span class="token operator">=</span>c<span class="token punctuation">,</span> gamma<span class="token operator">=</span><span class="token string">'auto'</span><span class="token punctuation">)</span> <span class="token keyword">return</span> sklearn<span class="token punctuation">.</span>model_selection<span class="token punctuation">.</span>cross_val_score<span class="token punctuation">(</span> clf<span class="token punctuation">,</span> iris<span class="token punctuation">.</span>data<span class="token punctuation">,</span> iris<span class="token punctuation">.</span>target<span class="token punctuation">,</span> n_jobs<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> cv<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">.</span>mean<span class="token punctuation">(</span><span class="token punctuation">)</span> study <span class="token operator">=</span> optuna<span class="token punctuation">.</span>create_study<span class="token punctuation">(</span>direction<span class="token operator">=</span><span class="token string">'maximize'</span><span class="token punctuation">)</span> study<span class="token punctuation">.</span>optimize<span class="token punctuation">(</span>objective<span class="token punctuation">,</span> n_trials<span class="token operator">=</span><span class="token number">100</span><span class="token punctuation">)</span> |
Cùng xem kết quả thôi :
1 2 3 4 5 | trial <span class="token operator">=</span> study<span class="token punctuation">.</span>best_trial <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'Accuracy: {}'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>trial<span class="token punctuation">.</span>value<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Best hyperparameters: {}"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>trial<span class="token punctuation">.</span>params<span class="token punctuation">)</span><span class="token punctuation">)</span> |
1 2 3 | Accuracy<span class="token punctuation">:</span> <span class="token number">0.9866666666666667</span> Best hyperparameters<span class="token punctuation">:</span> <span class="token punctuation">{</span><span class="token string">'classifier'</span><span class="token punctuation">:</span> <span class="token string">'SVC'</span><span class="token punctuation">,</span> <span class="token string">'svc_c'</span><span class="token punctuation">:</span> <span class="token number">4.448968980739045</span><span class="token punctuation">}</span> |
Sau chỉ với mấy dòng code sử dụng cô nàng Optuna thì độ chính xác đã đạt 0.987%, và như trên thì mô hình SVC đạt kết quả tốt hơn Random forest
Visualization với Optuna
Không chỉ hỗ trợ việc tự động tối ưu các siêu tham số, cô nàng này còn hỗ trợ chúng ta việc visualize ra các trạng thái của các lần thử nghiệm
- Visualize lịch sử của đối tượng study
1 2 | optuna<span class="token punctuation">.</span>visualization<span class="token punctuation">.</span>plot_optimization_history<span class="token punctuation">(</span>study<span class="token punctuation">)</span> |
2. Visualize độ chính xác của một siêu tham số tại mỗi lần thử nghiệm
1 2 | optuna<span class="token punctuation">.</span>visualization<span class="token punctuation">.</span>plot_contour<span class="token punctuation">(</span>study<span class="token punctuation">,</span> params<span class="token operator">=</span><span class="token punctuation">[</span><span class="token string">'n_estimators'</span><span class="token punctuation">,</span> <span class="token string">'max_depth'</span><span class="token punctuation">]</span><span class="token punctuation">)</span> |
- Visualize accuracy surface cho các siều tham số của mô hình random forest
1 2 | optuna<span class="token punctuation">.</span>visualization<span class="token punctuation">.</span>plot_contour<span class="token punctuation">(</span>study<span class="token punctuation">,</span> params<span class="token operator">=</span><span class="token punctuation">[</span><span class="token string">'n_estimators'</span><span class="token punctuation">,</span> <span class="token string">'max_depth'</span><span class="token punctuation">]</span><span class="token punctuation">)</span> |
Kết luận
Nếu bạn đang băn khoăn hay mất thời gian cho việc lựa chọn tham số thì hãy thử dùng Optuna xem sao nhé. Nếu bài viết này được mọi người quan tâm thì mình sẽ cố gắng tìm hiểu và viết thêm bài viết về prunning model sử dụng Optuna. Cảm ơn các bạn đã theo dõi, hãy để lại 1 vote cho mình nhé.