遷移學習:互信息的變分上下界
1 導引
在機器學習,尤其是涉及異構數據的遷移學習/聯邦學習中,我們常常會涉及互信息相關的優化項,我研一下期的處女作(發在SDM'24上)也是致力于此(ArXiv論文鏈接:FedDCSR,GitHub源碼鏈接:FedDCSR)。其思想雖然簡單,但其具體的估計與優化手段而言卻大有門道,我們今天來好好總結一下,也算是對我研一的一個收尾。
我們知道,隨機變量\(X\)和\(Y\)的互信息定義為其聯合分布(joint)\(p(x, y)\)和其邊緣分布(marginal)的乘積\(p(x)p(y)\)之間的KL散度(相對熵)[1]:
直觀地理解,互信息表示一個隨機變量包含另一個隨機變量信息量(即統計依賴性)的度量;同時,互信息也是在給定另一隨機變量知識的條件下,原隨機變量不確定度的縮減量,即\(I(X; Y) = H(X) - H(X \mid Y) = H(Y) - H(Y\mid X)\)。當\(X\)和\(Y\)一一對應時,\(I(X; Y) = H(X) = H(Y)\);當\(X\)和\(Y\)相互獨立時\(I(X; Y)=0\)。
在機器學習的情境下,聯合分布\(p(x, y)\)一般是未知的,因此我們需要用貝葉斯公式將其繼續轉換為如下形式:
那么轉換為這種形式之后,我們是否就可以開始對其進行估計了呢?答案是否定的。我們假設現在是深度表征學習場景,\(X\)是數據,\(Y\)是數據的隨機表征,則對于第\((1)\)種形式來說,條件概率分布\(p(x|y)=\frac{p (y|x)p(x)}{\int p(y|x)p(x)dx}\)是難解(intractable)的(由于\(p(x)\)未知);而對于第\((2)\)種形式而言,邊緣分布\(p(y)\)也需要通過積分\(p(y)=\int p(y \mid x)p(x)d x\)來進行計算,而這也是難解的(由于\(p(x)\)未知)。為了解決互信息估計的的難解性,我們的方法是不直接對互信息進行估計,而是采用變分近似的手段,來得出互信息的下界/上界做為近似,轉而對互信息的下界/上界進行最大化/最小化[2]。
2 互信息的變分下界(對應最大化)
我們先來看互信息的變分下界。我們常常通過最大化互信息的下界來近似地對其進行最大化。具體而言,按照是否需要解碼器,我們可以將互信息的下界分為兩類,分別對應變分信息瓶頸(解碼項)[3][4]和Deep InfoMax[5][6]這兩種方法。
2.1 數據VS表征:變分信息瓶頸(解碼項)
對于互信息的第\((1)\)種表示法即\(I(X ; Y){=}\mathbb{E}_{p(x, y)}\left[\log \frac{p(x \mid y)}{p(x)}\right]\),我們已經知道條件分布\(p(x|y)\)是難解的,那么我們就采用變分分布\(q(x|y)\)將其轉變為可解(tractable)的優化問題。這樣就可以導出互信息的Barber & Agakov下界(由于KL散度的非負性):
這里\(H(X)\)是\(X\)的微分熵,BA是論文[7]兩位作者名字的縮寫。當\(q(x|y)=p(x|y)\)時,該下界是緊的,此時上式的第一項就等于條件熵\(H(X|Y)\)。
上式可不可解取決于微分熵\(H(X)\)是否已知。幸運的是,限定在 \(X\)是數據,\(Y\)是表征 的場景下,\(H(X)=\mathbb{E}_{x\sim p(x)} \log p(x)\)僅涉及數據生成過程,和模型無關。這意味著我們只需要最大化\(I_{\text{BA}}\)的第一項,而這可以理解為最小化VAE中的重構誤差(失真,distortion)。此時,\(I_{\text{BA}}\)的梯度就與“編碼器”\(p(y|x)\)和變分“解碼器”\(q(x|y)\)相關,而這是易于計算的。因此,我們就可以使用該目標函數來學習一個最大化\(I(X; Y)\)的編碼器\(p(y|x)\),這就是大名鼎鼎的變分信息瓶頸(variational information bottleneck) 的思想(對應其中的解碼項部分)。
2.2 表征VS表征:Deep Infomax
我們在 2.1 中介紹的方法雖然簡單好用,但是需要構建一個易于計算的解碼器\(q(x|y)\),這在\(X\)是數據,\(Y\)是表征的時候非常容易,然而當 \(X\)和\(Y\)都是表征 的時候就直接寄了,首先是因為解碼器\(q(x|y)\)是難以計算的,其次微分熵\(H(X)\)也是未知的。為了導出不需要解碼器的可解下界,我們轉向去思考\(q(x|y)\)變分族的的非標準化分布(unnormalized distributions)。
我們選擇一個基于能量的變分族,它使用一個判別函數/網絡(critic)\(f(x, y): \mathcal{X} \times \mathcal{Y}\rightarrow \mathbb{R}\),并經由數據密度\(p(x)\)縮放:
我們將該分布代入公式\((3)\)中的\(I_{\text{BA}}\)中,就導出了另一個互信息的下界,我們將其稱為UBA下界(記作\(I_{\text{UBA}}\)),可視為Barber & Agakov下界的非標準化版本(Unnormalized version):
當\(f(x, y)=\log p(y|x) + c(y)\)時,該上界是緊的,這里\(c(y)\)僅僅是關于\(y\)的函數(而非\(x\))。注意在代入過程中難解的微分熵\(H(X)\)被消掉了,但我們仍然剩下一個難解的\(\log\)配分函數\(\log Z(y)\),它妨礙了我們計算梯度與評估。如果我們對\(\mathbb{E}_{p(y)}[\log Z(y)]\)這個整體應用Jensen不等式(\(\log\)為凹函數),我們能進一步導出式\((5)\)的下界,即大名鼎鼎的Donsker & Varadhan下界[7]:
然而,該目標函數仍然是難解的。接下來我們換個角度,我們不對\(\mathbb{E}_{p(y)}[\log Z(y)]\)這個整體應用Jensen不等式,而考慮對里面的\(\log Z(y)\)應用Jensen不等式即\(\log Z(y)=\log \mathbb{E}_{p(x)}\left[e^{f(x, y)}\right]\geq\mathbb{E}_{p(x)}\left[\log e^{f(x, y)}\right]=\mathbb{E}_{p(x)}\left[f(x, y)\right]\),那么我們就可以導出式\((5)\)的上界來對其進行近似:
然而式\((5)\)本身做為互信息的下界而存在,因此\(I_{\text{MINE}}\)嚴格意義上講既不是互信息的上界也不是互信息的下界。不過這種方法可視為采用期望的蒙特卡洛近似來評估\(I_{\text{DV}}\),也就是作為互信息下界的無偏估計。已經有工作證明了這種嵌套蒙特卡洛估計器的收斂性和漸進一致性,但并沒有給出在有限樣本下的成立的界[8][9]。
在\(I_{\text{MINE}}\)思想的基礎之上,論文Deep Infomax[6]又向前推進了一步,認為我們無需死抱著信息的KL散度形式不放,可以大膽采用非KL散度的形式。事實上,我們主要感興趣的是最大化互信息,而不關心它的精確值,于是采用非KL散度形式可以為我們提供有利的trade-off。比如我們就可以基于\(p(x, y)\)與\(p(x)p(y)\)的Jensen-Shannon散度(JSD),來定義如下的JS互信息估計器:
這里\(x\)是輸入樣本,\(x\prime\)是采自\(p(x^{\prime}) = p(x)\)的負樣本,\(\text{sp}(z) = \log (1+e^x)\)是\(\text{softplus}\)函數。這里判別網絡\(f\)被優化來能夠區分來自聯合分布的樣本對(正樣本對)和來自邊緣乘積分布的樣本對(負樣本對)。
此外,噪聲對比估計(NCE)[10]做為最先被采用的互信息下界(被稱為“InfoNCE”),也可以用于互信息最大化:
對于Deep Infomax而言,\(I_{\text{JSD}}\)和\(I_{\text{InfoNCE}}\)形式的之間差別在于負樣本分布\(p(x^{\prime})\)的期望是套在正樣本分布\(p(x, y)\)期望的里面還是外面,而這個差別就意味著對于\(\text{DV}\)和\(\text{JSD}\)而言一個正樣本只需要一個負樣本,但對于\(\text{InfoNCE}\)而言就是一個正樣本就需要\(N\)個負樣本(\(N\)為batch size)。此外,也有論文[6]分析證明了\(I_{\text{JSD}}\)對負樣本的數量不敏感,而\(I_{\text{InfoNCE}}\)的表現會隨著負樣本的減少而下降。
3 互信息的變分上界(對應最小化)
我們接下來來看互信息的變分上界。我們常常通過最小化互信息的上界來近似地對互信息進行最小化。具體而言,按照是否需要編碼器,我們可以將互信息的下界分為兩類,而這兩個類別分別就對應了變分信息瓶頸的編碼項[4]和解耦表征學習[11]。
3.1 數據VS表征:變分信息瓶頸(編碼項)
對于互信息的第\((2)\)種表示法即\(I(X ; Y){=}\mathbb{E}_{p(x, y)}\left[\log \frac{p(y \mid x)}{p(y)}\right]\),我們已經知道邊緣分布\(p(y)=\int p(y \mid x)p(x)d x\)是難解的。但是限定在 \(X\)是數據,\(Y\)是表征 的場景下,我們能夠通過引入一個變分近似\(q(y)\)來構建一個可解的變分上界:
注意上面的\((1)\)是分子分母同時乘以\(q(y)\);\((2)\)是單獨配湊出KL散度;\((3)\)是利用KL散度的非負性(證明變分上下界的常用技巧)。最后得到的這個上界我們在生成模型在常常被稱為Rate[12](也就是率失真理論里的那個率),故這里記為\(R\)。當\(q(y)=p(y)\)時該上界是緊的,且該上界要求\(\log q(y)\)是易于計算的。該變分上界經常在深度生成模型(如VAE)[13][14] 被用來限制隨機表征的容量。在變分信息瓶頸[4]這篇論文中,該上界被用于防止表征攜帶更多與輸入有關,但卻和下游分類任務無關的信息(即對應其中的編碼項部分)。
3.2 表征VS表征:解耦表征學習
上面介紹的方法需要構建一個易于計算的編碼器\(p(y|x)\),但應用場景也僅限于在\(X\)是數據,\(Y\)是表征的情況下,當 \(X\)和\(Y\)都是表征 的時候(即對應解耦表征學習的場景)也會遇到我們在2.2中所面臨的問題,從而不能夠使用了。那么我們能不能效仿2.2中的做法,對導出的\(I_{\text{JSD}}\)和\(I_{\text{InfoNCE}}\)加個負號,從而將互信息最大化轉換為互信息最小化呢?當然可以但是效果不會太好。因為對于兩個分布而言,拉近它們距離的結果是確定可控的,但直接推遠它們距離的結果就是不可控的了——我們無法掌控這兩個分布推遠之后的具體形態,導致任務的整體表現受到負面影響。那么有沒有更好的辦法呢?
我們退一步思考:最小化互信息\(I(X; Y)\)的難點在于\(X\)和\(Y\)都是隨機表征,那么我們可以嘗試引入數據隨機變量\(D\),使得互信息\(I(X; Y)\)可以進一步拆分為\(D\)和\(X\)、\(Y\)之間的互信息(如\(I(D; X)\)以及\(I(D; Y)\)。已知三個隨機變量的互信息(稱之為Interation information[1])的定義如下:
聯立上述的等式\((1)\)和等式\((2)\),我們有:
對解耦表征學習而言,在概率圖模型中的結構化假設(V型結構)中\(X\)和\(Y\)共同為\(D\)的潛在因子,而\(X\)和\(Y\)互不影響(詳情可參見論文[11])。\(X\),\(Y\)和\(D\)對應的結構化概率關系如下圖所示:
反映在后驗概率上即表征后驗分布\(q\)滿足\(q\left(X \mid D\right)=q\left(X \mid D, Y\right)\),因此上述等式的最后一項就消失了:
這樣我們就有:
上述的\((1)\)是由于\(I(X; Y \mid D)=0\),\((2)\)是由于互信息的鏈式法則即\(I(D; X, Y)=I(D; Y) + I(D; X \mid Y)\)。
對\(I(X; Y)\)等價變換至此,真相已經逐漸浮出水面:我們可以可以通過最小化\(I\left(D ; X\right)\)、\(I\left(D ; Y\right)\),最大化\(I\left(D ; X, Y\right)\)來完成對\(I(X; Y)\)的最小化。其直觀的物理意義也就是懲罰表征\(X\)和\(Y\)中涵蓋的總信息,并使得\(X\)和\(Y\)共同和數據\(D\)相關聯。
基于我們在\(3.1\)、\(2.1\)中所推導的\(I(D; X)\)、\(I(D; Y)\)的變分上界與\(I(D; X, Y)\)的變分下界,我們就得到了\(I(X; Y)\)的變分上界:
直觀地看,上式地物理意義為使后驗\(q(x\mid D)\)、\(q(y\mid D)\)都趨近于各自的先驗分布(一般取高斯分布),并減小\(X\)和\(Y\)對\(D\)的重構誤差,直覺上確實符合表征解耦的目標。
4 總結
總結起來,互信息的所有上下界可以表示為下圖[2](包括我們前面提到的\(I_{\text{BA}}\)、\(I_{\text{UBA}}\)、\(I_{\text{DV}}\)、\(I_{\text{MINE}}\)、\(I_{\text{InfoNCE}}\)等):
圖中節點的代表了它們估計與優化的易處理性:綠色的界表示易估計也易于優化,黃色的界表示易于優化但不易于估計,紅色的界表示既不易于優化也不易于估計。孩子節點通過引入新的近似或假設來從父親節點導出。
參考
- [1] Cover T M. Elements of information theory[M]. John Wiley & Sons, 1999.
- [2] Poole B, Ozair S, Van Den Oord A, et al. On variational bounds of mutual information[C]//International Conference on Machine Learning. PMLR, 2019: 5171-5180.
- [3] Tishby N, Pereira F C, Bialek W. The information bottleneck method[J]. arXiv preprint physics/0004057, 2000.
- [4] Alemi A A, Fischer I, Dillon J V, et al. Deep variational information bottleneck[J]. arXiv preprint arXiv:1612.00410, 2016.
- [5] Belghazi M I, Baratin A, Rajeshwar S, et al. Mutual information neural estimation[C]//International conference on machine learning. PMLR, 2018: 531-540.
- [6] Hjelm R D, Fedorov A, Lavoie-Marchildon S, et al. Learning deep representations by mutual information estimation and maximization[J]. arXiv preprint arXiv:1808.06670, 2018.
- [7] Barber D, Agakov F. The im algorithm: a variational approach to information maximization[J]. Advances in neural information processing systems, 2004, 16(320): 201.
- [8] Rainforth T, Cornish R, Yang H, et al. On nesting monte carlo estimators[C]//International Conference on Machine Learning. PMLR, 2018: 4267-4276.
- [9] Mathieu E, Rainforth T, Siddharth N, et al. Disentangling disentanglement in variational autoencoders[C]//International conference on machine learning. PMLR, 2019: 4402-4412.
- [10] Oord A, Li Y, Vinyals O. Representation learning with contrastive predictive coding[J]. arXiv preprint arXiv:1807.03748, 2018.
- [11] Hwang H J, Kim G H, Hong S, et al. Variational interaction information maximization for cross-domain disentanglement[J]. Advances in Neural Information Processing Systems, 2020, 33: 22479-22491.
- [12] Alemi A, Poole B, Fischer I, et al. Fixing a broken ELBO[C]//International conference on machine learning. PMLR, 2018: 159-168.
- [13] Rezende D J, Mohamed S, Wierstra D. Stochastic backpropagation and approximate inference in deep generative models[C]//International conference on machine learning. PMLR, 2014: 1278-1286.
- [14] Kingma D P, Welling M. Auto-encoding variational bayes[J]. arXiv preprint arXiv:1312.6114, 2013.

在機器學習,尤其是涉及異構數據的遷移學習/聯邦學習中,我們常常會涉及互信息相關的優化項,我上半年的第一份工作也是致力于此。其思想雖然簡單,但其具體的估計與優化手段而言卻大有門道,我們今天來好好總結一下,也算是對我研一下學期一個收尾。為了解決互信息估計的的難解性,我們的方法是不直接對互信息進行估計,而是采用變分近似的手段,來得出互信息的下界/上界做為近似,轉而對互信息的下界/上界進行最大化/最小化。
浙公網安備 33010602011771號