2.4小時教你入門機器學習算法
現在很流行什么24小時精通xxx,我覺得24小時太久,不如試試2.4小時。
我會嘗試用簡單的說法解釋復雜的事物。
人類一直在探索宇宙的真理,或者說探索宇宙的公式。周文王做《周易》也無非就是通過各種卦象推算出預言。再說直白一點,就是尋找y=f(x1,x2,x3,…)這樣的公式。通過x1,x2,x3…去推導出結果y。
那么我私人定義一下機器學習的定義,通俗的說,就是讓機器自己通過大量的樣本(x1,y1)(x2,y2)…(xn,yn) (俗稱大數據big data),推導出這個f(x1,x2,x3)的公式。注意,機器學習只會根據樣本推導出最接近真相的公式,但是不能回答為什么,不能解釋原理。
即使你從未接觸過機器學習,你肯定也知道這是一門非常復雜的學科。那么入門顯然要從最簡單的講起,這篇文章不會教你調用什么API,這樣就變成了調包俠,我們從最簡單的原理講起。
作為一個民間科學家,首先打開百度百科,搜索一下“方差公式”,可以看到以下信息:

我們的故事就要從方差公式說起,即使上學期間是混日子的人都知道這個公式。事實上,我不能解釋這個公式為什么這么定義,就當是大家公認的吧。這篇文章的例子都會遵循這個方差公式來定義穩定性。事實上,我們也可以自己定義自己的方差公式來定義自己系統的穩定性,在機器學習中,它的名字叫代價函數。
y=f(x1,x2,x3,…)是復雜的多維因子函數,作為入門教程,我們把它簡化成只有一個輸入因子的函數y=f(x)。我們的目的就是根據樣本求出這個f()。
作為一個民間科學家,再打開百度百科,搜索一下“泰勒級數”

這里我來更加通俗的解釋一下泰勒級數,我知道入門教程閱讀者不會去學習太深的數學基礎,所以我只會很直白的解釋,說白了,泰勒就是想把任意的y=f(x)轉換成好計算的多項式,本質就是把一條任意規律的線解釋成多項式子的和。因為一條線想等同于另外一條線,只要他們在任意x點的導數相同,導數的導數相同,導數的導數的導數相同,導數的導數的導數的導數相同….那么他們就相同,因此我們可以把任意的y=f(x)分解成a*x^n + b*x^(n-1) + ……+ c*x^(1) + d 這樣的多項式,問題就可以進一步簡化。
作為入門教程,我們把難度收斂到多項式中最簡單的一項:c*x^(1),什么?你說d是最簡單的一項?那好吧,那就收斂到第二簡單的一項。也就是y=ax這樣的一條直線。
也就是說任何復雜的公式,最終都是由若干的y= θx或者y= θx+b組成。從幾何上來說,這就是一條直接,因此最簡單的機器學習就是基于線性的回歸。
干貨:
假設我們擁有若干y=θx的樣本,目的就是求出θ。那么我們就應該力求最一個θ使得樣本的方差公式的結果最小,這樣就最穩定,最符合樣本的真相。注意,這很重要,這就是貫穿機器學習的核心方法論,讓代價函數的值最小。
為了求出θ,我們定義一個關于θ的函數: z= J(θ).由此可知,這個函數應該是這樣的:

也就是方差公式。得到合適的θ值讓z最小則成為了我們新的目標。
從習慣方便理解的方向,轉換成對θ的公式。我們把這個公式轉換習慣的y = f(x)公式。我們可以把x和y替換成a和b,把θ替換成x。想象成新的公式就應該是:

格式化寫上面這個公式,費了我很多力氣,所以接下來,還是讓靈魂畫手上場吧。

這個函數基本上是就是J(θ)的圖形,嗯,一個U形(當變量變成2維之后,你可以想象就是一個立體的碗的形狀,如果是3維就是,emo,我也想象不出來了),那么我們的目的就是要求出一個θ值讓曲線的值最小最低。
有一個公式,可以讓θ逐漸逼近最低點,這個過程,在機器學習中稱為梯度下降法。假設我們初始設置θ值為0,然后讓θ值變成θ值減去一個偏移,直到這個J(θ)的導數成為0,那么就找到了最低點。

