Zalo đang tổ chức một cuộc thi về Ai cho toàn thể ACE trong “Ngành”. Một trong ba bài toán đó là bài Motorbike Generator và tất nhiên requirement của nó y hệt như cái bài Dog Generator trên Kaggle, khác mỗi đầu ra là 128×128 còn bài Dog Generator là 64×64 :v. Và mình cũng tham gia góp vui với một tinh thần 3H – Ham học hỏi :v. Bài viết này mình đề cập tới những kinh nghiệm của mình trong việc quan sát dữ liệu, xử lý ảnh, cũng như cách training mô hình… Đây chỉ là kinh nghiệm của mình trong quá trình mình làm và học đc, nếu có chỗ nào sai sót mong mọi người gạch đá nhè nhẹ )
Yêu cầu bài toán
Hiểu đơn giản là sử dụng 10000 mà bên Zalo cho và sinh ra 10000 ảnh với định dạng là PNG từ bộ dữ liệu cho trước , evaluation metric là FID.
Lý thuyết
Để đọc thêm về FID evalutation mời các bạn đọc Tại đây
Mình sẽ sử dụng RaLSGAN cho bài toán này. Vậy RaLSGAN là cái vẹo gì? TL: nó chỉ là mạng GAN thông thường nhưng với một loss function tối ưu hơn
Loss Functions
Một Discriminator đầu ra có thể là hàm kích hoạt sigmoid hoặc linear. Nếu là sigmoid chúng ta có phân bố xác xuất rời rạc của một hình ảnh là thực là Pr(real) . Còn với linear chúng ta có C(x) = logit . Xác suất nằm trong khoảng (0,1). Và logit có thể là bất cứ số nào trong khoảng (0,1) . Số dương đại diện cho ảnh real còn âm đại diện cho ảnh fake.
Simple loss
Ta gọi x_r là ảnh thật, và x_f là ảnh fake thì ta sẽ có D(x) = Pr(Real) hay C(x) = Logit sẽ trở thành 2 ouptput của một Discriminator khi input là một hình ảnh, Loss Function sẽ như sau:
1 2 3 4 | # Take AVG over x_r and x_f in batch disc_loss = (1 - D(x_r)) + (D(x_f) - 0) gen_loss = (1 - D(x_f)) |
Chúng ta muốn Discriminator D(x_r) = 1 và D(x_f) = 0 tương đương với nhãn real và fake, và sau khi huấn luyện xong ta thì Generator càng gần 1 càng tốt. Nói nôm na là ta dùng ảnh thật làm bộ dữ liệu huấn luyện cho vào Discriminator để mạng có thể phân biệt được ảnh real và ảnh fake, và từ đó Discriminator sẽ phản hồi về cho Generator để nó tự hoàn thiện mình dựa vào phản hồi đó đó là lý do D(x_r) = 1 và D(x_f) = 0. Nếu các bạn muốn tìm hiểu sâu hơn về mạng GAN mời đọc Tại đây
DCGAN Loss
Ta thấy rõ Basic GAN và DCGAN sử dụng D(x):
1 2 3 4 | # Take AVG over x_r and x_f in batch disc_loss = -log (D(x_r)) - log (1-D(x_f)) gen_loss = -log (D(x_f)) |
RaLSGAN Loss
RaLSGAN sử dụng C(x) = logit:
1 2 3 4 | # Take AVG over x_r and x_f in batch disc_loss = (C(x_r) - AVG(C(x_f)) - 1)^2 + (C(x_f) - AVG(C(x_r)) + 1)^2 gen_loss = (C(x_r) - AVG(C(x_f)) + 1)^2 + (C(x_f) - AVG(C(x_r)) - 1)^2 |
Code
Ta sẽ sử dụng pytorch để code cho bài toán này, việc đầu tiên là thay đổi hàm kích hoạt từ sigmoid sang logit(có thể là tanh) ở dòng cuối:
1 2 3 | <span class="token comment">#x = torch.sigmoid(self.conv5(x))</span> x <span class="token operator">=</span> self<span class="token punctuation">.</span>conv5<span class="token punctuation">(</span>x<span class="token punctuation">)</span> |
Tiếp theo ta sẽ update loss của G và D:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | <span class="token comment">############################</span> <span class="token comment"># (1) Update D network</span> <span class="token comment">###########################</span> netD<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span> real_images <span class="token operator">=</span> real_images<span class="token punctuation">.</span>to<span class="token punctuation">(</span>device<span class="token punctuation">)</span> batch_size <span class="token operator">=</span> real_images<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span> labels <span class="token operator">=</span> torch<span class="token punctuation">.</span>full<span class="token punctuation">(</span><span class="token punctuation">(</span>batch_size<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> real_label<span class="token punctuation">,</span> device<span class="token operator">=</span>device<span class="token punctuation">)</span> outputR <span class="token operator">=</span> netD<span class="token punctuation">(</span>real_images<span class="token punctuation">)</span> noise <span class="token operator">=</span> torch<span class="token punctuation">.</span>randn<span class="token punctuation">(</span>batch_size<span class="token punctuation">,</span> nz<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> device<span class="token operator">=</span>device<span class="token punctuation">)</span> fake <span class="token operator">=</span> netG<span class="token punctuation">(</span>noise<span class="token punctuation">)</span> outputF <span class="token operator">=</span> netD<span class="token punctuation">(</span>fake<span class="token punctuation">.</span>detach<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span> errD <span class="token operator">=</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>mean<span class="token punctuation">(</span><span class="token punctuation">(</span>outputR <span class="token operator">-</span> torch<span class="token punctuation">.</span>mean<span class="token punctuation">(</span>outputF<span class="token punctuation">)</span> <span class="token operator">-</span> labels<span class="token punctuation">)</span> <span class="token operator">**</span> <span class="token number">2</span><span class="token punctuation">)</span> <span class="token operator">+</span> torch<span class="token punctuation">.</span>mean<span class="token punctuation">(</span><span class="token punctuation">(</span>outputF <span class="token operator">-</span> torch<span class="token punctuation">.</span>mean<span class="token punctuation">(</span>outputR<span class="token punctuation">)</span> <span class="token operator">+</span> labels<span class="token punctuation">)</span> <span class="token operator">**</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token operator">/</span><span class="token number">2</span> errD<span class="token punctuation">.</span>backward<span class="token punctuation">(</span>retain_graph<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span> optimizerD<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment">############################</span> <span class="token comment"># (2) Update G network</span> <span class="token comment">###########################</span> netG<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span> outputF <span class="token operator">=</span> netD<span class="token punctuation">(</span>fake<span class="token punctuation">)</span> errG <span class="token operator">=</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>mean<span class="token punctuation">(</span><span class="token punctuation">(</span>outputR <span class="token operator">-</span> torch<span class="token punctuation">.</span>mean<span class="token punctuation">(</span>outputF<span class="token punctuation">)</span> <span class="token operator">+</span> labels<span class="token punctuation">)</span> <span class="token operator">**</span> <span class="token number">2</span><span class="token punctuation">)</span> <span class="token operator">+</span> torch<span class="token punctuation">.</span>mean<span class="token punctuation">(</span><span class="token punctuation">(</span>outputF <span class="token operator">-</span> torch<span class="token punctuation">.</span>mean<span class="token punctuation">(</span>outputR<span class="token punctuation">)</span> <span class="token operator">-</span> labels<span class="token punctuation">)</span> <span class="token operator">**</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token operator">/</span><span class="token number">2</span> errG<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span> optimizerG<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span> |
Xử lý dữ liệu
Trong tất cả các bài toán DL thì dữ liệu luôn là thứ quan trọng nhất, và việc đầu tiên ta phải làm là quan sát bộ dữ liệu và tìm ra đặc điểm để xử lý theo yêu cầu. Và như ta thấy bộ dữ liệu của ta gồm 10.000 ảnh gồm:
- Các kích cỡ và định dạng khác nhau
- Nhiều ảnh GIF đặc trưng và lỗi
- Dữ liệu không đồng đều
- Nhiều xe hoặc nhiều vật cản trong một ảnh
- Đa phần là những ảnh có chiều quay là ngang
- Nhiều xe có đặc điểm gị thường và không đồng nhất với bộ dữ liệu
Cách xử lý:
- Loại bỏ những vật thừa trong ảnh và phân loại xe bằng cách sử dụng yolov3 Ở đây.
- sau khi loại bỏ những vật thừa và phân loại xe ta sẽ lọc dữ liệu bằng tay, vì bộ dữ liệu bao gồm rất nhiều xe đa dạng và ảnh không đúng định dạng bị lỗi
- Loại bỏ những xe có chi tiết thừa thãi và có ít trong bộ dữ liệu, chung quy là những xe không đa dạng và có đặc điểm “dị”
- Loại bỏ những ảnh có background quá màu mè
Preprocessing
Sau khi xử lý xong công đoạn trên để đưa ra được tập dữ liệu tốt thì việc chúng ta cần làm tiếp theo là đưa ảnh về kích cỡ 128×128 để đưa vào mạng. Có 2 option cho việc này:
- Padding images: thêm khoảng không gian bên trong ảnh, khoảng không gian này sẽ được cộng dồn thêm vào chiều rộng hoặc chiều cao của ảnh mà không bị biến dạng ảnh
- Resize images: Đưa ảnh về kích thước 128×128 luôn và co giãn ảnh theo chiều rộng hoặc chiều cao
Nhưng vấn đề ở chỗ khi mình training xong và thử cả hai trường hợp và FID evaluation thì thấy padding images cho kết quả đầu ra tốt hơn, tức là mạng học hiệu quả hơn
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | <span class="token keyword">def</span> <span class="token function">padding_image</span><span class="token punctuation">(</span>img<span class="token punctuation">)</span><span class="token punctuation">:</span> im <span class="token operator">=</span> mpimg<span class="token punctuation">.</span>imread<span class="token punctuation">(</span>img<span class="token punctuation">)</span> old_size <span class="token operator">=</span> im<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token number">2</span><span class="token punctuation">]</span> <span class="token comment"># old_size is in (height, width) format</span> ratio <span class="token operator">=</span> <span class="token builtin">float</span><span class="token punctuation">(</span>desired_size<span class="token punctuation">)</span><span class="token operator">/</span><span class="token builtin">max</span><span class="token punctuation">(</span>old_size<span class="token punctuation">)</span> new_size <span class="token operator">=</span> <span class="token builtin">tuple</span><span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token builtin">int</span><span class="token punctuation">(</span>x<span class="token operator">*</span>ratio<span class="token punctuation">)</span> <span class="token keyword">for</span> x <span class="token keyword">in</span> old_size<span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># new_size should be in (width, height) format</span> im <span class="token operator">=</span> cv2<span class="token punctuation">.</span>resize<span class="token punctuation">(</span>im<span class="token punctuation">,</span> <span class="token punctuation">(</span>new_size<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> new_size<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span> delta_w <span class="token operator">=</span> desired_size <span class="token operator">-</span> new_size<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span> delta_h <span class="token operator">=</span> desired_size <span class="token operator">-</span> new_size<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> top<span class="token punctuation">,</span> bottom <span class="token operator">=</span> delta_h<span class="token operator">//</span><span class="token number">2</span><span class="token punctuation">,</span> delta_h<span class="token operator">-</span><span class="token punctuation">(</span>delta_h<span class="token operator">//</span><span class="token number">2</span><span class="token punctuation">)</span> left<span class="token punctuation">,</span> right <span class="token operator">=</span> delta_w<span class="token operator">//</span><span class="token number">2</span><span class="token punctuation">,</span> delta_w<span class="token operator">-</span><span class="token punctuation">(</span>delta_w<span class="token operator">//</span><span class="token number">2</span><span class="token punctuation">)</span> color <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">]</span> new_im <span class="token operator">=</span> cv2<span class="token punctuation">.</span>copyMakeBorder<span class="token punctuation">(</span>im<span class="token punctuation">,</span> top<span class="token punctuation">,</span> bottom<span class="token punctuation">,</span> left<span class="token punctuation">,</span> right<span class="token punctuation">,</span> cv2<span class="token punctuation">.</span>BORDER_CONSTANT<span class="token punctuation">,</span> value<span class="token operator">=</span>color<span class="token punctuation">)</span> new_im <span class="token operator">=</span> <span class="token punctuation">(</span>new_im <span class="token operator">-</span> <span class="token number">127.5</span><span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token number">127.5</span> <span class="token keyword">return</span> new_im |
Đọc đường dẫn của ảnh:
1 2 3 4 5 6 7 8 | PATH <span class="token operator">=</span> <span class="token string">'../dataset'</span> OUTPUT_PATH <span class="token operator">=</span> <span class="token string">'../padding_image_last'</span> files <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> <span class="token comment"># r=root, d=directories, f = files</span> <span class="token keyword">for</span> r<span class="token punctuation">,</span> d<span class="token punctuation">,</span> f <span class="token keyword">in</span> os<span class="token punctuation">.</span>walk<span class="token punctuation">(</span>PATH<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">for</span> <span class="token builtin">file</span> <span class="token keyword">in</span> f<span class="token punctuation">:</span> files<span class="token punctuation">.</span>append<span class="token punctuation">(</span>os<span class="token punctuation">.</span>path<span class="token punctuation">.</span>join<span class="token punctuation">(</span>r<span class="token punctuation">,</span> <span class="token builtin">file</span><span class="token punctuation">)</span><span class="token punctuation">)</span> |
Tạo vòng lặp và gọi hàm để padding Images:
1 2 3 4 | <span class="token keyword">for</span> i <span class="token keyword">in</span> tqdm<span class="token punctuation">(</span><span class="token builtin">range</span><span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span><span class="token builtin">len</span><span class="token punctuation">(</span>files<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">:</span> img <span class="token operator">=</span> padding_image<span class="token punctuation">(</span>files<span class="token punctuation">[</span>i<span class="token punctuation">]</span><span class="token punctuation">)</span> matplotlib<span class="token punctuation">.</span>image<span class="token punctuation">.</span>imsave<span class="token punctuation">(</span>os<span class="token punctuation">.</span>path<span class="token punctuation">.</span>join<span class="token punctuation">(</span>OUTPUT_PATH <span class="token punctuation">,</span> f<span class="token string">'image_{i:05d}.jpg'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> img<span class="token punctuation">)</span> |
Sau khi padding images xong thì chúng ta images augmentation. Mình đã xử dụng các kỹ thuật lất ảnh và quay ảnh, trước đó mình có tăng độ tương phản nhưng kết quả ra khá tệ
1 2 3 4 5 6 7 8 9 10 11 | transform1 <span class="token operator">=</span> transforms<span class="token punctuation">.</span>Compose<span class="token punctuation">(</span><span class="token punctuation">[</span>transforms<span class="token punctuation">.</span>Resize<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token number">128</span><span class="token punctuation">,</span><span class="token number">128</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># Data augmentation and converting to tensors</span> random_transforms <span class="token operator">=</span> <span class="token punctuation">[</span>transforms<span class="token punctuation">.</span>RandomRotation<span class="token punctuation">(</span>degrees<span class="token operator">=</span><span class="token number">5</span><span class="token punctuation">)</span><span class="token punctuation">]</span> transform2 <span class="token operator">=</span> transforms<span class="token punctuation">.</span>Compose<span class="token punctuation">(</span><span class="token punctuation">[</span>transforms<span class="token punctuation">.</span>RandomHorizontalFlip<span class="token punctuation">(</span>p<span class="token operator">=</span><span class="token number">0.5</span><span class="token punctuation">)</span><span class="token punctuation">,</span> transforms<span class="token punctuation">.</span>RandomApply<span class="token punctuation">(</span>random_transforms<span class="token punctuation">,</span> p<span class="token operator">=</span><span class="token number">0.3</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token comment"># transforms.RandomApply([transforms.ColorJitter(brightness=0.2, contrast=(0.9, 1.2), saturation=0.3, hue=0.01)], p=0.5),</span> transforms<span class="token punctuation">.</span>ToTensor<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> transforms<span class="token punctuation">.</span>Normalize<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token number">0.5</span><span class="token punctuation">,</span> <span class="token number">0.5</span><span class="token punctuation">,</span> <span class="token number">0.5</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">0.5</span><span class="token punctuation">,</span> <span class="token number">0.5</span><span class="token punctuation">,</span> <span class="token number">0.5</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span> |
Training mô hình
Như mình đã nói ở trên RaLSGan là một mạng GAN bình thường nó có thể là DCGAN, hay SGAN … nhưng với một hàm loss tốt hơn.
Generator sẽ như sau:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | <span class="token keyword">class</span> <span class="token class-name">Generator</span><span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Module<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> nz<span class="token operator">=</span><span class="token number">128</span><span class="token punctuation">,</span> channels<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token builtin">super</span><span class="token punctuation">(</span>Generator<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>nz <span class="token operator">=</span> nz self<span class="token punctuation">.</span>channels <span class="token operator">=</span> channels <span class="token keyword">def</span> <span class="token function">convlayer</span><span class="token punctuation">(</span>n_input<span class="token punctuation">,</span> n_output<span class="token punctuation">,</span> k_size<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">:</span> block <span class="token operator">=</span> <span class="token punctuation">[</span> nn<span class="token punctuation">.</span>ConvTranspose2d<span class="token punctuation">(</span>n_input<span class="token punctuation">,</span> n_output<span class="token punctuation">,</span> kernel_size<span class="token operator">=</span>k_size<span class="token punctuation">,</span> stride<span class="token operator">=</span>stride<span class="token punctuation">,</span> padding<span class="token operator">=</span>padding<span class="token punctuation">,</span> bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span><span class="token punctuation">,</span> nn<span class="token punctuation">.</span>BatchNorm2d<span class="token punctuation">(</span>n_output<span class="token punctuation">)</span><span class="token punctuation">,</span> nn<span class="token punctuation">.</span>ReLU<span class="token punctuation">(</span>inplace<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">]</span> <span class="token keyword">return</span> block self<span class="token punctuation">.</span>model <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span> <span class="token operator">*</span>convlayer<span class="token punctuation">(</span>self<span class="token punctuation">.</span>nz<span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token comment"># Fully connected layer via convolution.</span> <span class="token operator">*</span>convlayer<span class="token punctuation">(</span><span class="token number">1024</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token operator">*</span>convlayer<span class="token punctuation">(</span><span class="token number">512</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token operator">*</span>convlayer<span class="token punctuation">(</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token operator">*</span>convlayer<span class="token punctuation">(</span><span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token operator">*</span>convlayer<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">32</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> nn<span class="token punctuation">.</span>ConvTranspose2d<span class="token punctuation">(</span><span class="token number">32</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>channels<span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> nn<span class="token punctuation">.</span>Tanh<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token punctuation">)</span> <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> z<span class="token punctuation">)</span><span class="token punctuation">:</span> z <span class="token operator">=</span> z<span class="token punctuation">.</span>view<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>nz<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> img <span class="token operator">=</span> self<span class="token punctuation">.</span>model<span class="token punctuation">(</span>z<span class="token punctuation">)</span> <span class="token keyword">return</span> img |
Như ta có thể thấy output đầu ra của G là 128×128 với hơn 13 triệu tham số
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | ---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ ConvTranspose2d-1 [-1, 1024, 4, 4] 2,097,152 BatchNorm2d-2 [-1, 1024, 4, 4] 2,048 ReLU-3 [-1, 1024, 4, 4] 0 ConvTranspose2d-4 [-1, 512, 8, 8] 8,388,608 BatchNorm2d-5 [-1, 512, 8, 8] 1,024 ReLU-6 [-1, 512, 8, 8] 0 ConvTranspose2d-7 [-1, 256, 16, 16] 2,097,152 BatchNorm2d-8 [-1, 256, 16, 16] 512 ReLU-9 [-1, 256, 16, 16] 0 ConvTranspose2d-10 [-1, 128, 32, 32] 524,288 BatchNorm2d-11 [-1, 128, 32, 32] 256 ReLU-12 [-1, 128, 32, 32] 0 ConvTranspose2d-13 [-1, 64, 64, 64] 131,072 BatchNorm2d-14 [-1, 64, 64, 64] 128 ReLU-15 [-1, 64, 64, 64] 0 ConvTranspose2d-16 [-1, 32, 128, 128] 32,768 BatchNorm2d-17 [-1, 32, 128, 128] 64 ReLU-18 [-1, 32, 128, 128] 0 ConvTranspose2d-19 [-1, 3, 128, 128] 867 Tanh-20 [-1, 3, 128, 128] 0 ================================================================ Total params: 13,275,939 Trainable params: 13,275,939 Non-trainable params: 0 ---------------------------------------------------------------- |
Discriminator:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | <span class="token keyword">class</span> <span class="token class-name">Discriminator</span><span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Module<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> channels<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token builtin">super</span><span class="token punctuation">(</span>Discriminator<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>channels <span class="token operator">=</span> channels <span class="token keyword">def</span> <span class="token function">convlayer</span><span class="token punctuation">(</span>n_input<span class="token punctuation">,</span> n_output<span class="token punctuation">,</span> k_size<span class="token operator">=</span><span class="token number">4</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token number">0</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span><span class="token punctuation">:</span> block <span class="token operator">=</span> <span class="token punctuation">[</span>nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span>n_input<span class="token punctuation">,</span> n_output<span class="token punctuation">,</span> kernel_size<span class="token operator">=</span>k_size<span class="token punctuation">,</span> stride<span class="token operator">=</span>stride<span class="token punctuation">,</span> padding<span class="token operator">=</span>padding<span class="token punctuation">,</span> bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span><span class="token punctuation">]</span> <span class="token keyword">if</span> bn<span class="token punctuation">:</span> block<span class="token punctuation">.</span>append<span class="token punctuation">(</span>nn<span class="token punctuation">.</span>BatchNorm2d<span class="token punctuation">(</span>n_output<span class="token punctuation">)</span><span class="token punctuation">)</span> block<span class="token punctuation">.</span>append<span class="token punctuation">(</span>nn<span class="token punctuation">.</span>LeakyReLU<span class="token punctuation">(</span><span class="token number">0.2</span><span class="token punctuation">,</span> inplace<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">return</span> block self<span class="token punctuation">.</span>model <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span> <span class="token operator">*</span>convlayer<span class="token punctuation">(</span>self<span class="token punctuation">.</span>channels<span class="token punctuation">,</span> <span class="token number">32</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token operator">*</span>convlayer<span class="token punctuation">(</span><span class="token number">32</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token operator">*</span>convlayer<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token operator">*</span>convlayer<span class="token punctuation">(</span><span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token operator">*</span>convlayer<span class="token punctuation">(</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">512</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> bn<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">,</span> nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">512</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> bias<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token comment"># FC with Conv.</span> <span class="token punctuation">)</span> <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> imgs<span class="token punctuation">)</span><span class="token punctuation">:</span> out <span class="token operator">=</span> self<span class="token punctuation">.</span>model<span class="token punctuation">(</span>imgs<span class="token punctuation">)</span> <span class="token keyword">return</span> out<span class="token punctuation">.</span>view<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | ---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1, 32, 64, 64] 1,536 LeakyReLU-2 [-1, 32, 64, 64] 0 Conv2d-3 [-1, 64, 32, 32] 32,768 LeakyReLU-4 [-1, 64, 32, 32] 0 Conv2d-5 [-1, 128, 16, 16] 131,072 BatchNorm2d-6 [-1, 128, 16, 16] 256 LeakyReLU-7 [-1, 128, 16, 16] 0 Conv2d-8 [-1, 256, 8, 8] 524,288 BatchNorm2d-9 [-1, 256, 8, 8] 512 LeakyReLU-10 [-1, 256, 8, 8] 0 Conv2d-11 [-1, 512, 4, 4] 2,097,152 BatchNorm2d-12 [-1, 512, 4, 4] 1,024 LeakyReLU-13 [-1, 512, 4, 4] 0 Conv2d-14 [-1, 1, 1, 1] 8,192 ================================================================ Total params: 2,796,800 Trainable params: 2,796,800 Non-trainable params: 0 ---------------------------------------------------------------- |
Output
Mình đã thử rất nhiều trường hợp để ra kết quả tốt nhất thì thấy nên đặt trong khoảng 550 – 750 epochs là ra kết quả khá đẹp FID khoảng từ 80 -> 62
1 2 3 4 5 6 7 | batch_size <span class="token operator">=</span> <span class="token number">64</span> LR_G <span class="token operator">=</span> <span class="token number">0.0008</span> LR_D <span class="token operator">=</span> <span class="token number">0.0008</span> epochs <span class="token operator">=</span> <span class="token number">750</span> real_label <span class="token operator">=</span> <span class="token number">0.8</span> fake_label <span class="token operator">=</span> <span class="token number">0</span> |
Một số kết quả:
Nguồn tham khảo
https://www.kaggle.com/c/generative-dog-images/discussion/99485