核函數:讓支持向量機從“青銅”變“王者”
在機器學習領域,支持向量機(SVM)是一種強大的分類算法,而核函數則是其核心組件之一。
核函數的本質是一個「空間映射工具」。
當原始數據在低維空間中線性不可分時(如環形、月牙形數據),核函數能將數據隱式地映射到更高維的特征空間,使得在高維空間中數據變得線性可分,同時避免直接計算高維特征的爆炸性計算量(即"核技巧")。
本文將介紹核函數的作用、常用類型,并通過 scikit-learn 的實際案例展示其效果。
1. 核心作用
在許多實際場景中,數據往往不是線性可分的,直接應用線性分類器效果不佳。
核函數通過巧妙的數學變換,將數據映射到一個更高維度的空間,在這個空間中,原本線性不可分的數據可能變得線性可分。
這樣,SVM 就可以在這個高維空間中找到一個最優的超平面來進行分類。
核函數的核心優勢在于它避免了顯式地計算高維空間中的坐標,而是通過核技巧直接計算映射后的內積。
這種方法不僅提高了計算效率,還減少了內存消耗,使得 SVM 能夠高效地處理大規模數據集。
總的來說,核函數主要為了解決下面幾個核心問題:
- 避免直接計算高維空間中的內積(維度爆炸問題,核技巧通過數學變換簡化計算)
- 為非線性數據提供高效的分類解決方案
- 通過不同核函數的選擇,適應多樣化的數據分布特征
2. 常用核函數
一般我們在訓練SVM模型時,常用的核函數主要有4種:
2.1. 線性核函數
線性核函數(Linear Kernel)是最簡單的核函數,其公式為:$ K(x,y)=x^Ty $。
它適用于線性可分的數據集。
如果數據在原始空間中已經可以通過一個線性超平面進行分類,那么線性核函數是一個高效且簡單的選擇。
2.2. 多項式核函數
多項式核函數(Linear Kernel)的公式為:\(K(x,y)=(x^Ty+c)^d\) 。
其中, $ c $是一個常數項, $ d $是多項式的度數。
多項式核函數可以通過調整 $ d $ 和 $ c $ 的值來增加模型的復雜度,從而更好地擬合非線性數據。
它適用于數據具有多項式關系的場景。
2.3. 徑向基核函數(RBF)
RBF 核函數(Radial Basis Function)是SVM中最常用的核函數之一,其公式為:\(K(x,y)=exp(-\frac{||x-y||^2}{2\sigma^2})\) 。
其中, $ \sigma $是控制高斯分布寬度的參數。
RBF 核函數能夠將數據映射到無窮維空間,具有很強的靈活性,適用于大多數非線性問題。
它對數據的局部變化非常敏感,能夠很好地捕捉數據的復雜結構。
2.4. sigmoid 核函數
Sigmoid 核函數的公式為:$ K(x,y)=\tanh(ax^Ty+b) $。
其中, $ a $ 和$ b $ 是參數。
Sigmoid 核函數類似于神經網絡中的激活函數,它在某些特定的非線性問題中表現良好,但使用時需要謹慎調整參數,以避免過擬合或欠擬合。
3. 核函數實踐
為了展示核函數的作用,我們使用scikit-learn庫構造一個測試數據集,并比較不同核函數的效果。
首先,使用make_moons函數生成一個非線性可分的數據集。
這個數據集包含兩個半月形的類別,用線性分類器很難進行區分。
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
# 生成非線性可分的數據集
X, y = make_moons(n_samples=200, noise=0.1, random_state=42)
# 繪制數據集
plt.scatter(X[:, 0], X[:, 1], c=y, cmap='viridis')
plt.title("非線性可分數據集")
plt.show()

接下來,我們使用scikit-learn的SVC(支持向量分類器)分別應用線性核函數、多項式核函數、RBF 核函數和 Sigmoid 核函數,并比較它們的效果。
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# 劃分訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 線性核函數
svm_linear = SVC(kernel='linear')
svm_linear.fit(X_train, y_train)
y_pred_linear = svm_linear.predict(X_test)
print("線性核函數的準確率:", accuracy_score(y_test, y_pred_linear))
# 多項式核函數
svm_poly = SVC(kernel='poly', degree=3)
svm_poly.fit(X_train, y_train)
y_pred_poly = svm_poly.predict(X_test)
print("多項式核函數的準確率:", accuracy_score(y_test, y_pred_poly))
# RBF 核函數
svm_rbf = SVC(kernel='rbf', gamma='scale')
svm_rbf.fit(X_train, y_train)
y_pred_rbf = svm_rbf.predict(X_test)
print("RBF 核函數的準確率:", accuracy_score(y_test, y_pred_rbf))
# Sigmoid 核函數
svm_sigmoid = SVC(kernel='sigmoid')
svm_sigmoid.fit(X_train, y_train)
y_pred_sigmoid = svm_sigmoid.predict(X_test)
print("Sigmoid 核函數的準確率:", accuracy_score(y_test, y_pred_sigmoid))
運行結果:
線性核函數的準確率: 0.8833333333333333
多項式核函數的準確率: 0.95
RBF 核函數的準確率: 0.9833333333333333
Sigmoid 核函數的準確率: 0.6666666666666666
從結果來看,線性核函數在這種非線性數據集上的表現一般,而 RBF 核函數和多項式核函數取得了較好的效果,
Sigmoid 核函數表現最差,它僅在特定場景(如模擬神經網絡)可能有用,其他場景下效果普遍較差。
為了更加直觀,我們可以把四種核函數的分類結果繪制出來:
plt.figure(figsize=(20, 10))
models = [svm_linear, svm_poly, svm_rbf, svm_sigmoid]
for i, model in enumerate(models, 1):
plt.subplot(2, 2, i)
h = 0.02 # 網格間隔
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, alpha=0.8, cmap="viridis")
plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors="k", cmap="viridis")
plt.tight_layout()
plt.show()

4. 總結
核函數是支持向量機中不可或缺的組成部分,它通過將數據映射到高維空間,解決了線性不可分問題,使 SVM 能夠處理復雜的非線性分類任務。
在實際應用中,選擇合適的核函數至關重要。
線性核函數適用于線性可分數據,多項式核函數和 RBF 核函數則更適合處理非線性問題。
通過基于scikit-learn的實驗,我們直觀地看到了核函數在不同數據集上的效果差異。
在實際項目中,建議根據數據的特點和需求選擇合適的核函數,并通過交叉驗證等方法調整參數,以達到最佳的分類效果。

浙公網安備 33010602011771號