吳恩達深度學習課程一:神經網絡和深度學習 第二周:神經網絡基礎(三)梯度下降法
此分類用于記錄吳恩達深度學習課程的學習筆記。
課程相關信息鏈接如下:
- 原課程視頻鏈接:[雙語字幕]吳恩達深度學習deeplearning.ai
- github課程資料,含課件與筆記:吳恩達深度學習教學資料
- 課程配套練習(中英)與答案:吳恩達深度學習課后習題與答案
本篇為第一課第二周,是2.4到2.6部分的筆記內容。
可能會發現跳過了幾節,實際上是因為課程中的順序為講解結構后再分別講解基礎的順序,筆記為了便于理解,便改為先講解基礎后可以順暢的理解結構的順序。跳過的節數會在之后再講到。
本周的課程以邏輯回歸為例詳細介紹了神經網絡的運行,傳播等過程,其中涉及大量機器學習的基礎知識和部分數學原理,如沒有一定的相關基礎,理解會較為困難。
因為,筆記并不直接復述視頻原理,而是從基礎開始,盡可能地創造一個較為絲滑的理解過程。
首先,經過之前的第二部分內容學習,我們了解了機器學習中的分類是什么以及邏輯回歸算法,而本篇延續邏輯,進行傳播部分的講解。這一部分則涉及較多的數學基礎,同樣會在筆記中進行補充,本篇更偏向于補充數學基礎。
依舊以一個問題來引入,現在我們知道了邏輯回歸如何得到分類結果,但想要預測結果更準確,擬合效果更好,我們便要尋找使結果最優的權重 \(w\) 和偏置 \(b\) ,那如何得到這兩個數呢?
我們以此開始本篇筆記的內容。
1.導數
要尋找使結果最優的權重 \(w\) 和偏置 \(b\),便要以一種合適的規則來對二者不斷地更新,得到最優的結果,什么樣的更新方法最好最合適,這便涉及到效率問題,而當我們想提高效率時,總是離不開數學的。 導數便是這一部分的基礎。
1.1 什么是導數?
老樣子先擺一個導數的概念:
導數是一個基本的數學概念,主要用于分析函數的變化率。簡單來說,導數表示一個函數在某一點的瞬時變化率,或者說是該點切線的斜率。
斜率的概念我們不陌生,兩點確定一條直線,用兩點縱橫坐標的差作除法,就得到這了這條直線的斜率。
為了便于理解的同時不涉及太多的數學內容,我們先不擺導數的公式,用課程中的例子來進行說明:

