Meta Pseudo Label: ImageNet Top1 Accuracy 90% 達成!

Kai-Wen Zhao
19 min readJun 9, 2021

--

前言

2021 年 1 月 Google Brain 的 Principal Scientist — Quoc Le 在個人 Twitter 帳號上宣佈首次在 ImageNet Benchmark 上達到了 Top 1 Accuracy 90% 的成績,也讓 Meta Pseudo Label 這篇命運多舛的 Paper 搖身一變成為 State-of-the-Art 的作品。

Quoc Le 的 Twitter 發文: https://twitter.com/quocleix/status/1349443438698143744

Meta Pseudo Label (簡稱 MPL) 這篇 Paper 於一年前 2020/3 第一次出現在 ArXiv 。雖然身為 Semi-Supervised Learning 主題的文章 (SSL 當時正熱門呢),MPL 一開始並沒有受到太多關注,尚有一些理論與實務上瑕疵的地方。例如直接利用 Validation Loss 來訓練模型,或是硬是用了 Meta Learning 框架來解釋 Gradient 的走向 (但筆者認為 MPL 的核心精神更加貼近 Reinforcement Learning 中的 Actor-Critic) 。

隨著 Paper 的不斷改寫,在最新的 v4 (arXiv:2003.10580v4) 中,Paper 作者已有詳細推導 Student (Actor) 與 Teacher (Critic) Model 是如何交互作用,以嚴謹的數學說明 MPL 是如何拿下 ImageNet SOTA 的成績,而非以往認知 Semi-Supervised Learning 靠著龐大的資料與混合的 Regularization 取勝。

本篇文章會仔細介紹 MPL 的數學推導,說明 Teacher-Student 之間是如何相輔相成,希望能讓大家都欣賞到 MPL 精妙的設計。

Semi-Supervised Learning (SSL) 成為主流

攤開近年 ImageNet Benchmark,可以明顯看到 SSL 屠榜的趨勢。

EfficientNet V2 在 2021 年被發明了,之後的進步依舊可以期待!

在 2018 年之前,主要的進步來自於架構的改善; 在 EfficientNet 被發明 (或者說發現) 後架構就沒有大幅進步了,研究則轉向由 Data-Driven,也就是 SSL 演算法的改進。

First-Tier of ImageNet Benchmark (Recorded 2021/3); EfficientNet-L2 不需修改能達到 85.5% 的驚人成績,但在 MPL 的助力之下,還能提升接近 5% 的準確率!

究竟 SSL 的表現有多強呢?

