簡介
神經常微分方程模型是一類新的深度神經網絡模型,不同于VGG、ResNet等這些有有限個離散的隱藏層構成的神經網絡模型。
例如殘差網絡、循環神經網絡解碼器、歸一化流等模型建立起復雜的變換,是通過一個變換(相對簡單的變換,比如ReLU變換)序列實現的。公式化表示為$$\mathbf h_{l+1} = \mathbf h_l + f_{l}(\mathbf h_l, \theta)$$ 論文中是使用公式$$\mathbf h_{t+1} = \mathbf h_t + f(\mathbf h_t, \theta_t)$$ 使用這種表示方式,便于擴展到常微分方程。
因為常微分方程$$\frac{d\mathbf h(t)}{dt} = f(\mathbf h(t), t, \theta)$$, 給定一個初始解,按照歐拉方法$$\mathbf h_{t+1} = \mathbf h_t + f(\mathbf h_t, t, \theta)$$可以迭代式求出變量\(\mathbf h\)的軌跡。可以看出常微分方程的歐拉迭代公式與離散神經網絡模型的迭代過程 是一致的。除了每步迭代更新公式不一樣。
| 類型 |
連續性 |
深度 |
| 常規神經網絡模型 |
離散 |
有限個層(N層) |
| 神經常微分方程模型 |
連續 |
無限,任意迭代次數 |
模仿常微分方程,用神經網絡來建模方程右側的函數,這個函數是關于要研究變量、模型參數的函數。
前向計算
采用常微分方程建模變量隨時間的變化,那么就可以采用常微分數值解法,比如一階歐拉法、四階龍格-庫塔法,有很多軟件集成了這些方法,統一用常微分方程求解器簡寫成ODESolver。
反向計算
關于參數\(\mathbb z\)梯度
通過對前向操作進行微分,這樣會變成與常規神經網絡模型一樣,需要計算關于參數和中間狀態的導數,這樣會增加內存的使用。因此引入一種新的方法,該方法將ODESlover看作黑盒,使用伴隨敏感度方法——adjoint sensitivity method 計算梯度。該方法引入一個新的變量\(\mathbf a(t)\),通過直接計算該變量的微分方程解來實現對變量\(\mathbf z(t)\)的梯度求解。
現在考慮優化一個標量損失函數\(\mathbf L(\dots)\), 其輸入是神經常微分方程模型最終結果\(\mathbf z_1\),也是ODESolver的結果?,F在假設已知\(t_0\)時刻變量的狀態值即\(\mathbf z_0 = \mathbf z(t_0)\),常微分方程參數\(\theta\), 損失函數\(L()\)也提前確定了,現在求解損失函數\(L\)關于狀態變量\(\mathbf z(t)\)和模型參數\(\theta\)的梯度。
伴隨敏感度方法,就是把損失函數\(L\)關于狀態變量\(\mathbf z(t)\)的梯度,稱為伴隨量,這個伴隨量命名為\(\mathbf a(t)\), 即\(\mathbf a(t) = \frac{\partial L}{\partial \mathbf z(t)}\)。這個量本身也是隨著時間\(t\)變化的。 這個伴隨量的動力學方程如下所示
\[\frac{d\mathbf a(t)}{dt} = -\mathbf a(t)^T \frac{\partial f(\mathbf z(t), t, \theta)}{\partial \mathbf z}
\]
給定初始值\(\mathbf a(t_1)\),由ODESolver就可以求出\(\mathbf a(t)\) 在各個時間點上的狀態值。使用ODESolver的前提條件是\(\frac{\partial f(\mathbf z(t), t, \theta)}{\partial \mathbf z}\) 是已知。
既然\(\mathbf a(t)\) 在各個時間點上的狀態值都已經求解出來了,那么就依次獲得了\(\frac{\partial L}{\partial z(t)}\) 在各個時間點\(t_1, t_{\sum_{1}^{N-1} \delta_{i}}, t_{\delta_{1}}, t_0\) 的狀態值(這里面的\(\delta_{i}\)表示第i個時間步的長度)。這個順序是逆序的,對應著反向傳播的梯度。
現在上公式來推導得出上述微分方程(為避免損耗腦細胞,可以跳過不看):
推導依賴1: 根據鏈式法則 \(\frac{\partial L}{\partial \mathbf z(t)} = \frac{\partial L}{\partial \mathbf z(t+1)} \frac{\partial z(t+1)}{\partial \mathbf z(t)}\), 因為求關于變量\(\mathbf z(t)\)的梯度,按照時間步,是從最大的時間開始求解再依次求較小的時間。因此這里是這樣的鏈式形式。 此外由于是求解連續梯度,因此將上面改寫成 \(\frac{\partial L}{\partial \mathbf z(t)} = \frac{\partial L}{\partial \mathbf z(t+\epsilon)} \frac{\partial \mathbf z(t+\epsilon)}{\partial \mathbf z(t)}\). 注意上面第一個公式t+1更多表達時間步索引,而第二個公式就表示t時間本身。
下面假設向量\(\mathbf z(t), \mathbf z(t+\epsilon)\) 的維度是d; 且向量對向量求導,采用分子布局方式。
可以按照下面來理解
\[L \leftarrow z(t+\epsilon) \leftarrow z(t)
\]
令\(y=z(t+\epsilon)\),此處\(z(t+\epsilon)\), \(z(t)\)都假設為行向量\(\mathbb {R}^{1 \times d}\),那么
\[\frac{\partial L}{\partial z} = (\frac{dL}{dz_1}, \frac{dL}{dz_2}, \cdots, \frac{dL}{dz_d})
\]
\[\frac{\partial L}{\partial y} = (\frac{dL}{dy_1}, \frac{dL}{dy_2}, \cdots, \frac{dL}{dy_d})
\]
\[\frac{\partial y}{\partial z} =
\left[
\begin{array}{}
\frac{dy_1}{dz_1} & \frac{dy_1}{dz_2} & \cdots & \frac{dy_1}{dz_d} \\
\frac{dy_2}{dz_1} & \frac{dy_2}{dz_2} & \cdots & \frac{dy_2}{dz_d} \\
\cdots && \\
\frac{dy_d}{dz_1} & \frac{dy_d}{dz_2} & \cdots & \frac{dy_d}{dz_d} \\
\end{array}
\right]
\]
行對應\(y^T\)的分量,列對應著\(x\)的分量,可以看出
\[\frac{\partial L}{\partial y}\frac{\partial y}{\partial z}
= (\frac{dL}{dy_1}, \frac{dL}{dy_2}, \cdots, \frac{dL}{dy_d})
\left[
\begin{array}{}
\frac{dy_1}{dz_1} & \frac{dy_1}{dz_2} & \cdots & \frac{dy_1}{dz_d} \\
\frac{dy_2}{dz_1} & \frac{dy_2}{dz_2} & \cdots & \frac{dy_2}{dz_d} \\
\cdots && \\
\frac{dy_d}{dz_1} & \frac{dy_d}{dz_2} & \cdots & \frac{dy_d}{dz_d} \\
\end{array}
\right] = (\sum\frac{dL}{dy_i}\frac{dy_i}{dz_1}, \sum\frac{dL}{dy_i}\frac{dy_i}{dz_2}, \cdots, \sum\frac{dL}{dy_i}\frac{dy_i}{dz_d})
\]
因此
\[\frac{\partial L}{\partial z} = \frac{\partial L}{\partial y}\frac{\partial y}{\partial z}
\]
推導依賴2: 泰勒公式
既然知道了變量\(\mathbf z(t)\) 的一階導數,那么很自然由泰勒公式就直接得出 \(\mathbf z(t+\epsilon) = \mathbf z(t) + \epsilon \frac{d\mathbf z(t)}{dt} + o(\epsilon^2)\)
推理依賴3: 導數的定義
\[\begin{aligned}
\frac{\mathbf a(t+\epsilon) - \mathbf a(t)}{\epsilon}
&= \frac{\frac{\partial L}{\partial \mathbf z(t+\epsilon)} - \frac{\partial L}{\partial \mathbf{z}(t)}}{\epsilon} \\
&= \frac{\frac{\partial L}{\partial \mathbf z(t+\epsilon)} - \frac{\partial L}{\partial \mathbf z(t+\epsilon)} \frac{\partial \mathbf z(t+\epsilon)}{\partial \mathbf z(t)} }{\epsilon} \\
&= \frac{\frac{\partial L}{\partial \mathbf z(t+\epsilon)}}{\epsilon}[I - \frac{\partial \mathbf z(t+\epsilon)}{\partial \mathbf z(t)}] \\
&= \frac{\frac{\partial L}{\partial \mathbf z(t+\epsilon)}}{\epsilon}[I - \frac{\partial }{\partial \mathbf z^(t)}(\mathbf z(t) + f(\mathbf z(t), t, \theta)\epsilon + o(\epsilon^2))] \\
&= \frac{\frac{\partial L}{\partial \mathbf z(t+\epsilon)}}{\epsilon}[I - (I + \frac{\partial f(\mathbf z(t), t, \theta)}{\partial \mathbf z(t)}\epsilon)] \\
&= \frac{\frac{\partial L}{\partial \mathbf z(t+\epsilon)}}{\epsilon}[- \frac{\partial f(\mathbf z(t), t, \theta)}{\partial \mathbf z(t)}\epsilon ] \\
&= -\frac{\partial L}{\partial \mathbf z(t+\epsilon)}\frac{\partial f(\mathbf z(t), t, \theta)}{\partial \mathbf z(t)}
\end{aligned}
\]
兩邊同時取關于\(\epsilon\)趨近于0的極限,便可以得到\(\frac{d \mathbf a(t)}{dt} = -\mathbf a(t)\frac{\partial f(\mathbf z(t), t, \theta)}{\partial \mathbf z(t)}\)
設有時間點\(t_1, t_2, \cdots, t_{N-1}, t_{N}\) , 知道初始時間點\(t_N\)上損失值關于變量\(\mathbf z(t)\)的梯度,即\(\mathbf a(t_N) = \frac{\partial L}{d\mathbf z(t_N)}\)。那么由ODESolver可以,依次求解出\(\mathbf a(t_{N-1}), \mathbf a(t_{N-2}), \cdots, \mathbf a(t_1)\).
到這里,關于變量\(\mathbf z(t)\)的神經微分方程網絡,其各個時間點前向狀態值,及其反向梯度值都完全求解出來了。
關于參數\(\theta\)和\(t\)的梯度
將參數\(\theta, t\)關于時間t的微分方程寫作為
\[\frac{\partial \theta(t)}{\partial t}=\mathbb 0, \frac{\partial t(t)}{\partial t}= 1
\]
現在將狀態變量\(\mathbb z(t)\)進行擴充,添加參數\(\theta, t\),那么其對應的微分方程為
\[\frac{d[\mathbb z, \theta, t]}{dt} = f_{aug}[\mathbb z, \theta, t] = [f(\mathbb z,\theta,t), \mathbb 0, 1]
\]
它們對應的伴隨量記為
\[a_{aug} = [a, a_{\theta}, a_{t}],
a_{\theta}(t) = \frac{dL}{d\theta(t)},
a_{t}(t) = \frac{dL}{d t(t)}
\]
對擴充的\(\mathbb z\), 其伴隨量a_{aug}的導數可以推導得到為
\[\frac{da_{aug}(t)}{dt} = -a_{aug}(t) \frac{\partial f_{aug}}{\partial [\mathbb z, \theta, t]}(t)=-[a\frac{\partial f}{\partial \mathbb z}, a\frac{\partial f}{\partial \theta}, a\frac{\partial f}{\partial t}]
\]
這個可以下面公式推導得出
\[\frac{\partial f_{aug}}{\partial [\mathbb z, \theta, t]}(t)=
\left [
\begin{array}{}
\frac{\partial f}{\partial \mathbb z}, \frac{\partial f}{\partial \theta}, \frac{\partial f}{\partial t} \\
0, 0, 0 \\
0, 0, 0 \\
\end{array}
\right ]
\]
向量求導布局
矩陣向量求導的本質就是多元函數求導,僅僅是把因變量分量對自變量分量的求導結果排列成了向量矩陣的形式,為未來方便表達與計算而已。但是在矩陣向量求導中,其求導結果會因向量和矩陣的形式而導致結果不唯一,為此引入了求導布局的概念。
比如參考1中,提到m維列向量\(\mathbf y\)對n維列向量\(\mathbf x\)求導。這兩個向量求導,一共有m個標量對n個標量分別求導,共計mn個求導。求導的結果就可以排列為一個矩陣。
- 如果是分子布局,則矩陣的第一個維度以分子的維度為準,那就是m行n列的矩陣即
\[\left[
\begin{array}{}
\frac{dy_1}{dx_1} & \frac{dy_1}{dx_2} & \cdots & \frac{dy_1}{dx_n} \\
\frac{dy_2}{dx_1} & \frac{dy_2}{dx_2} & \cdots & \frac{dy_2}{dx_n} \\
\cdots && \\
\frac{dy_m}{dx_1} & \frac{dy_m}{dx_2} & \cdots & \frac{dy_m}{dx_n} \\
\end{array}
\right]
\]
- 如果是分母布局,則求導的結果矩陣的第一個維度以分母的維度為準,那就是n行m列的矩陣即
\[\left[
\begin{array}{}
\frac{dy_1}{dx_1} & \frac{dy_2}{dx_1} & \cdots & \frac{dy_m}{dx_1} \\
\frac{dy_1}{dx_2} & \frac{dy_2}{dx_2} & \cdots & \frac{dy_m}{dx_2} \\
\cdots && \\
\frac{dy_1}{dx_n} & \frac{dy_2}{dx_n} & \cdots & \frac{dy_m}{dx_n} \\
\end{array}
\right]
\]
可以看出按分子布局和按分母布局,得到的求導矩陣是互為轉置的。
比如參考2中,假設某函數從\(\mathbf f: \mathbb R^n \rightarrow \mathbb R^m\),從n維向量\(\mathbf x \in \mathbb R^n\)映射到m維向量\(\mathbf f(\mathbf x) \in \mathbb R^m\), 求導數為
\[\left[
\begin{array}{}
\frac{df_1}{dx_1} & \frac{df_1}{dx_2} & \cdots & \frac{df_1}{dx_n} \\
\frac{df_2}{dx_1} & \frac{df_2}{dx_2} & \cdots & \frac{df_2}{dx_n} \\
\cdots && \\
\frac{df_m}{dx_1} & \frac{df_m}{dx_2} & \cdots & \frac{df_m}{dx_n} \\
\end{array}
\right]
\]
參考