如圖,這是一個函數圖像,我們標記了當 \(a_1=2和a_2=2.001\) 的兩點,而這兩點對應的函數值\(f(a_1)=6和f(a_2)=6.003\) , 連接兩點和其對應的移動距離便構成圖中的小三角形,我們不難得到這樣的結論:
當我們把 \(a\) 從2向上移動了0.001時,其對應的函數值增加了0.003,即這條直線的斜率為3。
而剛剛在概念里提到,導數是一個點切線的斜率,問題來了,當我們只知道切點的時候,如何才能的得到這個點切線的斜率?
這便涉及微積分中極限的概念,我們現在再來看導數的公式:
簡單解釋一下:
- \(f'(a)\) 就代表 \(f(a)\) 的一階導數。
- \(lim_{h \to 0}\) 代表 \(h\) 是一個無限趨近于0,非常非常小的增量,比例里的0.001還要小的多,無法度量。
只有在這樣的情況下,我們連接 \(f(a+h) 和 f(a)\) 得到的才是 \(a\) 點的切線,如圖所示(例):

知道了如何得到導數后,現在我們回到剛剛的兩個 \(a\) 點上,結合剛剛的概念,我們便得到如下結論:
假設0.001可以當作 \(h\) 時,函數 \(f(a)=3a\) 在\(a=2\) 這一點的(一階)導數值為3。
可以發現我們強調了導數值是針對一個具體的點所說的,而斜率往往針對一條線,這便是二者在概念上的其中一點不同。
繼續看\(f(a)=3a\) 這個函數,我們會發現,其實這個函數在任何一點上的導數值都是3,因為它的任何一點的切線都是函數圖像本身,再說斜率,在這個函數里,無論取相隔多遠的兩個點,其產生的函數值的變化都是都是3倍。
對這個簡單的函數,我們無需多想,便可以得出 $$f'(a)=3$$
這代表 \(f(a)\) 在任何一點的導數值都等于3,不同于剛剛只針對一個點的說法,\(f'(a)\) 描述了原始函數在每一點的瞬時變化率,我們便稱之為導函數。
1.2 如何得到導函數?
剛剛通過一個導函數恒定的例子來解釋導數,但實際上想得到導函數,找到函數變化的規律不會這么簡單,我們來看一個更復雜的情況:

不再重復贅述過程。
經過計算,我們可以發現,函數 \(f(a)=a^2\) 在\(a=2\) 這一點的(一階)導數值為4,在\(a=5\) 這一點的(一階)導數值為10。
這說明導函數和函數一樣,也需要變量。
先擺結論:
這是查表得到的,不多描述原理,簡單來說,各種基本函數有其規律可以總結出其導函數,而復雜函數也就是基本函數的組合,同樣可以通過運算得到其導函數。
在這里,我們只需要知道,我們用導函數得到函數在各個點的瞬時變化率,而我們只要查表得到導函數即可。
1.3 如何得到多變量函數的導函數?
我們已經知道如何得到單變量的導函數了,現在再擴展一下,當變量從單個變為兩個甚至多個的時候呢?
實際上,多變量的函數才是更符合神經網絡的,因為我們大多情況下都不會只輸入一種特征,而往往是多種特征來作為輸入。
現在我們來看這樣一個函數:
這個函數有兩個變量,而且這兩個變量互不相關,如何求這樣同時有兩個量變化的函數的導函數?
現在擺一個偏導數的概念:
偏導數是多變量函數最常見的求導方式,適用于多變量函數中,求某一特定變量對函數的影響。它計算的是在其他變量固定不變時,函數相對于某一變量的變化率。
計算偏導數的步驟與普通導數類似,只是要記住在求某個變量的偏導數時,其他變量視為常數。
我們來進行一下 \(f(x,y)=3x+2y^2\) 分別對變量 \(x和y\) 的偏導數。
看一下偏導數的表示方式及其對應的結果:
現在,我們得到了 \(f(x,y)對x和y的偏導數\) ,這兩個偏導數便代表了兩個變量分別對函數的影響,其變化率。
這一節其實全部是在補充數學方面的基礎,在了解了導數方面的基礎后,這些知識到底怎么幫助我們訓練權重\(w\)和偏置\(b\) 呢?我們展開下一部分。
2.梯度下降法
2.1 什么是梯度?
我們知道了怎么求多變量的導函數,現在再看一下\(f(x,y)=3x+2y^2\)的圖像:

我們已經通過偏導數知道了在函數任意一點上,該點在 \(x\) 方向和 \(y\) 的變化率,簡單舉例解釋,在\((2,2)\) 這個點上,函數在 \(x\) 方向和 \(y\) 的變化率為\(3和8\) 。
于是我們便知道了在這一點上,函數在 \(y\) 方向變化的比在 \(x\) 方向上的變化更快。
通俗來說,把圖像看作山坡,在這一點的位置上, \(y\) 方向比 \(x\) 方向更“陡”。
可問題又來了,不同于單變量函數只能一條線上移動,當變量增加為兩個,圖像也由線擴展為面,同樣以山坡為例,我們不僅可以單獨改變 \(x或者y\) ,也就是“直著走”,我們也可以同時改變\(x和y\),來“斜著走”。
在這樣的情境下,函數圖像就是真正的山坡,而我們站在某一點的位置上,只要不“穿模”(走到圖像外),便可以向山坡的任何方向移動。
由此,我們引入梯度的概念:
梯度是由函數的偏導數組成的向量,每個分量表示該函數在對應變量方向的變化速率。梯度的方向是函數值變化最快的方向,而梯度的大小則表示在該方向上的變化率的大小。
依舊以\(f(x,y)=3x+2y^2\) 來說明這個概念,已知梯度向量是由函數的偏導數組成的向量,那么本題中的梯度向量就是:
我們現在站在\((2,2)\) 上,即:
我們再通過敘述理解一下這兩個式子:
首先,我們將函數對兩個變量的偏導數組合成立一個二維向量,這個向量的每個元素代表都代表一個變量在該方向的變化率。
而當我們代入具體位置時,便得到對應的梯度向量,這兩個變化率組合成一個二維向量,這個向量指向的便是變化最快的方向。
我個人的理解來說,針對兩個變量的函數,它在x,y的方向的偏導就像直角三角形的兩條直角邊,得到的斜邊最長,其指向的方向就是變化最快的方向,而換一個方向,這個角就不一定是直角,經過映射就會讓直角邊變短
總而言之,梯度指向的是局部變化速率最大的位置。
如下圖所示,\((3,8)\)的方向即是在\((2,2)\)時,移動時讓函數值變化最大的方向。也就是“最陡的方向”。

這里我們通過例子說明的梯度是什么,實際上如果還想更深入的理解,可以再了解一下方向導數,如何理解梯度(方向導數的最大值) 這為up主的視頻可以很好的講解。
但是涉及到向量計算,我們只在這里理解梯度的含義,便于服務神經網絡即可。
2.2 什么是梯度下降法
回到神經網絡里來,我們之前在第一周的課程里提過,監督學習就是學生可以在考試-對答案-考試的重復中不斷提高自己的成績。
而要不斷提升模型效果,我們就要讓每一次對答案時,我們的作答和答案的差別越來越小,只有這樣,我們才能達到訓練的效果。
也就是說,訓練權重\(w\)和偏置\(b\),可以看做一個輸出和標簽的差別最小化問題。
我們把計算所有輸出和標簽的總差別的函數叫做成本函數,我們在了解完梯度下降法后再聊它們。
而怎么最快的找到最佳的參數,讓成本函數的結果最小,其實從梯度下降法的名字就可以看出來了。
我們通過梯度得知了函數值變化最快的方向,而又根據導數本身的定義我們又知道了梯度的方向其實是上升的方向。
那么,當我們想要讓函數值最小時,只要順著梯度的反方向走,即梯度下降的方向走,不就能最快的達到最小值,從山坡下到平地了嗎?
現在給出梯度下降法的概念:
梯度下降法是一種優化算法,常用于在多維空間中找到一個函數的最小值。簡而言之,它就是通過沿著梯度的反方向(即最快下降的方向)一步步更新參數,直到找到函數的最低點。
至此,我們知道了訓練權重\(w\)和偏置\(b\) 的方法,實際上,梯度下降法還有一些可能導致的問題,它在實際的傳播過程中又如何應用,我們便在下一篇進行邏輯回歸傳播過程的具體講解。

浙公網安備 33010602011771號