Mở đầu
Ở các phần trước, mình đã giới thiệu qua về lý thyết cơ bản để xây dựng một mạng nơ-ron lượng tử đồng thời kết hợp code ví dụ với thư viện Paddle Quantum. Tiếp tục với chuỗi bài về mạng nơ-ron lượng tử, ở phần này mình sẽ giới thiệu tới các bạn cách xây dựng QNN với một thư viện khá quen thuộc trong lĩnh vực Deep Learning đó là Pytorch, kết hợp với framework Qiskit được cung cấp bởi IBM. Tuy nhiên vẫn phải lưu ý là toàn bộ phần code này có thể chạy trên colab notebook chứ không phải máy tính lượng tử nên các bạn có thể yên tâm chạy thử nhé !
Cách thức hoạt động
Về cơ bản thì mô hình QNN (Quantum Neural Network) về cơ bản vẫn sẽ có 2 phần : phần chính là classical layer và phần phụ là quantum layer. Nói đến đây thì một số bạn có thể thắc mắc là tại sao mạng nơ-ron lượng tử mà thành phần chính lại không phải các quantum layer? Chúng ta hoàn toàn có thể xây dựng một mạng nơ-ron thuần quantum, nhưng khối lượng tính toán sẽ cực kỳ lớn khiến thời gian training với máy tính thường khá lâu, không thích hợp cho nội dung mình muốn truyền tải nên tạm thời mình sẽ để quantum layer đóng vai ít quan trọng đồng thời cũng sẽ có so sánh với mạng nơ-ron thông thường để các bạn thấy lợi ích và “tác hại” của lượng tử nhé :v
Để thiết kế mạng QNN thì với những ai đã quen thuộc với thư viện Pytorch, chúng ta có thể đơn giản tưởng tượng cách thiết kế y hệt một mạng NN thông thường, chỉ khác ở chỗ ta sẽ thêm một quantum layer vào một vị trí nào đấy trong mạng. Quantum layer mà mình sử dụng ở đây là một mạch tham số hóa lượng tử (Parameterized Quantum Circuit – PQC). Mạch này nhận đầu ra từ một classical layer trong mạng, tính toán trong trường không gian
HilbertHilbert
và sau đó measure ra các giá trị thông thường.Vậy quá trình backprop sẽ diễn ra như thế nào với quantum layer ? Mục đích của bài viết đơn giản là tạo một mạng QNN, nghiêng về phần code nhiều hơn là học thuật nên mình sẽ không nói quá dài về backpropagation cho quantum layer nhé. Dựa theo công trình nghiên cứu Parameter Shift Rule thì chúng ta có thể đúc kết ra công thức tính gradient cho quantum layer tại điểm
θtheta
như sau :
∇θ Quantum Circuit=Quantum Circuit(θ+s)−Quantum Circuit(θ−s)nabla_{theta} text{Quantum Circuit} = text{Quantum Circuit} (theta + s) – text{Quantum Circuit} (theta – s)
Thực hành
Bước 1 : trước tiên chúng ta cần một số cài đặt một số thư viện cần dùng với cú pháp :
1 2 | <span class="token operator">%</span>pip install qiskit<span class="token operator">==</span><span class="token number">0.30</span><span class="token number">.0</span> pylatexenc |
Bước 2 : Sau đó import các thư viện cần thiết:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | <span class="token keyword">import</span> numpy <span class="token keyword">as</span> np <span class="token keyword">import</span> matplotlib<span class="token punctuation">.</span>pyplot <span class="token keyword">as</span> plt <span class="token keyword">import</span> torch <span class="token keyword">from</span> torch<span class="token punctuation">.</span>autograd <span class="token keyword">import</span> Function <span class="token keyword">import</span> torch<span class="token punctuation">.</span>optim <span class="token keyword">as</span> optim <span class="token keyword">import</span> torch<span class="token punctuation">.</span>nn <span class="token keyword">as</span> nn <span class="token keyword">import</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>functional <span class="token keyword">as</span> F <span class="token keyword">from</span> torch<span class="token punctuation">.</span>autograd <span class="token keyword">import</span> Variable <span class="token keyword">import</span> qiskit <span class="token keyword">from</span> qiskit <span class="token keyword">import</span> transpile<span class="token punctuation">,</span> assemble <span class="token keyword">from</span> qiskit<span class="token punctuation">.</span>visualization <span class="token keyword">import</span> plot_state_qsphere<span class="token punctuation">,</span>plot_histogram<span class="token punctuation">,</span> plot_bloch_multivector<span class="token punctuation">,</span> plot_state_city |
Bước 3 : một trong những bước quan trọng nhất – tạo Quantum Circuit. Circuit mà mình tạo khá đơn giản, đầu tiên sẽ nhận đầu vào thông qua cổng
RYR_Y
nhằm biến đổi giá trị thực sang một trạng thái lượng tử, phương pháp xử lý dữ liệu này được gọi là “Angle Encoding”. Với phương pháp này, chúng ta có thể tùy ý sử dụng một trong 3 cổng xoayRX,RY,RZR_X, R_Y, R_Z
. Tiếp theo đó, mình sử dụng 2 cổngHH
vàCNOTCNOT
để tạo rối lượng tử và cuối cùng là phép measure với cơ sởZZ
:σz=∑izip(zi)sigma_mathbf{z} = sum_i z_i p(z_i)
. Như vậy là mình đã tạo một Quantum Circuit đơn giản đóng vai trò quantum layer trong mạng, lưu ý là ở đây mình có sử dụng cổng 2 qubitCNOTCNOT
vì thế nên mạch sẽ chỉ hoạt động với 2 qubit trở lên, chính xác hơn thì trong bài này mình sẽ sử dụng 3 qubit để xây dựng mạch.Mạch trên được tạo như sau:
1 2 3 4 5 6 7 8 9 10 11 12 13 | circuit <span class="token operator">=</span> qiskit<span class="token punctuation">.</span>QuantumCircuit<span class="token punctuation">(</span>n_qubits<span class="token punctuation">)</span> n_qubits <span class="token operator">=</span> n_qubits theta <span class="token operator">=</span> qiskit<span class="token punctuation">.</span>circuit<span class="token punctuation">.</span>Parameter<span class="token punctuation">(</span><span class="token string">'theta'</span><span class="token punctuation">)</span> all_qubits <span class="token operator">=</span> <span class="token punctuation">[</span>i <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>n_qubits<span class="token punctuation">)</span><span class="token punctuation">]</span> circuit<span class="token punctuation">.</span>ry<span class="token punctuation">(</span>self<span class="token punctuation">.</span>theta<span class="token punctuation">,</span> all_qubits<span class="token punctuation">)</span> circuit<span class="token punctuation">.</span>barrier<span class="token punctuation">(</span><span class="token punctuation">)</span> circuit<span class="token punctuation">.</span>h<span class="token punctuation">(</span>all_qubits<span class="token punctuation">)</span> circuit<span class="token punctuation">.</span>barrier<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">for</span> k <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>n_qubits<span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">:</span> self<span class="token punctuation">.</span>circuit<span class="token punctuation">.</span>cx<span class="token punctuation">(</span>k<span class="token punctuation">,</span> k<span class="token operator">+</span><span class="token number">1</span><span class="token punctuation">)</span> circuit<span class="token punctuation">.</span>measure_all<span class="token punctuation">(</span><span class="token punctuation">)</span> |
Kết quả:
Đưa đoạn code này vào class để tiện sử dụng :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | <span class="token keyword">class</span> <span class="token class-name">SimpleCircuit</span><span class="token punctuation">:</span> <span class="token triple-quoted-string string">""" This class provides a simple interface for interaction with the quantum circuit """</span> <span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> n_qubits<span class="token punctuation">,</span> backend<span class="token punctuation">,</span> shots<span class="token punctuation">,</span> n_input<span class="token operator">=</span><span class="token number">9</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token comment"># --- Circuit definition ---</span> self<span class="token punctuation">.</span>n_qubits<span class="token operator">=</span>n_qubits self<span class="token punctuation">.</span>n_inputs <span class="token operator">=</span> n_input self<span class="token punctuation">.</span>circuit <span class="token operator">=</span> qiskit<span class="token punctuation">.</span>QuantumCircuit<span class="token punctuation">(</span>n_qubits<span class="token punctuation">)</span> self<span class="token punctuation">.</span>n_qubits <span class="token operator">=</span> n_qubits self<span class="token punctuation">.</span>theta <span class="token operator">=</span> qiskit<span class="token punctuation">.</span>circuit<span class="token punctuation">.</span>Parameter<span class="token punctuation">(</span><span class="token string">'theta'</span><span class="token punctuation">)</span> all_qubits <span class="token operator">=</span> <span class="token punctuation">[</span>i <span class="token keyword">for</span> i <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>n_qubits<span class="token punctuation">)</span><span class="token punctuation">]</span> self<span class="token punctuation">.</span>circuit<span class="token punctuation">.</span>ry<span class="token punctuation">(</span>self<span class="token punctuation">.</span>theta<span class="token punctuation">,</span> all_qubits<span class="token punctuation">)</span> self<span class="token punctuation">.</span>circuit<span class="token punctuation">.</span>barrier<span class="token punctuation">(</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>circuit<span class="token punctuation">.</span>h<span class="token punctuation">(</span>all_qubits<span class="token punctuation">)</span> self<span class="token punctuation">.</span>circuit<span class="token punctuation">.</span>barrier<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">for</span> k <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>n_qubits<span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">:</span> self<span class="token punctuation">.</span>circuit<span class="token punctuation">.</span>cx<span class="token punctuation">(</span>k<span class="token punctuation">,</span> k<span class="token operator">+</span><span class="token number">1</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>circuit<span class="token punctuation">.</span>measure_all<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># ---------------------------</span> self<span class="token punctuation">.</span>backend <span class="token operator">=</span> backend self<span class="token punctuation">.</span>shots <span class="token operator">=</span> shots <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> thetas<span class="token punctuation">)</span><span class="token punctuation">:</span> t_qc <span class="token operator">=</span> transpile<span class="token punctuation">(</span>self<span class="token punctuation">.</span>circuit<span class="token punctuation">,</span> self<span class="token punctuation">.</span>backend<span class="token punctuation">)</span> qobj <span class="token operator">=</span> assemble<span class="token punctuation">(</span>t_qc<span class="token punctuation">,</span> shots<span class="token operator">=</span>self<span class="token punctuation">.</span>shots<span class="token punctuation">,</span> parameter_binds <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">{</span>self<span class="token punctuation">.</span>theta<span class="token punctuation">:</span> theta<span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">}</span> <span class="token keyword">for</span> theta <span class="token keyword">in</span> thetas<span class="token punctuation">]</span><span class="token punctuation">)</span> job <span class="token operator">=</span> self<span class="token punctuation">.</span>backend<span class="token punctuation">.</span>run<span class="token punctuation">(</span>qobj<span class="token punctuation">)</span> result <span class="token operator">=</span> job<span class="token punctuation">.</span>result<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>get_counts<span class="token punctuation">(</span><span class="token punctuation">)</span> exp <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> <span class="token keyword">for</span> dict_ <span class="token keyword">in</span> result<span class="token punctuation">:</span> counts <span class="token operator">=</span> np<span class="token punctuation">.</span>array<span class="token punctuation">(</span><span class="token builtin">list</span><span class="token punctuation">(</span>dict_<span class="token punctuation">.</span>values<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> states <span class="token operator">=</span> np<span class="token punctuation">.</span>array<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token builtin">int</span><span class="token punctuation">(</span>k<span class="token punctuation">,</span><span class="token number">2</span><span class="token punctuation">)</span> <span class="token keyword">for</span> k <span class="token keyword">in</span> <span class="token builtin">list</span><span class="token punctuation">(</span>dict_<span class="token punctuation">.</span>keys<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span> probabilities <span class="token operator">=</span> counts <span class="token operator">/</span> self<span class="token punctuation">.</span>shots expectation <span class="token operator">=</span> states <span class="token operator">*</span> probabilities <span class="token keyword">while</span> expectation<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token operator"><</span><span class="token number">2</span><span class="token operator">**</span>self<span class="token punctuation">.</span>n_qubits<span class="token punctuation">:</span> expectation <span class="token operator">=</span> np<span class="token punctuation">.</span>append<span class="token punctuation">(</span>expectation<span class="token punctuation">,</span> <span class="token number">0.00</span><span class="token punctuation">)</span> exp<span class="token punctuation">.</span>append<span class="token punctuation">(</span>expectation<span class="token punctuation">)</span> <span class="token keyword">return</span> np<span class="token punctuation">.</span>asarray<span class="token punctuation">(</span>exp<span class="token punctuation">)</span><span class="token punctuation">.</span>T<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span> <span class="token keyword">def</span> <span class="token function">plot</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> thetas<span class="token punctuation">)</span><span class="token punctuation">:</span> self<span class="token punctuation">.</span>plot_backend <span class="token operator">=</span> qiskit<span class="token punctuation">.</span>Aer<span class="token punctuation">.</span>get_backend<span class="token punctuation">(</span><span class="token string">"statevector_simulator"</span><span class="token punctuation">)</span> t_qc <span class="token operator">=</span> transpile<span class="token punctuation">(</span>self<span class="token punctuation">.</span>circuit<span class="token punctuation">,</span> self<span class="token punctuation">.</span>plot_backend<span class="token punctuation">)</span> qobj <span class="token operator">=</span> assemble<span class="token punctuation">(</span>t_qc<span class="token punctuation">,</span> shots<span class="token operator">=</span>self<span class="token punctuation">.</span>shots<span class="token punctuation">,</span> parameter_binds <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">{</span>self<span class="token punctuation">.</span>theta<span class="token punctuation">:</span> theta<span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">}</span> <span class="token keyword">for</span> theta <span class="token keyword">in</span> thetas<span class="token punctuation">]</span><span class="token punctuation">)</span> job <span class="token operator">=</span> self<span class="token punctuation">.</span>plot_backend<span class="token punctuation">.</span>run<span class="token punctuation">(</span>qobj<span class="token punctuation">)</span> result <span class="token operator">=</span> job<span class="token punctuation">.</span>result<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>get_counts<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> display<span class="token punctuation">(</span>plot_state_qsphere<span class="token punctuation">(</span>job<span class="token punctuation">.</span>result<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>get_statevector<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> display<span class="token punctuation">(</span>plot_state_city<span class="token punctuation">(</span>job<span class="token punctuation">.</span>result<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>get_statevector<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> figsize<span class="token operator">=</span><span class="token punctuation">[</span><span class="token number">16</span><span class="token punctuation">,</span> <span class="token number">9</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span> display<span class="token punctuation">(</span>self<span class="token punctuation">.</span>circuit<span class="token punctuation">.</span>draw<span class="token punctuation">(</span><span class="token string">"mpl"</span><span class="token punctuation">)</span><span class="token punctuation">)</span> display<span class="token punctuation">(</span>plot_histogram<span class="token punctuation">(</span>job<span class="token punctuation">.</span>result<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>get_counts<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span> |
Ở đoạn code trên thì nội dung trong hàm init mình đã giới thiệu trước rồi nên bỏ qua nhé. Hàm forward nhận đầu vào là tensor có số chiều
(1,N)(1,N)
. Sau đó lần lượt forward tensor này qua 3 cổngRYR_Yparameter_binds = [{self.theta: theta.item()} for theta in thetas])
. Từ đó mạch sẽ tiến hành tính toán và measure đầu ra result
dưới dạng các dictionary với key là chuỗi bit ‘000’ -> ‘111’ và result là xác suất rơi vào 1 trong 8 trạng thái trên, sau đó ta sẽ lợi dụng những giá trị này để tạo đầu ra mong muốn
Bước 4 : Sau khi tạo quantum circuit, chúng ta có thể tiến đến bước “bọc” circuit với thư viện Pytorch cho quá trình forward và backward tương tự như một dense layer bình thường:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | <span class="token keyword">class</span> <span class="token class-name">TorchCircuit</span><span class="token punctuation">(</span>Function<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token decorator annotation punctuation">@staticmethod</span> <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>ctx<span class="token punctuation">,</span> inp<span class="token punctuation">,</span> circuit<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">,</span> shift<span class="token operator">=</span>np<span class="token punctuation">.</span>pi<span class="token operator">/</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">if</span> <span class="token keyword">not</span> <span class="token builtin">hasattr</span><span class="token punctuation">(</span>ctx<span class="token punctuation">,</span> <span class="token string">'QiskitCirc'</span><span class="token punctuation">)</span><span class="token punctuation">:</span> ctx<span class="token punctuation">.</span>QiskitCirc <span class="token operator">=</span> SimpleCircuit<span class="token punctuation">(</span>NUM_QUBITS<span class="token punctuation">,</span> SIMULATOR<span class="token punctuation">,</span> shots<span class="token operator">=</span>NUM_SHOTS<span class="token punctuation">,</span> n_input<span class="token operator">=</span>NUM_INP<span class="token punctuation">)</span> exp_value <span class="token operator">=</span> ctx<span class="token punctuation">.</span>QiskitCirc<span class="token punctuation">.</span>forward<span class="token punctuation">(</span>inp<span class="token punctuation">)</span> result <span class="token operator">=</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>exp_value<span class="token punctuation">]</span><span class="token punctuation">)</span> ctx<span class="token punctuation">.</span>save_for_backward<span class="token punctuation">(</span>result<span class="token punctuation">,</span> inp<span class="token punctuation">)</span> <span class="token keyword">return</span> result <span class="token decorator annotation punctuation">@staticmethod</span> <span class="token keyword">def</span> <span class="token function">backward</span><span class="token punctuation">(</span>ctx<span class="token punctuation">,</span> grad_output<span class="token punctuation">)</span><span class="token punctuation">:</span> forward_tensor<span class="token punctuation">,</span> i <span class="token operator">=</span> ctx<span class="token punctuation">.</span>saved_tensors input_numbers <span class="token operator">=</span> i gradients <span class="token operator">=</span> torch<span class="token punctuation">.</span>Tensor<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">for</span> k <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>NUM_INP<span class="token punctuation">)</span><span class="token punctuation">:</span> shift_right <span class="token operator">=</span> input_numbers<span class="token punctuation">.</span>detach<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>clone<span class="token punctuation">(</span><span class="token punctuation">)</span> shift_right<span class="token punctuation">[</span>k<span class="token punctuation">]</span> <span class="token operator">=</span> shift_right<span class="token punctuation">[</span>k<span class="token punctuation">]</span> <span class="token operator">+</span> SHIFT shift_left <span class="token operator">=</span> input_numbers<span class="token punctuation">.</span>detach<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>clone<span class="token punctuation">(</span><span class="token punctuation">)</span> shift_left<span class="token punctuation">[</span>k<span class="token punctuation">]</span> <span class="token operator">=</span> shift_left<span class="token punctuation">[</span>k<span class="token punctuation">]</span> <span class="token operator">-</span> SHIFT expectation_right <span class="token operator">=</span> ctx<span class="token punctuation">.</span>QiskitCirc<span class="token punctuation">.</span>forward<span class="token punctuation">(</span>shift_right<span class="token punctuation">)</span> expectation_left <span class="token operator">=</span> ctx<span class="token punctuation">.</span>QiskitCirc<span class="token punctuation">.</span>forward<span class="token punctuation">(</span>shift_left<span class="token punctuation">)</span> gradient <span class="token operator">=</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>expectation_right<span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">-</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>expectation_left<span class="token punctuation">]</span><span class="token punctuation">)</span> gradients <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">(</span>gradients<span class="token punctuation">,</span> gradient<span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> result <span class="token operator">=</span> torch<span class="token punctuation">.</span>Tensor<span class="token punctuation">(</span>gradients<span class="token punctuation">)</span> <span class="token keyword">return</span> <span class="token punctuation">(</span>result<span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">*</span> grad_output<span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">.</span>T |
Đồng thời define một số hyperparameter:
1 2 3 4 5 6 7 8 | NUM_INP<span class="token operator">=</span><span class="token number">10</span> NUM_QUBITS <span class="token operator">=</span> <span class="token number">3</span> SIMULATOR<span class="token operator">=</span>qiskit<span class="token punctuation">.</span>Aer<span class="token punctuation">.</span>get_backend<span class="token punctuation">(</span><span class="token string">'qasm_simulator'</span><span class="token punctuation">)</span> NUM_SHOTS<span class="token operator">=</span><span class="token number">512</span> SHIFT <span class="token operator">=</span> np<span class="token punctuation">.</span>pi<span class="token operator">/</span><span class="token number">2</span> n_samples <span class="token operator">=</span> <span class="token number">150</span> |
Bước 5 : chuẩn bị dữ liệu. Bộ dữ liệu mà mình sử dụng là tập mnist 60000 sample được thu thập từ khoảng 250 người viết., tuy nhiên mình chỉ sử dụng 1500 sample train và 1500 sample test mỗi class. Vì thời gian training khá lâu nên mình sẽ sử dụng hạn chế nhất lượng dữ liệu cần dùng để tăng tốc quá trình training.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | <span class="token keyword">import</span> numpy <span class="token keyword">as</span> np <span class="token keyword">import</span> torchvision <span class="token keyword">from</span> torchvision <span class="token keyword">import</span> datasets<span class="token punctuation">,</span> transforms X_train <span class="token operator">=</span> datasets<span class="token punctuation">.</span>MNIST<span class="token punctuation">(</span>root<span class="token operator">=</span><span class="token string">'./data'</span><span class="token punctuation">,</span> train<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> download<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> transform<span class="token operator">=</span>transforms<span class="token punctuation">.</span>Compose<span class="token punctuation">(</span><span class="token punctuation">[</span>transforms<span class="token punctuation">.</span>ToTensor<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span> idx <span class="token operator">=</span> np<span class="token punctuation">.</span>append<span class="token punctuation">(</span>np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>X_train<span class="token punctuation">.</span>targets <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>n_samples<span class="token punctuation">]</span><span class="token punctuation">,</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>X_train<span class="token punctuation">.</span>targets <span class="token operator">==</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>n_samples<span class="token punctuation">]</span><span class="token punctuation">)</span> idx <span class="token operator">=</span> np<span class="token punctuation">.</span>append<span class="token punctuation">(</span>idx<span class="token punctuation">,</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>X_train<span class="token punctuation">.</span>targets <span class="token operator">==</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>n_samples<span class="token punctuation">]</span><span class="token punctuation">)</span> idx <span class="token operator">=</span> np<span class="token punctuation">.</span>append<span class="token punctuation">(</span>idx<span class="token punctuation">,</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>X_train<span class="token punctuation">.</span>targets <span class="token operator">==</span> <span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>n_samples<span class="token punctuation">]</span><span class="token punctuation">)</span> idx <span class="token operator">=</span> np<span class="token punctuation">.</span>append<span class="token punctuation">(</span>idx<span class="token punctuation">,</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>X_train<span class="token punctuation">.</span>targets <span class="token operator">==</span> <span class="token number">4</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>n_samples<span class="token punctuation">]</span><span class="token punctuation">)</span> idx <span class="token operator">=</span> np<span class="token punctuation">.</span>append<span class="token punctuation">(</span>idx<span class="token punctuation">,</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>X_train<span class="token punctuation">.</span>targets <span class="token operator">==</span> <span class="token number">5</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>n_samples<span class="token punctuation">]</span><span class="token punctuation">)</span> idx <span class="token operator">=</span> np<span class="token punctuation">.</span>append<span class="token punctuation">(</span>idx<span class="token punctuation">,</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>X_train<span class="token punctuation">.</span>targets <span class="token operator">==</span> <span class="token number">6</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>n_samples<span class="token punctuation">]</span><span class="token punctuation">)</span> idx <span class="token operator">=</span> np<span class="token punctuation">.</span>append<span class="token punctuation">(</span>idx<span class="token punctuation">,</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>X_train<span class="token punctuation">.</span>targets <span class="token operator">==</span> <span class="token number">7</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>n_samples<span class="token punctuation">]</span><span class="token punctuation">)</span> idx <span class="token operator">=</span> np<span class="token punctuation">.</span>append<span class="token punctuation">(</span>idx<span class="token punctuation">,</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>X_train<span class="token punctuation">.</span>targets <span class="token operator">==</span> <span class="token number">8</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>n_samples<span class="token punctuation">]</span><span class="token punctuation">)</span> idx <span class="token operator">=</span> np<span class="token punctuation">.</span>append<span class="token punctuation">(</span>idx<span class="token punctuation">,</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>X_train<span class="token punctuation">.</span>targets <span class="token operator">==</span> <span class="token number">9</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>n_samples<span class="token punctuation">]</span><span class="token punctuation">)</span> X_train<span class="token punctuation">.</span>data <span class="token operator">=</span> X_train<span class="token punctuation">.</span>data<span class="token punctuation">[</span>idx<span class="token punctuation">]</span> X_train<span class="token punctuation">.</span>targets <span class="token operator">=</span> X_train<span class="token punctuation">.</span>targets<span class="token punctuation">[</span>idx<span class="token punctuation">]</span> train_loader <span class="token operator">=</span> torch<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>data<span class="token punctuation">.</span>DataLoader<span class="token punctuation">(</span>X_train<span class="token punctuation">,</span> batch_size<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span> shuffle<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> pin_memory<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span> |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | X_test <span class="token operator">=</span> datasets<span class="token punctuation">.</span>MNIST<span class="token punctuation">(</span>root<span class="token operator">=</span><span class="token string">'./data'</span><span class="token punctuation">,</span> train<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> download<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> transform<span class="token operator">=</span>transforms<span class="token punctuation">.</span>Compose<span class="token punctuation">(</span><span class="token punctuation">[</span>transforms<span class="token punctuation">.</span>ToTensor<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span> idx <span class="token operator">=</span> np<span class="token punctuation">.</span>append<span class="token punctuation">(</span>np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>X_test<span class="token punctuation">.</span>targets <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>n_samples<span class="token punctuation">]</span><span class="token punctuation">,</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>X_test<span class="token punctuation">.</span>targets <span class="token operator">==</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>n_samples<span class="token punctuation">]</span><span class="token punctuation">)</span> idx <span class="token operator">=</span> np<span class="token punctuation">.</span>append<span class="token punctuation">(</span>idx<span class="token punctuation">,</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>X_test<span class="token punctuation">.</span>targets <span class="token operator">==</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>n_samples<span class="token punctuation">]</span><span class="token punctuation">)</span> idx <span class="token operator">=</span> np<span class="token punctuation">.</span>append<span class="token punctuation">(</span>idx<span class="token punctuation">,</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>X_test<span class="token punctuation">.</span>targets <span class="token operator">==</span> <span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>n_samples<span class="token punctuation">]</span><span class="token punctuation">)</span> idx <span class="token operator">=</span> np<span class="token punctuation">.</span>append<span class="token punctuation">(</span>idx<span class="token punctuation">,</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>X_test<span class="token punctuation">.</span>targets <span class="token operator">==</span> <span class="token number">4</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>n_samples<span class="token punctuation">]</span><span class="token punctuation">)</span> idx <span class="token operator">=</span> np<span class="token punctuation">.</span>append<span class="token punctuation">(</span>idx<span class="token punctuation">,</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>X_test<span class="token punctuation">.</span>targets <span class="token operator">==</span> <span class="token number">5</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>n_samples<span class="token punctuation">]</span><span class="token punctuation">)</span> idx <span class="token operator">=</span> np<span class="token punctuation">.</span>append<span class="token punctuation">(</span>idx<span class="token punctuation">,</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>X_test<span class="token punctuation">.</span>targets <span class="token operator">==</span> <span class="token number">6</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>n_samples<span class="token punctuation">]</span><span class="token punctuation">)</span> idx <span class="token operator">=</span> np<span class="token punctuation">.</span>append<span class="token punctuation">(</span>idx<span class="token punctuation">,</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>X_test<span class="token punctuation">.</span>targets <span class="token operator">==</span> <span class="token number">7</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>n_samples<span class="token punctuation">]</span><span class="token punctuation">)</span> idx <span class="token operator">=</span> np<span class="token punctuation">.</span>append<span class="token punctuation">(</span>idx<span class="token punctuation">,</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>X_test<span class="token punctuation">.</span>targets <span class="token operator">==</span> <span class="token number">8</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>n_samples<span class="token punctuation">]</span><span class="token punctuation">)</span> idx <span class="token operator">=</span> np<span class="token punctuation">.</span>append<span class="token punctuation">(</span>idx<span class="token punctuation">,</span> np<span class="token punctuation">.</span>where<span class="token punctuation">(</span>X_test<span class="token punctuation">.</span>targets <span class="token operator">==</span> <span class="token number">9</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">[</span><span class="token punctuation">:</span>n_samples<span class="token punctuation">]</span><span class="token punctuation">)</span> X_test<span class="token punctuation">.</span>data <span class="token operator">=</span> X_test<span class="token punctuation">.</span>data<span class="token punctuation">[</span>idx<span class="token punctuation">]</span> X_test<span class="token punctuation">.</span>targets <span class="token operator">=</span> X_test<span class="token punctuation">.</span>targets<span class="token punctuation">[</span>idx<span class="token punctuation">]</span> test_loader <span class="token operator">=</span> torch<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>data<span class="token punctuation">.</span>DataLoader<span class="token punctuation">(</span>X_test<span class="token punctuation">,</span> batch_size<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span> shuffle<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span> |
Bước 6 : Dựng model với quantum layer đóng vai trò là layer cuối có nhiệm vụ prediction
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | <span class="token keyword">class</span> <span class="token class-name">Net</span><span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Module<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token builtin">super</span><span class="token punctuation">(</span>Net<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>conv1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">10</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token number">5</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>conv2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">10</span><span class="token punctuation">,</span> <span class="token number">20</span><span class="token punctuation">,</span> kernel_size<span class="token operator">=</span><span class="token number">5</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>conv2_drop <span class="token operator">=</span> nn<span class="token punctuation">.</span>Dropout2d<span class="token punctuation">(</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>fc1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">320</span><span class="token punctuation">,</span> <span class="token number">50</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>fc2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">50</span><span class="token punctuation">,</span> NUM_INP<span class="token punctuation">)</span> self<span class="token punctuation">.</span>qc <span class="token operator">=</span> TorchCircuit<span class="token punctuation">.</span><span class="token builtin">apply</span> self<span class="token punctuation">.</span>qcsim <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>NUM_INP<span class="token punctuation">,</span> NUM_INP<span class="token punctuation">)</span> self<span class="token punctuation">.</span>fc3 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>NUM_INP<span class="token punctuation">,</span> <span class="token number">10</span><span class="token punctuation">)</span> <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span> x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>F<span class="token punctuation">.</span>max_pool2d<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv1<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">)</span> x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>F<span class="token punctuation">.</span>max_pool2d<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv2_drop<span class="token punctuation">(</span>self<span class="token punctuation">.</span>conv2<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">)</span> x <span class="token operator">=</span> x<span class="token punctuation">.</span>view<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">320</span><span class="token punctuation">)</span> x <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>fc1<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span> x <span class="token operator">=</span> F<span class="token punctuation">.</span>dropout<span class="token punctuation">(</span>x<span class="token punctuation">,</span> training<span class="token operator">=</span>self<span class="token punctuation">.</span>training<span class="token punctuation">)</span> x <span class="token operator">=</span> self<span class="token punctuation">.</span>fc2<span class="token punctuation">(</span>x<span class="token punctuation">)</span> x <span class="token operator">=</span> np<span class="token punctuation">.</span>pi<span class="token operator">*</span>torch<span class="token punctuation">.</span>tanh<span class="token punctuation">(</span>x<span class="token punctuation">)</span> <span class="token comment"># print('params to QC: {}'.format(x))</span> MODE <span class="token operator">=</span> <span class="token string">'QC'</span> <span class="token comment"># 'QC' or 'QC_sim'</span> <span class="token keyword">if</span> MODE <span class="token operator">==</span> <span class="token string">'QC'</span><span class="token punctuation">:</span> x <span class="token operator">=</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>qc<span class="token punctuation">(</span>x_i<span class="token punctuation">)</span> <span class="token keyword">for</span> x_i <span class="token keyword">in</span> x<span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># QUANTUM LAYER</span> <span class="token keyword">print</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>qc<span class="token punctuation">(</span>x<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">.</span>shape<span class="token punctuation">)</span> <span class="token keyword">else</span><span class="token punctuation">:</span> x <span class="token operator">=</span> self<span class="token punctuation">.</span>qcsim<span class="token punctuation">(</span>x<span class="token punctuation">)</span> <span class="token comment"># x = F.relu(self.fc3(x.float()))</span> <span class="token comment"># x = torch.cat((x, 1-x), -1)</span> <span class="token keyword">return</span> x <span class="token keyword">def</span> <span class="token function">predict</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token comment"># apply softmax</span> pred <span class="token operator">=</span> self<span class="token punctuation">.</span>forward<span class="token punctuation">(</span>x<span class="token punctuation">)</span> <span class="token comment"># print(pred)</span> ans <span class="token operator">=</span> torch<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>pred<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">return</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span>ans<span class="token punctuation">)</span> |
Bước 7 : thiết kế training script và bắt đầu quá trình training
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | <span class="token keyword">from</span> tqdm <span class="token keyword">import</span> tqdm epochs <span class="token operator">=</span> <span class="token number">10</span> loss_list <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> loss_func <span class="token operator">=</span> nn<span class="token punctuation">.</span>CrossEntropyLoss<span class="token punctuation">(</span><span class="token punctuation">)</span> list_acc <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> <span class="token keyword">for</span> epoch <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>epochs<span class="token punctuation">)</span><span class="token punctuation">:</span> total_loss <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> <span class="token keyword">for</span> batch_idx<span class="token punctuation">,</span> <span class="token punctuation">(</span>data<span class="token punctuation">,</span> target<span class="token punctuation">)</span> <span class="token keyword">in</span> <span class="token builtin">enumerate</span><span class="token punctuation">(</span>tqdm<span class="token punctuation">(</span>train_loader<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token comment"># print(batch_idx)</span> optimizer<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># Forward pass</span> output <span class="token operator">=</span> network<span class="token punctuation">(</span>data<span class="token punctuation">)</span> <span class="token comment"># Calculating loss</span> loss <span class="token operator">=</span> loss_func<span class="token punctuation">(</span>output<span class="token punctuation">,</span> target<span class="token punctuation">)</span> <span class="token comment"># Backward pass</span> loss<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># Optimize the weights</span> optimizer<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span> total_loss<span class="token punctuation">.</span>append<span class="token punctuation">(</span>loss<span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span> loss_list<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token builtin">sum</span><span class="token punctuation">(</span>total_loss<span class="token punctuation">)</span><span class="token operator">/</span><span class="token builtin">len</span><span class="token punctuation">(</span>total_loss<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'Training [{:.0f}%]tLoss: {:.4f}'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span> <span class="token number">100</span><span class="token punctuation">.</span> <span class="token operator">*</span> <span class="token punctuation">(</span>epoch <span class="token operator">+</span> <span class="token number">1</span><span class="token punctuation">)</span> <span class="token operator">/</span> epochs<span class="token punctuation">,</span> loss_list<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span> torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span><span class="token punctuation">{</span> <span class="token string">'epoch'</span><span class="token punctuation">:</span> epoch<span class="token punctuation">,</span> <span class="token string">'model'</span><span class="token punctuation">:</span> network<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token string">'optimizer_state_dict'</span><span class="token punctuation">:</span> optimizer<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token string">'loss'</span><span class="token punctuation">:</span> loss<span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">}</span><span class="token punctuation">,</span> <span class="token string-interpolation"><span class="token string">f"/content/gdrive/MyDrive/Adhoc/model_quantum_</span><span class="token interpolation"><span class="token punctuation">{</span>epoch<span class="token punctuation">}</span></span><span class="token string">.pth"</span></span><span class="token punctuation">)</span> accuracy <span class="token operator">=</span> <span class="token number">0</span> number <span class="token operator">=</span> <span class="token number">0</span> network<span class="token punctuation">.</span><span class="token builtin">eval</span><span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">for</span> batch_idx<span class="token punctuation">,</span> <span class="token punctuation">(</span>data<span class="token punctuation">,</span> target<span class="token punctuation">)</span> <span class="token keyword">in</span> <span class="token builtin">enumerate</span><span class="token punctuation">(</span>test_loader<span class="token punctuation">)</span><span class="token punctuation">:</span> number <span class="token operator">+=</span><span class="token number">1</span> output <span class="token operator">=</span> network<span class="token punctuation">(</span>data<span class="token punctuation">)</span> accuracy <span class="token operator">+=</span> <span class="token punctuation">(</span>output<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token operator">==</span> target<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment"># accuracy += (output.argmax(1).cpu() == target.cpu()).sum().item()</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">"Performance on test data is : {}/{} = {}%"</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>accuracy<span class="token punctuation">,</span>number<span class="token punctuation">,</span><span class="token number">100</span><span class="token operator">*</span>accuracy<span class="token operator">/</span>number<span class="token punctuation">)</span><span class="token punctuation">)</span> list_acc<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token number">100</span><span class="token operator">*</span>accuracy<span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token operator">/</span>number<span class="token punctuation">)</span> |
Và đây là thành quả cuối cùng:
Toàn bộ model của mình có thể được mô hình hóa lại như sau:
So sánh thời gian training giữa mô hình thông thường và mô hình lượng tử thì mô hình CNN thông thường chỉ mất 1 phút cho 10 epoch với lượng dữ liệu trên, trong khi mô hình lượng tử mất gần 2 tiếng :V Nhưng kết quả mang lại thì khá tương xứng với công sức bỏ ra khi accuracy của QNN hoàn toàn vượt trội so với CNN thông thường.
Tuy đạt được mục tiêu đề ra, nhưng do hạn chế về mặt thời gian thực hiện, bài viết vẫn
chưa phát triển được hết các tính năng của Quantum Neural Network, cụ thể là mới triển khai
được trên máy tính thông thường, còn việc huấn luyện và kiểm thử mô hình trên máy lượng
tử thì vẫn chưa được triển khai. Trong tương lai, nếu có cơ hội, mình sẽ tiếp tục nghiên cứu và
phát triển các loại Quantum Neural Network với nhiều bài toán phức tạp hơn.
Toàn bộ phần code mình sẽ để ở link này nhé : https://colab.research.google.com/drive/1CUbqJg1cDwiBVWHaOjYgfc5EBdB1TvUV?usp=sharing
Reference
[1] Parameter shift rule . https://arxiv.org/pdf/1905.13311.pdf