แค่ 10 training images (few-shot learning) ก็ทำ transfer learning บน GAN ได้ จาก CVPR2021

PUSH TECH
4 min readJul 23, 2021

--

Problems of few-shot image generation

เป็นเปเปอร์น่าสนใจจาก CVPR2021 ผลงานร่วมระหว่าง Adobe Research, UC Davis และ UC Berkeley

งานนี้พูดถึงปัญหาของ generative models ที่ *โดยทั่วไป* เหมือนจะรู้กันว่าถ้าอยาก train-from-scratch ก็ต้องใช้ training images จำนวนไม่น้อย (thousands of images) ถึงจะได้ Generative Adversarial Network (GAN) ดี ๆ สักตัวที่ทำงาน image generation ได้ภาพผลลัพธ์ออกมางาม ๆ … ปัญหาคือคนทั่วไปมักมีทรัพยากร (คอมและเวลา) ไม่มากพอจะ train-from-scratch ตัว GAN ระดับ state-of-the-art (SOTA) ได้ ตัวอย่างเช่น StyleGAN v2 ตัวดัง กรณี train บน FFHQ dataset at 1024x1024 ต้องใช้ GPU ขั้นต่ำ 16GB และใช้เวลา train เกือบ 70 วัน (ถ้าใช้ 1 GPU)

สำหรับโมเดล Deep Learning (DL) หลายตัวที่ไม่ใช่ GAN การทำ transfer learning ก็เป็นทางออกที่ดี ทำให้เราไม่ต้อง train-from-scratch แต่สามารถ jump start ตั้งต้นจาก baseline model ที่มีคน pre-train ไว้ให้ดีแล้วระดับหนึ่ง … จากนั้นหน้าที่เราก็แค่เอามาต่อยอดอีกหน่อย มา train เพิ่มอีกนิดด้วย our custom image dataset ที่มีจำนวน training images น้อยกว่า และใช้ทรัพยากรคำนวณน้อยกว่าเดิมมาก

คำถาม คือ เราสามารถทำ transfer learning อย่างนี้กับ GAN สำหรับงาน image generation ได้ไหม?

คำตอบ คือ ก็มีหลายงานพยายามทำกันอยู่นะ หลักการแสดงในรูปแรกใต้โพสต์ (figure 1) คือ

… เรามี pre-trained GAN สักตัวที่ถูก train มาบน a large dataset (เช่น ข้อมูลหลักล้านรูป) และทำงาน image generation บน domain ของตัวเอง (เช่น สร้างภาพหน้าคนเสมือนจริง) ได้ดีอยู่แล้ว

… วันนึงเราอยากจะเอา pre-trained GAN ตัวที่ว่ามาปรับให้สามารถทำ image generation บน domain ที่ต่างจากเดิมได้ (เช่น สร้างภาพหน้าการ์ตูน หรือ สร้างภาพหน้าเด็ก) โดยที่ขอใช้ training images เพื่อการปรับครั้งนี้แค่สัก N ภาพก็พอ

เทคนิคก่อน ๆ นี้นั้น ค่า N อย่างน้อยก็ต้องสักหลักร้อย เพราะถ้า N น้อยกว่านี้ภาพที่ได้ออกมาจะคุณภาพต่ำ หรือไม่ก็เกิด overfit ทำให้เกิดอาการภาพซ้ำเหมือนที่แสดงในแถวกลางของภาพที่ 2 ใต้โพสต์ (figure 2) … แต่ในเปเปอร์นี้เขาออกมาเกทับว่าถ้าใช้วิธีของเขานะ ขอแค่ 10 training images (N=10) เท่านั้น เขาก็ทำ transfer learning สอน pre-trained GAN ให้สร้างรูปภาพใน different domain ที่ต้องการได้แล้ว

Proposed method

ไอเดียของงานนี้มีส่วนประกอบหลายส่วนแสดงในรูปที่ 3 ใต้โพสต์ (figure 3) รายละเอียดดังนี้

… G_s (คางหมูสีส้ม ล็อกกุญแจ) คือ the source generator ของตัว pre-trained GAN ที่ train มาอย่างดีแล้วบน a large dataset … เช่น ในภาพนี้ G_s ถูก train มาให้รับ latent vector z_i แล้วสร้างภาพหน้าคนแบบเหมือนจริงขึ้นมา … จากเปเปอร์ งานนี้น่าจะใช้ G_s เป็น StyleGAN v2 (pre-trained on FFHQ dataset)

… G_s=>t (คางหมูสีเขียว) คือ the adapted generator ของงานนี้ที่ต้องการ train ด้วยข้อมูลแค่ 10 training images … ในภาพนี้ G_s=>t จะรับ latent vector z_i แล้วสร้างภาพหน้าคนแนว painting style ขึ้นมา … ค่า weight ของ G_s=>t เห็นว่า initialize ตอนแรกเริ่มด้วย weight ชุดเดียวกันกับ G_s นั่นล่ะ