讓我們觀察 Second-tier 的榜單 (84%~85%) 。這個區間多為 Network Architecture 的改善 (藍色標記),但很明顯的,改架構的效益極為有限 (<1%); 但伴隨的 Model Size (#params) 與 FLOPS 影響很巨大。相較之下,利用 Noisy-Student SSL 演算法與大量 Unlabeled Data 訓練的 EfficientNet-B3 (橘色標記) ,能以輕量的架構達到相同水準。例如與 ResNet-RS-350 相差近兩個數量級,相當不講武德。

Second-Tier of ImageNet Benchmark (Recorded 2021/3)
不講武德的表現

另外 SSL 也我們看到 Neural Network 的 Capcaity 其實相當高,就像經典的 ResNet-50 到現在還沒被資料 “塞滿”,依舊能不斷進步。

“There’s plenty of room for ResNet-50”

Self-Training Framework

首先,我們比較 Noisy-Student 與 MPL 的架構。

Noisy-Student 演算法:

  1. 在 Labeled Data 上訓練 Teacher Model
  2. 由 Teacher 在 Unlabeled Data 上標記出 Pseudo Label 作為 Student 的訓練資料
  3. 因為 Pseudo Label 很髒 ,Student 訓練時需要加上各種 Regularization 方法避免 Confirmation Bias
  4. Student Model 作為下一個 Iteration 的 Teacher Model

以上演算法構成了一個 Open Loop 的 Self-Training 框架,Teacher 並不知道 Student 的學習狀況,Student 也僅是靠著 RandAug, Stochastic Depth 種種高強度的 Regularization 避免學壞。

MPL 的目標是解決 Teacher-Student 缺乏互動的問題

Open Loop & Closed Loop Self-Training Framework

MPL 演算法構建了一個 Closed Loop (或稱 Feedback Loop) 訓練框架,也就是說,Teacher Model 會根據 Student 的學習狀況做調整。

  1. 由 Teacher 在 Unlabeled Data 上標記出 Pseudo Label 作為 Student 的訓練資料
  2. Teacher 根據 Student 的學習狀況來更新自己,讓下一個 Iteration 能標記出更適合 Student 學習的 Pseudo Label

以上就是 MPL 演算法的概念,也就是 Update Teacher 時同時考慮 Student。接下來我們將逐步構建出他的 Training Algorithm。

延伸思考

  1. MPL 可以類比 Reinforcement Learning 中 Agent 與 Environment 的關係,action 與 reward 類比為 pseudo label 與 student signal
  2. 在 Noisy-Student 中 Teacher 與 Student Model 的地位其實是等價的,因為一個 Iteration 後會發生取代; 在 MPL 中 Teacher 與 Student Model 其實不等價,Teacher 只作為 Data 的容器 (data distribution) ,因此在 Paper 附錄中衍伸出 Economical version of MPL — 用一個 multi-layer perceptron 透過學習 Identity map 來充當 Teacher Model 。此部分本文不展開討論,留給讀者思考玩味。

MPL 演算法推導 — Part I

由於 MPL 是數學上嚴謹的演算法,這意味著枯燥與乏味 (笑),我們分為 3 個部分,從不同面向切入,讓讀者能循序漸進的了解 MPL。

Cross Entropy Loss

畢竟 MPL 還是 Image Classification 問題,使用的 Loss 是經典的 Cross Entropy Loss 。本文要特別強調的是 Target Distribution ,在不同的 Training Scheme 會選用不同的表示,例如 One-Hot 用於 Supervised Learning; Soft-Label 常用於 SSL 等等。Target Distribution 的選擇對 MPL 帶來深遠的影響。

Cross Entropy Loss Is Our Old Friend

另外,讀者可以留意本文中 Cross Entropy 用了兩種表示 L(theta) 與 CE(target, predict),前者強調 Model 後者強調 Prediction。另外下標 l 表示 對 labeled data 計算 loss; 下標 u 代表 unlabeled data 。

Pseudo Label as Bi-Level Optimization Problem

要建構演算法,就要定義 Teacher 與 Student 是如何溝通的。我們將之前的概念轉換為以下兩條 Update Rules

Rule 1 — Student Update Rule: 用 Teacher 給出 Pseudo Label 來訓練 Student

Train Student with Pseudo Label on Unlabeled Data

Rule 2 — Teacher-Student Update Rule: 用 Student 在 Labeled Data 上的表現作為回饋,來修正 Teacher 的教導能力

Update Teacher-Student with Student Performance on Labeled Data

把兩條式子放在一起,就是我們 MPL Training Problem

Bi-Level Optimization Problem

但這個 Training Problem 實務上很困難的,有兩個地方需要簡化。

簡化一:在 Student 的 Update Rule 中,Teacher 以 Soft-Label 的形式直接出現在 Loss Function ,意味著整個 Teacher Model 會出現在 Back-propagation Graph 中,會造成訓練非常緩慢 (想像一下 Teacher 也是 EfficientNet-L2) 。

解決的方法是將 Soft-Label 轉為 Hard-Label ,由於 ArgMax Operator 不可微分,能直接 Break End-to-end Path。

簡化二:在 Teacher-Student Update Rule 中,Teacher 直接與 Student Model 耦合在一起,成為很複雜的 Optimization Problem,實務上需要把耦合去除。

因此 Paper 中利用了簡化的技巧 (可以說 One-Step Update 或者 First-order Taylor Expansion) ,將原本 Student 受 Teacher 影響而耦合的式子展開,成為簡單的 Gradient Update Rule ,並且可以利用前一步 Student Update Rule 的結果來替換。

將耦合項用 One-step Student Update 取代; 正好這項就是 bi-level 的第一條

經過實務上的簡化,我們能得到較詳細的關係圖如下:

Feedback Loop Diagram

Teacher 靠著 Pseudo Label 把資訊傳給 Student,同時做 Harden Pseudo Label 降低運算量;透過 Reused Student’s Update,更新 Teacher 使他更能因材施教。

Practical Update Rules in Paper

以上就是 MPL 的演算法,Paper 的正文也是描述到這邊。基本上交代了一個自洽也合理的流程,讓大家都明白訓練目標與做法。但同時也留下一些困惑與實務上的含糊,而本文的目的就是說明這些細節。

問題

  • 簡化二 宣稱 Decouple 了原本的 Teacher-Student Update,說好的師生互動是怎麼樣實踐的?
  • Reuse Student’s Update 是什麼意思?是直接用 Student’s Weights 取代 Teacher’s Weights 嗎?
  • 把 Soft Pseudo Label 轉成 Hard Pseudo Label 造成什麼影響?
  • Teacher 負責標記 Unlabeled Data,但 Teacher 有從中獲得資訊來修正自己嗎? 還是全靠 Labeled Data?
  • Student 出現在 Rule 1 與 2 ,到底 Student Inference 幾次 Update 幾次?

現在,讓我們進入 Part II 來看看 MPL 真實的面貌。

MPL 演算法推導 — Part II

MPL 最困難的在於如何 Rule2 — Update Teacher-Student (當一個顧慮到學生的老師真難),本節直接推導 Update Teacher-Student 的 Gradient 並將在 Part III 以演算法的形式再次說明。

推導過程

為了維持閱讀愉快,推導完整過程位於本文末,更詳細過程也可以參考原始 Paper 附錄。

推導結果

我們直接上結果。以數學表達相當漂亮,Teacher-Student Update Rule 如下:

Solution to Meta Pseudo Label Optimization Problem

我們逐項來說明:

  • 左式是 Gradient of Teacher-Student Update on Labeled Data,意義為根據 Student 在 Labeled Data 上面的表現,修正 Teacher 的教育方式
  • Gradient of Teacher-Student Update 其實與 Gradient of Teacher (on Unlabeled Data) 向量平行:意味著 Teacher 本身能不能在 Unlabeled Data 上標出好的結果直接影響 Teacher-Student 的關係 (Teacher 多會教,直接影響師生的關係)
  • Teacher-Student Gradient 跟 Student Learning Rate 成正比
  • Teacher-Student Gradient 與一個常量 h 成正比 — 這個常量 (scalar) 就是 Student Model 是否有進步的指標 (Student Feedback Signal)
  • h (Student Feedback Signal) 為兩個 Gradient 內積:分別是原本 Student 在 Unlabeled Data 上的表現與 Update 過後的 Student 在 Labeled Data 讓的表現內積

MPL 演算法推導 — Part III

最後,脫離複雜的數學公式,我們從直觀的演算法與流程圖再次理解 MPL 。

演算法

從 Part II 詳細地推導 (於文末) ,我們列出各項 Gradient 的來源,並且逐步在演算法實作中取得他們。(為了方便,我們用顏色區分需計算的每項,一直到文末顏色都保持一致的)

Composite of Teacher-Student Update

首先,我們寫下演算法的 Pseudo Code 如下:

接著,我們用圖示來說明以上演算法動態時發生什麼事情。

Gradient Flow

規則:虛線為 Forward; 實線為 Backprop 取得 Gradient; 若有相同顏色線條形成 Loop 代表 Model 被 Update。

Step 1:Teacher Model Inference Unlabeled Data,產生一組 Pseudo Label,並且轉成 Hard-Label 。

Sample Pseudo Label and Harden them

Step 2: Student Model Forward & Backward on Unlabeled Data,並且 Update 。也就是 Rule1 — Student Update。

Update Student with Pseudo Labels

Step 3:Student Inference Labeled Data 計算 並且 Backward 取得 Gradient ,注意這裡沒有形成 Loop 因此 Student Model 並不會因為這次 Backprop 而更新。此外 Reuse Student’s Update 算出 Student Feedback Signal h 。

Compute Student’s Feedback signal

Step 4 -1:Teacher Model 再次 Inference Unlabeled Data 並且計算 Gradient。注意,這裡容易會產生盲點,因為 Unlabeled Data 是 Teacher 自己標記出來的,為什麼能算出 Non-Zero Loss 並且 Backprop 呢?原因正是因為 Step 1 已經將 Pseudo Label 轉為 Hard-Label,計算 Cross Entropy 時還是能算出結果。算出紅色 Gradient 後與 h 相乘。

Compute Teacher’s Gradient from Student’s Feedback

Step 4–2:Teacher Model Inference Labeled Data 並 Backward 計算 Gradient。

Compute Teacher’s Gradient on Labeled Data

Step 4–3 :最後把 Teacher Model 的 Loop Close 起來,一次 Update Teacher Model。完成一個 MPL Step 。

Update Teacher Model

延伸討論

回答先前的問題

  • 師生互動體現在微分後的 Jacobian Matrix,最後化成藍色與紅色的 Gradient
  • Reuse Student’s Update 指的就是藍色 Gradient 後來計算 h 可以重複使用
  • Teacher 會從中自己標出的 Hard Label 與 Soft Label 間產生 Loss 而且合併 Student Feedback 來 Update
  • Student 其實 Inference 了兩次,但第二次只取 Gradient 並不會 Update Model (有趣的是 Student 並不直接透過 Labeled Data Training)
  • 如果使用 Soft Pseudo Label,基本上就完全不適用這裡推導的所有公式,而整個問題應該可以 End-to-end Training

那,為什麼非得做 Hard Pseudo Label 呢?用 Soft Pseudo Label 會有問題嗎?

除了免除 End-to-end 很龐大以外,筆者認為當計算 CE(q, p) Prediction 與 Target 皆為 Soft-Label 時,Cross Entropy 並沒有辦法真正拉近兩個 Distribution 的距離,因為 Divergence 公式表明 KL = CE(q, p) - H(q),如果單純用 CE 可能無法防止 Teacher Model 產生 High Entropy 模稜兩可的 Pseudo Label; 相較之下,Hard-Label 則強迫 Teacher 做出明確的決定。

結論

以上,就是 MPL 的詳細運作過程。從實驗中我們可以看到他的效果非常顯著; 此外,在數學推導上本篇也是有相當嚴謹,並提出合理的訓練框架,可以說 SOTA 的結果是靠著演算法的進步所拿下的。

ResNet 50 Experiment
State-of-the-Art Performance on ImageNet

附錄

資料: Labeled Data 與 Unlabeled Data

在本文與實驗中提到的 Labeled Data 均為 ImageNet; Unlabeled Data 為 JFT-300M。讀者可以注意數量級即可,Unlabeled Data 圖片數量為 Labeled Data 的 100 倍 (JFT 300M 中取 130M 來用),約有 20% 是帶有 noise 的。

Teacher-Student Update Rule 推導過程

首先,是一個數學上的建模。引入一個 Estimator 作為 Updated Student 的數學模型 (因為 Student Update 以後變成什麼樣子,並不能完全知道,所以我們用他的平均行為來描述)。接著就是暴力微分:

根據 Chain Rule 與變數變換,我們能得到一個 Jacobian Matrix ,它也就是最重要的一項 (藍色框框),描述了 Student Model 如何因為 Teacher Model 而改變。

接這是一些變數的合併與整理,讓我們比較好看。

整理過後,一個對 Estimator 中 Distribution 微分的形式出現了,而他與 Reinforcement Learning 中常用到的 Policy Gradient 如出一轍。因此我們可以用這個公式替換,最後整理出 Cross Entropy 的形式。

由於我們 Harden Pseudo Label — Break End-to-End Path,當要微分時跟 RL 中要對 Expected Reward 計算 Policy Gradient 是一樣的道理。

最後我們把得到的各項放在一起,可以求得最後的 Teacher-Student Gradient。

References

--

--

No responses yet