Giới thiệu
Trong quá trình training model ML hay mạng neural, một bước cực kỳ quan trọng và không thể thiếu là lựa chọn giá trị cho các tham số như learning rate, epochs, số layers, hidden units,… Việc lựa chọn các tham số hợp lý thường dựa trên kinh nghiệm và với mỗi bộ tham số như vậy ta phải huấn luyện model, quan sát kết quả đạt được, đánh giá kết quả, điều chỉnh tham số và lặp lại. Để tự động hóa quy trình trên, các thuật toán tìm kiếm như Grid Search hay Random Search được sử dụng. Tuy nhiên các thuật toán này chỉ hoạt động hiệu quả với số lượng tham số ít, thường là nhỏ hơn 5 vì không gian tìm kiếm sẽ tăng nhanh khi số lượng tham số lớn khiến cho thời gian tìm kiếm trở nên rất lâu. Bayesian Optimization (BO) là một thuật toán giúp tối ưu hiệu quả những hàm mục tiêu có chi phí evaluation lớn (như training 1 mạng neural) dựa trên định lý Bayesian. BO làm giảm đáng kể số lần thử sai khi tune tham số so với Grid Search hay Random Search, điều mà với một mạng deep learning lớn (ResNet, Inception, Xception,..) cùng bộ data khổng lồ (ImageNet) sẽ tốn hàng giờ thậm chí hàng ngày liền. Trong bài này, chúng ta sẽ tìm hiểu cách BO hoạt động và ứng dụng nó tối ưu tham số training model CNN cho bài toán MNIST nhé.
Tổng quan thuật toán
BO tối ưu hàm mục tiêu dựa trên học máy. Viết dưới dạng công thức:
maxx∈Af(x)max _{x in A} f(x)
x∈Amaxf(x)
Trong đó A là tập các tham số cần tune, và số lượng tham số không nên nhiều hơn 20 để đảm bảo thuật toán học tốt nhất; f(x)f(x)f(x) là hàm mục tiêu để tối ưu lớn nhất/nhỏ nhất (ví dụ: accuracy, loss,…), f(x)f(x)f(x) có một vài đặc điểm sau:
- fff là hàm liên tục
- Chi phí evaluation của fff lớn
- fff là một black box, ta không hề biết các tính chất của fff như tính tuyến tính, hàm lồi, hàm lõm,…
- Vì fff là một black box, khi thực hiện evaluate fff chỉ lấy ra giá trị f(x)f(x)f(x), các phương pháp liên quan đến đạo hàm như gradient descent sẽ không khả thi và không được sử dụng.
BO xây dựng một hàm surrogate (hàm thay thế) để model fff, mục tiêu là tối ưu surrogate càng giống fff càng tốt, từ đó có thể tìm ra global minimum/maximum của fff dễ dàng bằng cách lấy minimum/maximum của surrogate. Surrogate này là một Gaussian Process với prior distribution và likelihood được định nghĩa từ đầu. Mỗi khi evaluate fff của điểm x mới, ta sẽ tính được posterior của surrogate bằng quy tắc Bayes. Dựa trên posterior này, một Acquisition function tìm ra điểm x’ có tiềm năng mang lại giá trị fff lớn nhất, x’ tiếp tục được evaluate và update posterior. Cứ lặp như vậy cho đến khi ta đạt được kết quả f(x)f(x)f(x) mong muốn.
Pseudo-code cho thuật toán BO:
Bước chuẩn bị
Trong bài này mình sẽ sử dụng GPyTorch để implement phần chính của BO là Gaussian Process Regression, mục đích là cung cấp cái nhìn rõ hơn về các thành phần bên trong của GP Regression. Ngoài ra bạn cũng có thể sử dụng luôn model GaussianProcessRegressor của scikit-learn nhé.
Cài đặt các thư viện liên quan:
1 2 3 4 5 6 7 | pip install torch torchvision pip install gpytorch pip install numpy pip install scipy pip install tqdm pip install ax-platform |
Import thư viện:
1 2 3 4 5 6 7 8 9 10 11 12 13 | <span class="token keyword">import</span> numpy <span class="token keyword">as</span> np <span class="token keyword">from</span> scipy<span class="token punctuation">.</span>stats <span class="token keyword">import</span> norm<span class="token punctuation">,</span> loguniform <span class="token keyword">from</span> ax<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>tutorials<span class="token punctuation">.</span>cnn_utils <span class="token keyword">import</span> load_mnist<span class="token punctuation">,</span> train<span class="token punctuation">,</span> evaluate<span class="token punctuation">,</span> CNN <span class="token keyword">from</span> warnings <span class="token keyword">import</span> catch_warnings <span class="token keyword">from</span> warnings <span class="token keyword">import</span> simplefilter <span class="token keyword">from</span> matplotlib <span class="token keyword">import</span> pyplot <span class="token keyword">from</span> typing <span class="token keyword">import</span> Dict <span class="token keyword">import</span> torch <span class="token keyword">import</span> torch<span class="token punctuation">.</span>nn <span class="token keyword">as</span> nn <span class="token keyword">import</span> torch<span class="token punctuation">.</span>optim <span class="token keyword">as</span> optim <span class="token keyword">from</span> torch<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>data <span class="token keyword">import</span> DataLoader <span class="token keyword">from</span> tqdm <span class="token keyword">import</span> tqdm |
Tiếp đến là hàm để train CNN model trên tập MNIST. Hàm này nhận 1 dict các tham số và giá trị tương ứng và train model bằng các tham số đó, hàm trả về model đã được train. Ở đây mình sẽ chỉ tune 1 tham số duy nhất là learning rate.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | <span class="token keyword">def</span> <span class="token function">train</span><span class="token punctuation">(</span> net<span class="token punctuation">:</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>Module<span class="token punctuation">,</span> train_loader<span class="token punctuation">:</span> DataLoader<span class="token punctuation">,</span> parameters<span class="token punctuation">:</span> Dict<span class="token punctuation">[</span><span class="token builtin">str</span><span class="token punctuation">,</span> <span class="token builtin">float</span><span class="token punctuation">]</span><span class="token punctuation">,</span> dtype<span class="token punctuation">:</span> torch<span class="token punctuation">.</span>dtype<span class="token punctuation">,</span> device<span class="token punctuation">:</span> torch<span class="token punctuation">.</span>device<span class="token punctuation">,</span> <span class="token punctuation">)</span> <span class="token operator">-</span><span class="token operator">></span> nn<span class="token punctuation">.</span>Module<span class="token punctuation">:</span> net<span class="token punctuation">.</span>to<span class="token punctuation">(</span>dtype<span class="token operator">=</span>dtype<span class="token punctuation">,</span> device<span class="token operator">=</span>device<span class="token punctuation">)</span> net<span class="token punctuation">.</span>train<span class="token punctuation">(</span><span class="token punctuation">)</span> criterion <span class="token operator">=</span> nn<span class="token punctuation">.</span>CrossEntropyLoss<span class="token punctuation">(</span><span class="token punctuation">)</span> optimizer <span class="token operator">=</span> optim<span class="token punctuation">.</span>SGD<span class="token punctuation">(</span> net<span class="token punctuation">.</span>parameters<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> lr<span class="token operator">=</span>parameters<span class="token punctuation">.</span>get<span class="token punctuation">(</span><span class="token string">"lr"</span><span class="token punctuation">)</span><span class="token punctuation">,</span> momentum<span class="token operator">=</span><span class="token number">0.9</span> <span class="token punctuation">)</span> num_epochs <span class="token operator">=</span> <span class="token number">20</span> <span class="token keyword">for</span> _ <span class="token keyword">in</span> tqdm<span class="token punctuation">(</span><span class="token builtin">range</span><span class="token punctuation">(</span>num_epochs<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">for</span> inputs<span class="token punctuation">,</span> labels <span class="token keyword">in</span> train_loader<span class="token punctuation">:</span> inputs <span class="token operator">=</span> inputs<span class="token punctuation">.</span>to<span class="token punctuation">(</span>dtype<span class="token operator">=</span>dtype<span class="token punctuation">,</span> device<span class="token operator">=</span>device<span class="token punctuation">)</span> labels <span class="token operator">=</span> labels<span class="token punctuation">.</span>to<span class="token punctuation">(</span>device<span class="token operator">=</span>device<span class="token punctuation">)</span> optimizer<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span> outputs <span class="token operator">=</span> net<span class="token punctuation">(</span>inputs<span class="token punctuation">)</span> loss <span class="token operator">=</span> criterion<span class="token punctuation">(</span>outputs<span class="token punctuation">,</span> labels<span class="token punctuation">)</span> loss<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span> optimizer<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">return</span> net |
Load MNIST data:
1 2 3 4 5 6 | torch<span class="token punctuation">.</span>manual_seed<span class="token punctuation">(</span><span class="token number">12345</span><span class="token punctuation">)</span> dtype <span class="token operator">=</span> torch<span class="token punctuation">.</span><span class="token builtin">float</span> device <span class="token operator">=</span> torch<span class="token punctuation">.</span>device<span class="token punctuation">(</span><span class="token string">"cuda"</span> <span class="token keyword">if</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>is_available<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">else</span> <span class="token string">"cpu"</span><span class="token punctuation">)</span> BATCH_SIZE <span class="token operator">=</span> <span class="token number">512</span> train_loader<span class="token punctuation">,</span> valid_loader<span class="token punctuation">,</span> test_loader <span class="token operator">=</span> load_mnist<span class="token punctuation">(</span>batch_size<span class="token operator">=</span>BATCH_SIZE<span class="token punctuation">)</span> |
Hàm sample_lr
generate các giá trị learning rate khác nhau trong khoảng 1e-6 đến 0.2
1 2 3 | <span class="token keyword">def</span> <span class="token function">sample_lr</span><span class="token punctuation">(</span>size<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">return</span> loguniform<span class="token punctuation">(</span><span class="token number">0.000001</span><span class="token punctuation">,</span> <span class="token number">0.2</span><span class="token punctuation">)</span><span class="token punctuation">.</span>rvs<span class="token punctuation">(</span>size<span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> np<span class="token punctuation">.</span>newaxis<span class="token punctuation">]</span> |
Cuối cùng, ta định nghĩa hàm objective
chính là hàm số mà ta muốn tối ưu. Cụ thể, hàm này lấy vào x là hyperparameter (learning rate), thực hiện train, evaluate CNN model và trả về giá trị accuracy.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | <span class="token keyword">def</span> <span class="token function">objective</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">:</span> parameterization <span class="token operator">=</span> <span class="token punctuation">{</span><span class="token string">"lr"</span><span class="token punctuation">:</span> x<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">}</span> net <span class="token operator">=</span> CNN<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Training CNN model...."</span><span class="token punctuation">)</span> net <span class="token operator">=</span> train<span class="token punctuation">(</span>net<span class="token operator">=</span>net<span class="token punctuation">,</span> train_loader<span class="token operator">=</span>train_loader<span class="token punctuation">,</span> parameters<span class="token operator">=</span>parameterization<span class="token punctuation">,</span> dtype<span class="token operator">=</span>dtype<span class="token punctuation">,</span> device<span class="token operator">=</span>device<span class="token punctuation">)</span> acc <span class="token operator">=</span> evaluate<span class="token punctuation">(</span> net<span class="token operator">=</span>net<span class="token punctuation">,</span> data_loader<span class="token operator">=</span>valid_loader<span class="token punctuation">,</span> dtype<span class="token operator">=</span>dtype<span class="token punctuation">,</span> device<span class="token operator">=</span>device<span class="token punctuation">,</span> <span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Accuracy:"</span><span class="token punctuation">,</span> acc<span class="token punctuation">,</span> <span class="token string">"Hyperparams:"</span><span class="token punctuation">,</span> parameterization<span class="token punctuation">,</span> <span class="token string">'n'</span><span class="token punctuation">)</span> <span class="token keyword">return</span> acc |
Gaussian Process Regression
Gausian Process (GP) Regression là một phương pháp thống kê Bayesian để model các hàm số. GP regression hoạt động hiệu quả trên những tập dataset nhỏ nhờ vào giả định các điểm dữ liệu nằm trên một phân phối nhiều chiều (Multivariate Normal). Đầu tiên ta định nghĩa một GP model, chính là prior distribution trên hàm số. Với một biến ngẫu nhiên X, phân bố của nó được xác định bằng hàm mật độ xác suất (probability distribution function – pdf):
f(x∣μ,σ2)=12πσ2e−(x−μ)22σ2fleft(x mid mu, sigma^{2}right)=frac{1}{sqrt{2 pi sigma^{2}}} e^{-frac{(x-mu)^{2}}{2 sigma^{2}}}
f(x∣μ,σ2)=2πσ21e−2σ2(x−μ)2
với μ,σmu, sigmaμ,σ lần lượt là mean và variance. Khi có nhiều hơn một biến ngẫu nhiên, mean sẽ được biểu diễn dưới dạng 1 vector và variance được biểu diễn dưới dạng ma trận, gọi là covariance matrix. Covariance matrix được xây dựng bằng cách evaluate một hàm covariance hay kernel Σ0Σ_0Σ0 trên từng cặp giá trị xi,xjx_i, x_jxi,xj của các biến ngẫu nhiên. Kernel được chọn để các điểm xi,xjx_i, x_jxi,xj càng gần nhau trong không gian đầu vào càng có giá trị lớn. Các kernel khác nhau biểu diễn prior khác nhau, dẫn đến hàm số kết quả khác nhau.
Để xây dựng một prior (GP model) bằng GPyTorch, mình sẽ kế thừa class ExactGP và custom mean cũng như covariance modules. Cụ thể, mình sử dụng ConstantMean
và RBFKernel
:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | <span class="token keyword">import</span> torch <span class="token keyword">import</span> gpytorch <span class="token keyword">class</span> <span class="token class-name">ExactGPModel</span><span class="token punctuation">(</span>gpytorch<span class="token punctuation">.</span>models<span class="token punctuation">.</span>ExactGP<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> train_x<span class="token punctuation">,</span> train_y<span class="token punctuation">,</span> likelihood<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token builtin">super</span><span class="token punctuation">(</span>ExactGPModel<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span>train_x<span class="token punctuation">,</span> train_y<span class="token punctuation">,</span> likelihood<span class="token punctuation">)</span> self<span class="token punctuation">.</span>mean_module <span class="token operator">=</span> gpytorch<span class="token punctuation">.</span>means<span class="token punctuation">.</span>ConstantMean<span class="token punctuation">(</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>covar_module <span class="token operator">=</span> gpytorch<span class="token punctuation">.</span>kernels<span class="token punctuation">.</span>ScaleKernel<span class="token punctuation">(</span>gpytorch<span class="token punctuation">.</span>kernels<span class="token punctuation">.</span>RBFKernel<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span> mean_x <span class="token operator">=</span> self<span class="token punctuation">.</span>mean_module<span class="token punctuation">(</span>x<span class="token punctuation">)</span> covar_x <span class="token operator">=</span> self<span class="token punctuation">.</span>covar_module<span class="token punctuation">(</span>x<span class="token punctuation">)</span> <span class="token keyword">return</span> gpytorch<span class="token punctuation">.</span>distributions<span class="token punctuation">.</span>MultivariateNormal<span class="token punctuation">(</span>mean_x<span class="token punctuation">,</span> covar_x<span class="token punctuation">)</span> |
Class Regressor
là một GP Regression bao gồm GP model và gaussian likelihood.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | <span class="token keyword">class</span> <span class="token class-name">Regressor</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> train_x<span class="token punctuation">,</span> train_y<span class="token punctuation">)</span><span class="token punctuation">:</span> self<span class="token punctuation">.</span>training_iter <span class="token operator">=</span> <span class="token number">100</span> self<span class="token punctuation">.</span>X <span class="token operator">=</span> train_x self<span class="token punctuation">.</span>y <span class="token operator">=</span> train_y self<span class="token punctuation">.</span>likelihood <span class="token operator">=</span> gpytorch<span class="token punctuation">.</span>likelihoods<span class="token punctuation">.</span>GaussianLikelihood<span class="token punctuation">(</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>gp <span class="token operator">=</span> ExactGPModel<span class="token punctuation">(</span>self<span class="token punctuation">.</span>X<span class="token punctuation">,</span> self<span class="token punctuation">.</span>y<span class="token punctuation">,</span> self<span class="token punctuation">.</span>likelihood<span class="token punctuation">)</span> <span class="token keyword">def</span> <span class="token function">train</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span> self<span class="token punctuation">.</span>gp<span class="token punctuation">.</span>train<span class="token punctuation">(</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>likelihood<span class="token punctuation">.</span>train<span class="token punctuation">(</span><span class="token punctuation">)</span> optimizer <span class="token operator">=</span> torch<span class="token punctuation">.</span>optim<span class="token punctuation">.</span>Adam<span class="token punctuation">(</span><span class="token punctuation">[</span> <span class="token punctuation">{</span><span class="token string">'params'</span><span class="token punctuation">:</span> self<span class="token punctuation">.</span>gp<span class="token punctuation">.</span>parameters<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">}</span><span class="token punctuation">,</span> <span class="token punctuation">]</span><span class="token punctuation">,</span> lr<span class="token operator">=</span><span class="token number">0.05</span><span class="token punctuation">)</span> mll <span class="token operator">=</span> gpytorch<span class="token punctuation">.</span>mlls<span class="token punctuation">.</span>ExactMarginalLogLikelihood<span class="token punctuation">(</span>self<span class="token punctuation">.</span>likelihood<span class="token punctuation">,</span> self<span class="token punctuation">.</span>gp<span class="token punctuation">)</span> <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>training_iter<span class="token punctuation">)</span><span class="token punctuation">:</span> optimizer<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span> output <span class="token operator">=</span> self<span class="token punctuation">.</span>gp<span class="token punctuation">(</span>self<span class="token punctuation">.</span>X<span class="token punctuation">)</span> loss <span class="token operator">=</span> <span class="token operator">-</span>mll<span class="token punctuation">(</span>output<span class="token punctuation">,</span> self<span class="token punctuation">.</span>y<span class="token punctuation">)</span> loss<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span> optimizer<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">def</span> <span class="token function">predict</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> X<span class="token punctuation">)</span><span class="token punctuation">:</span> self<span class="token punctuation">.</span>likelihood<span class="token punctuation">.</span><span class="token builtin">eval</span><span class="token punctuation">(</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>gp<span class="token punctuation">.</span><span class="token builtin">eval</span><span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">with</span> torch<span class="token punctuation">.</span>no_grad<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> gpytorch<span class="token punctuation">.</span>settings<span class="token punctuation">.</span>fast_pred_var<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span> pred <span class="token operator">=</span> self<span class="token punctuation">.</span>gp<span class="token punctuation">(</span>X<span class="token punctuation">)</span> <span class="token keyword">return</span> self<span class="token punctuation">.</span>likelihood<span class="token punctuation">(</span>pred<span class="token punctuation">)</span> |
Acquisition Function
Khi đã có posterior từ GP, câu hỏi tiếp theo là làm thế nào để chọn được candidate tiềm năng cho ra kết quả (mean) tốt nhất? Cách đơn giản nhất là lựa chọn vị trí maximum của mean, nhưng còn những vùng có uncertainty lớn (standard deviation lớn) cũng rất tiềm năng cho ra kết quả tốt chứ. Acquisition function dựa trên cả 2 yếu tố là cực trị hiện tại và uncertainty của posterior để đưa ra candidate tiềm năng nhất. Một trong những hàm acquisition hay được sử dụng là Expected Improvement (EI) được tính bằng công thuwcs:
EI(x)={(μ(x)−f(x+)−ξ)Φ(Z)+σ(x)ϕ(Z) , neˆˊu σ(x)>00 , neˆˊu σ(x)=0mathrm{EI}(mathbf{x})=left{begin{array}{ll}
left(mu(mathbf{x})-fleft(mathbf{x}^{+}right)-xiright) Phi(Z)+sigma(mathbf{x}) phi(Z) & text { , nếu } sigma(mathbf{x})>0 \
0 & text { , nếu } sigma(mathbf{x})=0
end{array}right.
EI(x)={(μ(x)−f(x+)−ξ)Φ(Z)+σ(x)ϕ(Z)0 , neˆˊu σ(x)>0 , neˆˊu σ(x)=0
Trong đó x+mathbf{x}^{+}x+ là điểm dữ liệu tốt nhất hiện tại; μ(x)mu(mathbf{x})μ(x) và σ(x)sigma(mathbf{x})σ(x) lần lượt là mean và variance của xxx; Φ(Z)Phi(Z)Φ(Z) và ϕ(Z)phi(Z)ϕ(Z) lần lượt là CDF và PDF của phân phối Z:
Z={μ(x)−f(x+)−ξσ(x) , neˆˊu σ(x)>00 , neˆˊu σ(x)=0Z=left{begin{array}{ll}
frac{mu(mathbf{x})-fleft(mathbf{x}^{+}right)-xi}{sigma(mathbf{x})} & text { , nếu } sigma(mathbf{x})>0 \
0 & text { , nếu } sigma(mathbf{x})=0
end{array}right.
Z={σ(x)μ(x)−f(x+)−ξ0 , neˆˊu σ(x)>0 , neˆˊu σ(x)=0
Số hạng (μ(x)−f(x+)−ξ)Φ(Z)left(mu(mathbf{x})-fleft(mathbf{x}^{+}right)-xiright) Phi(Z)(μ(x)−f(x+)−ξ)Φ(Z) cho biết “improvement” của xmathbf{x}x so với x+mathbf{x}^{+}x+, thể hiện sự exploitation trong khi σ(x)ϕ(Z)sigma(mathbf{x}) phi(Z)σ(x)ϕ(Z) “khám phá” những nơi có uncertainty (variance) cao, thể hiện sự exploration. Tham số ξxiξ điều khiển trade-off giữa exploration và exploitation. Nhìn công thức có vẻ phức tạp nhưng implement nó lại khá dễ dàng như sau:
1 2 3 4 5 6 7 8 9 | <span class="token keyword">def</span> <span class="token function">acquisition</span><span class="token punctuation">(</span>X<span class="token punctuation">,</span> Xsamples<span class="token punctuation">,</span> model<span class="token punctuation">)</span><span class="token punctuation">:</span> yhat<span class="token punctuation">,</span> _ <span class="token operator">=</span> surrogate<span class="token punctuation">(</span>model<span class="token punctuation">,</span> X<span class="token punctuation">)</span> best <span class="token operator">=</span> <span class="token builtin">max</span><span class="token punctuation">(</span>yhat<span class="token punctuation">)</span> mu<span class="token punctuation">,</span> std <span class="token operator">=</span> surrogate<span class="token punctuation">(</span>model<span class="token punctuation">,</span> Xsamples<span class="token punctuation">)</span> eps <span class="token operator">=</span> <span class="token number">0.03</span> Z <span class="token operator">=</span> <span class="token punctuation">(</span>mu <span class="token operator">-</span> best <span class="token operator">-</span> eps<span class="token punctuation">)</span> <span class="token operator">/</span> std ei <span class="token operator">=</span> <span class="token punctuation">(</span>mu <span class="token operator">-</span> best <span class="token operator">-</span> eps<span class="token punctuation">)</span><span class="token operator">*</span>norm<span class="token punctuation">.</span>cdf<span class="token punctuation">(</span>Z<span class="token punctuation">)</span> <span class="token operator">+</span> std<span class="token operator">*</span>norm<span class="token punctuation">.</span>pdf<span class="token punctuation">(</span>Z<span class="token punctuation">)</span> <span class="token keyword">return</span> ei |
Sau khi đã có hàm tính EI cho một điểm x bất kỳ, ta sẽ phải sample rất nhiều x và chọn ra điểm có EI lớn nhất. Tuy hàm EI đã có 2 yếu tố exploration và exploitation, người ta cho rằng nhiều khi nó vẫn hơi “tham lam” nghiêng về expoitation. Vậy nên thay vì chọn điểm EI cao nhất, mình sẽ chọn theo ϵepsilonϵ-greedy: chọn điểm tốt nhất với xác suất (1-ϵepsilonϵ) và chọn một điểm bất kỳ với xác suất ϵepsilonϵ, và ϵepsilonϵ sẽ giảm dần theo từng vòng lặp.
1 2 3 4 5 6 7 8 9 10 | <span class="token keyword">def</span> <span class="token function">opt_acquisition</span><span class="token punctuation">(</span>X<span class="token punctuation">,</span> y<span class="token punctuation">,</span> model<span class="token punctuation">)</span><span class="token punctuation">:</span> Xsamples <span class="token operator">=</span> sample_data<span class="token punctuation">(</span><span class="token number">1000</span><span class="token punctuation">)</span> scores <span class="token operator">=</span> acquisition<span class="token punctuation">(</span>X<span class="token punctuation">,</span> Xsamples<span class="token punctuation">,</span> model<span class="token punctuation">)</span> e <span class="token operator">=</span> <span class="token number">0.08</span><span class="token operator">**</span><span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>X<span class="token punctuation">)</span><span class="token operator">/</span>ITERS<span class="token punctuation">)</span> <span class="token keyword">if</span> random<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">></span> e<span class="token punctuation">:</span> ix <span class="token operator">=</span> argmax<span class="token punctuation">(</span>scores<span class="token punctuation">)</span> <span class="token keyword">else</span><span class="token punctuation">:</span> ix <span class="token operator">=</span> np<span class="token punctuation">.</span>random<span class="token punctuation">.</span>randint<span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>scores<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">return</span> Xsamples<span class="token punctuation">[</span>ix<span class="token punctuation">]</span> |
Main loop
Vòng lặp này theo đúng thuật toán:
- Train GP regression bằng initial data .
- Sử dụng acquisition function chọn candidate (learning rate) tiềm năng nhất.
- Evaluate candidate: train lại CNN model bằng candidate đó
- Append candidate và giá trị (accuracy) của nó vào tập dữ liệu ban đầu và lặp lại bước 1.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Collecting initial observations"</span><span class="token punctuation">)</span> X <span class="token operator">=</span> sample_lr<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">)</span> y <span class="token operator">=</span> np<span class="token punctuation">.</span>asarray<span class="token punctuation">(</span><span class="token punctuation">[</span>objective<span class="token punctuation">(</span>x<span class="token punctuation">)</span> <span class="token keyword">for</span> x <span class="token keyword">in</span> X<span class="token punctuation">]</span><span class="token punctuation">)</span> X <span class="token operator">=</span> torch<span class="token punctuation">.</span>DoubleTensor<span class="token punctuation">(</span>X<span class="token punctuation">)</span> y <span class="token operator">=</span> torch<span class="token punctuation">.</span>DoubleTensor<span class="token punctuation">(</span>y<span class="token punctuation">)</span> model <span class="token operator">=</span> Regressor<span class="token punctuation">(</span>X<span class="token punctuation">,</span> y<span class="token punctuation">)</span> model<span class="token punctuation">.</span>train<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># for visualization purposes</span> new_X <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> new_y <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"START OPTIMIZATION"</span><span class="token punctuation">)</span> <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>ITERS<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Iteration "</span><span class="token punctuation">,</span> i<span class="token punctuation">)</span> x <span class="token operator">=</span> opt_acquisition<span class="token punctuation">(</span>X<span class="token punctuation">,</span> model<span class="token punctuation">)</span> actual <span class="token operator">=</span> objective<span class="token punctuation">(</span>x<span class="token punctuation">)</span> est<span class="token punctuation">,</span> _ <span class="token operator">=</span> surrogate<span class="token punctuation">(</span>model<span class="token punctuation">,</span> torch<span class="token punctuation">.</span>DoubleTensor<span class="token punctuation">(</span><span class="token punctuation">[</span>x<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span> X <span class="token operator">=</span> np<span class="token punctuation">.</span>vstack<span class="token punctuation">(</span><span class="token punctuation">(</span>X<span class="token punctuation">.</span>numpy<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">)</span> y <span class="token operator">=</span> np<span class="token punctuation">.</span>hstack<span class="token punctuation">(</span><span class="token punctuation">(</span>y<span class="token punctuation">.</span>numpy<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> actual<span class="token punctuation">)</span><span class="token punctuation">)</span> new_X<span class="token punctuation">.</span>append<span class="token punctuation">(</span>x<span class="token punctuation">)</span> new_y<span class="token punctuation">.</span>append<span class="token punctuation">(</span>actual<span class="token punctuation">)</span> X <span class="token operator">=</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span>X<span class="token punctuation">)</span> y <span class="token operator">=</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span>y<span class="token punctuation">)</span> model <span class="token operator">=</span> Regressor<span class="token punctuation">(</span>X<span class="token punctuation">,</span> y<span class="token punctuation">)</span> model<span class="token punctuation">.</span>train<span class="token punctuation">(</span><span class="token punctuation">)</span> ix <span class="token operator">=</span> np<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>y<span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'Best Result:'</span><span class="token punctuation">,</span> X<span class="token punctuation">[</span>ix<span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">.</span>numpy<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> y<span class="token punctuation">[</span>ix<span class="token punctuation">]</span><span class="token punctuation">.</span>numpy<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span> |
Chạy thuật toán (sẽ mất một chút thời gian đó):
Qua 10 vòng lặp, kết quả là giá trị learning rate tốt nhất là 0.012973 với accuracy bằng 0.9823
Visualize posterior
Ta sẽ visualize posterior của GP regression để có cái nhìn trực quan về cách mà BO lựa chọn candidate.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | <span class="token keyword">def</span> <span class="token function">plot_post</span><span class="token punctuation">(</span>X<span class="token punctuation">,</span> y<span class="token punctuation">,</span> gp<span class="token punctuation">,</span> new_X<span class="token punctuation">,</span> new_y<span class="token punctuation">)</span><span class="token punctuation">:</span> pyplot<span class="token punctuation">.</span>figure<span class="token punctuation">(</span>figsize<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">8</span><span class="token punctuation">,</span> <span class="token number">8</span><span class="token punctuation">)</span><span class="token punctuation">)</span> X_ <span class="token operator">=</span> np<span class="token punctuation">.</span>logspace<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">6</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1000</span><span class="token punctuation">)</span> y_mean<span class="token punctuation">,</span> y_std <span class="token operator">=</span> surrogate<span class="token punctuation">(</span>gp<span class="token punctuation">,</span> X_<span class="token punctuation">)</span> pyplot<span class="token punctuation">.</span>xscale<span class="token punctuation">(</span><span class="token string">'log'</span><span class="token punctuation">)</span> pyplot<span class="token punctuation">.</span>plot<span class="token punctuation">(</span>X_<span class="token punctuation">,</span> y_mean<span class="token punctuation">,</span> <span class="token string">'r'</span><span class="token punctuation">,</span> lw<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> zorder<span class="token operator">=</span><span class="token number">9</span><span class="token punctuation">,</span> label<span class="token operator">=</span><span class="token string">'surrogate function'</span><span class="token punctuation">)</span> pyplot<span class="token punctuation">.</span>fill_between<span class="token punctuation">(</span>X_<span class="token punctuation">,</span> y_mean <span class="token operator">-</span> y_std<span class="token punctuation">,</span> y_mean <span class="token operator">+</span> y_std<span class="token punctuation">,</span> alpha<span class="token operator">=</span><span class="token number">0.2</span><span class="token punctuation">,</span> color<span class="token operator">=</span><span class="token string">'k'</span><span class="token punctuation">)</span> pyplot<span class="token punctuation">.</span>scatter<span class="token punctuation">(</span>X<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> y<span class="token punctuation">,</span> c<span class="token operator">=</span><span class="token string">'r'</span><span class="token punctuation">,</span> s<span class="token operator">=</span><span class="token number">20</span><span class="token punctuation">,</span> zorder<span class="token operator">=</span><span class="token number">10</span><span class="token punctuation">,</span> edgecolors<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> label<span class="token operator">=</span><span class="token string">'initial point'</span><span class="token punctuation">)</span> pyplot<span class="token punctuation">.</span>scatter<span class="token punctuation">(</span>new_X<span class="token punctuation">,</span> new_y<span class="token punctuation">,</span> c<span class="token operator">=</span><span class="token string">'g'</span><span class="token punctuation">,</span> s<span class="token operator">=</span><span class="token number">40</span><span class="token punctuation">,</span> zorder<span class="token operator">=</span><span class="token number">10</span><span class="token punctuation">,</span> edgecolors<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> label<span class="token operator">=</span><span class="token string">'query point'</span><span class="token punctuation">)</span> pyplot<span class="token punctuation">.</span>legend<span class="token punctuation">(</span>loc<span class="token operator">=</span><span class="token string">'upper left'</span><span class="token punctuation">)</span> pyplot<span class="token punctuation">.</span>tight_layout<span class="token punctuation">(</span><span class="token punctuation">)</span> pyplot<span class="token punctuation">.</span>xlabel<span class="token punctuation">(</span><span class="token string">"learning rate"</span><span class="token punctuation">)</span> pyplot<span class="token punctuation">.</span>ylabel<span class="token punctuation">(</span><span class="token string">"accuracy"</span><span class="token punctuation">)</span> pyplot<span class="token punctuation">.</span>show<span class="token punctuation">(</span><span class="token punctuation">)</span> plot_post<span class="token punctuation">(</span>X<span class="token punctuation">,</span> y<span class="token punctuation">,</span> model<span class="token punctuation">,</span> np<span class="token punctuation">.</span>array<span class="token punctuation">(</span>new_X<span class="token punctuation">)</span><span class="token punctuation">,</span> np<span class="token punctuation">.</span>array<span class="token punctuation">(</span>new_y<span class="token punctuation">)</span><span class="token punctuation">)</span> |
Và đây là kết qủa. Có thể thấy BO hầu như chọn điểm xung quanh vùng mang lại giá trị accuracy cao nhất nhưng đôi khi nó cũng chọn những vùng khác có uncertainty (variance) cao (vùng màu xám thể hiện variance của posterior).
Kết luận
Bài này mình đã giới thiệu và implement với bạn thuật toán tối ưu Bayesian Optimization. Đây là một thuật toán khá hay và hữu dụng trong việc tối ưu, đặc biệt là tối ưu tham số model deep learning. Nếu bài viết có gì sai sót, hãy comment cho mình biết nhé; còn nếu thấy hay thì ngại gì mà không upvote/clip.
Tài liệu tham khảo:
- http://krasserm.github.io/2018/03/21/bayesian-optimization/
- https://www.borealisai.com/en/blog/tutorial-8-bayesian-optimization/
- A Tutorial on Bayesian Optimization – Peter I. Frazier