… D_img และ D_patch (คางหมูสีฟ้า) คือ full-image discriminator และ patch-level discriminator ซึ่งเดี๋ยวจะมี condition ให้อีกทีว่ากรณีไหนจะเลือกใช้ discriminator ตัวไหน

Contribution หลักของงานนี้มี 2 ส่วน ได้แก่

1 Cross-domain distance consistency loss (L_dist): actively uses the original source generator to regularize the tuning process

ในเรื่องนี้ไอเดียของเค้า คือ: if the model can preserve the relative similarities and differences between instances in the source domain, then it has the chance to inherit the diversity in the source domain while adapting to the target domain.

จากภาพ figure 3 จะเห็นว่า (inspire from contrastive learning)

… ด้านซ้ายมือ มีการคำนวณหา similarity (s^s) ระหว่าง ‘ผลลัพธ์จาก G_s ที่มีอินพุตคือ z_0’ เทียบกันกับ ‘ผลลัพธ์จาก G_s ที่มีอินพุตคือ z_i อื่น ๆ’

… ด้านขวามือ มีการคำนวณหา similarity (s^t) ระหว่าง ‘ผลลัพธ์จาก G_s=>t ที่มีอินพุตคือ z_0’ เทียบกันกับ ‘ผลลัพธ์จาก G_s=>t ที่มีอินพุตคือ z_i อื่น ๆ’

… ผลลัพธ์ similarity ของแต่ละฝั่ง (ซ้าย s^s และขวา s^t) เอามาทำ softmax เพื่อรวมค่าในฝั่งเดียวกันก่อน แล้วค่อยเอาผลรวม softmax จากแต่ละฝั่งมาคำนวณ KL divergence ร่วมกันอีกที

2 Relaxed discriminator: encourages different levels of realism over different regions in the latent space

เพื่อป้องกัน overfit จากข้อมูล training images ที่มีน้อย (N=10) งานนี้เสนออีกหนึ่งมาตรการ คือ relaxed discriminator หรือการที่เรามี two discriminators (D_img และ D_patch) นั่นเอง … ส่วนนี้ใช้ adversarial loss คือ L’_adv = L_adv(G, D_img) + L_adv(G, D_patch)

สำหรับการเลือกว่า ณ ขณะหนึ่งเราจะใช้ D_img หรือ D_patch นั้น อาศัยการดูว่าค่า latent ที่ใช้อยู่ใน “anchor regions” หรือเปล่า ถ้าอยู่ก็ใช้ D_img แต่ถ้าไม่อยู่ก็ใช้ D_patch

สำหรับวิธีการกำหนด anchor regions ก็ตามนี้ >> Z_anch forms a subset of the entire latent space. To define the anchor space, we select k random points, corresponding to the number of training images, and save them. We sample from these fixed points, with a small added Gaussian noise (σ = .05). We use shared weights between the two discriminators by defining D_patch as a subset of the larger D_img network.

Results

ผลการทดลองส่วนนึงที่น่าสนใจ คือ

… ถ้าไม่ใช้ L_dist เลยจะทำให้ความ diversity ของภาพที่ generator สร้างขึ้นมาลดน้อยลง เช่น ทุกภาพมีศีรษะและทรงผมคล้ายกันหมด

… ถ้าใช้ L_dist ปกติ แต่เลือกใช้เฉพาะ D_img (ไม่ใช้ D_patch เลย) ผลคือ >> mode collapse at the part level (same blue hat appears in multiple generations) and the phenomenon where some results are only slight modifications of the same mode (same girl with and without the blue hat)

… ถ้าใช้ L_dist ปกติ แต่เลือกใช้เฉพาะ D_patch (ไม่ใช้ D_img เลย) ผลคือ >> more diversity, but poorer quality (เมื่อเทียบกับตอนที่ใช้แต่ D_img)

… ผลสรุปงานนี้เลยเลือกใช้หลายวิธีปน ๆ กันเพื่อให้ได้ภาพผลลัพธ์ที่ทั้ง diverse และทั้ง realistic ไม่ว่าจะในระดับ part-level หรือ image-level ก็ตาม

ถ้าใครยังคาใจจุดไหนไปอ่านต่อกันได้ในลิงก์ข้างล่างนี้ … เตือนไว้ก่อนว่าโค้ดเค้า(แทบ)ไม่มีคำอธิบายโค้ดกำกับเลย จะไล่โค้ดกันยากนิดนึง

อ้างอิงข้อมูลและรูปภาพจากลิงก์ต่อไปนี้

… เปเปอร์ arXiv 13APR2021 (15 pages): https://arxiv.org/abs/2104.06820

… Github (PyTorch): https://github.com/utkarshojha/few-shot-gan-adaptation

… Website: https://utkarshojha.github.io/few-shot-gan-adaptation/

--

--

No responses yet