我們從下圖來理解一下:

假設θ在最低值的左側,對這個點求導數,也就是切線,可以得到一個很小的?z和?θ,事實上他們的比就是這個點的導數,那么這是一個負數,因此
就會讓θ變大,然后右移,同理,我們取值在右側,就會讓θ變小,然后左移,當α比較小的時候,就會逐漸逼近最低點,直到導數為0 或者無限接近0的時候,就是最低點了。
上面我們得出
那么也就是
talk is cheap, show you the code.
首先我們定義一些點的用例case
@Data
@NoArgsConstructor
@AllArgsConstructor
public class Case {
private double x;
private double y;
public static Case of(double x, double y) {
return new Case(x, y);
}
}
我們假定有一個y=f(x)的最簡單的線性函數是y=θx,這里θ我們設置成一個隨便設置的變量,可以得到原始函數
private static final double REAL_SITA = 12.52D;
/**
* 原始函數
*
* @param x x值
* @return y值
*/
public double orgFunc(double x) {
return REAL_SITA * x;
}
然后構造1000個樣本用例,并且假設用例不是很標準,有5%的誤差,這樣更加真實
private static final int CASE_COUNT = 1000;
private final List<Case> CASES = new ArrayList<>();
/**
* mock樣本
*/
private void makeCases() {
List<Case> result = new ArrayList<>();
for (int i = 0; i < CASE_COUNT; i++) {
double x = ThreadLocalRandom.current().nextDouble(-100D, 100D);
double y = orgFunc(x);
boolean add = ThreadLocalRandom.current().nextBoolean();
/* 為了仿真,y進行5%以內的抖動 */
int percent = ThreadLocalRandom.current().nextInt(0, 5);
double f = y * percent / 100;
if (add) {
y += f;
} else {
y -= f;
}
Case c = Case.of(x, y);
result.add(c);
}
/* sort排序 */
result.sort((o1, o2) -> {
double v = o1.getX() - o2.getX();
if (v < 0) {
return -1;
} else if (v == 0) {
return 0;
} else {
return 1;
}
});
this.CASES.clear();
this.CASES.addAll(result);
}
由于之前我們得到了求θ的導數公式,因此我們可以這樣計算導數
/**
* J(θ)的導數
*
* @param sita θ
* @return 導數值
*/
public double derivativeOfJ (double sita) {
double count = 0.0D;
for (Case c : CASES) {
double v = sita * c.getX() * c.getX() - c.getX() * c.getY();
count += v;
}
return count / CASE_COUNT;
}
最后我們設置一個小一點的α,然后假設 θ初始值為0,讓 θ自己不停的去修正自己,得到最后的 θ值
/**
* J(θ)的導數
*
* @param sita θ
* @return 導數值
*/
public double derivativeOfJ (double sita) {
double count = 0.0D;
for (Case c : CASES) {
double v = sita * c.getX() * c.getX() - c.getX() * c.getY();
count += v;
}
return count / CASE_COUNT;
}
/**
* 梯度下降
*/
public double stepDownToGetSita() {
double alpha = 0.0001D;
/* 假設θ從0開始遞增 */
double sita = 0D;
while (true) {
double der = derivativeOfJ(sita);
/* 由于計算機double有精度丟失,當導數der無限趨于0,則認為等于0 */
if (Math.abs(der) < 0.000001) {
return sita;
}
/* 不然就修正θ */
sita -= alpha * der;
}
}
最后我們寫一個junit來簡單測試一下,看看能不能模糊計算出我們預先設置的θ值,和REAL_SITA比精度能到多少
@Test
public void test() {
LineFunction lineFunction = new LineFunction();
lineFunction.makeCases();
double x = lineFunction.stepDownToGetSita();
log.debug("x:{}", x);
}
最后輸出結果

可見,預期值是12.52,我們計算出來是12.520841
什么?你說結果不是很精確?嗯,第一,這只是一個POC,第二,樣本數量太少,第三,樣本我進行了5%的模糊化,導致和本身的函數確實有誤差。
嗯,機器學習的入門就是這樣了,往后算子的復雜度會越來越高,但是原理就是這樣。
浙公網安備 33010602011771號