高斯混合模型 GMM計算方法
高斯混合模型(Gaussian Mixture Model, GMM)是一種常用的概率模型,用于聚類和密度估計。它假設數據是由多個高斯分布混合生成的。GMM 的計算通常使用 期望最大化(Expectation-Maximization, EM)算法 來求解。
一、問題設定
我們有:
數據集:\(X = \{x_1, x_2, ..., x_N\}\),其中每個$ x_i \in \mathbb{R}^D$
假設有 \(K\) 個高斯分布
每個樣本由一個隱變量 \(z_i \in \{1,2,...,K\}\) 表示其來自哪個高斯分布
我們的目標是:
-
推斷出所有高斯分布的參數:均值 \(\mu_k\) 、協方差矩陣 \(\Sigma_k\) 、權重 \(\alpha_k\)
-
同時推斷每個樣本屬于哪一個高斯分布的概率
二、GMM 的聯合分布形式
GMM 的聯合分布為:
其中:
-
\(\alpha_k\) 是第\(k\) 個高斯分布的權重,滿足 \(\sum_{k=1}^K \alpha_k = 1\)
-
\(\mathcal{N}(x | \mu_k, \Sigma_k)\) 是多元高斯分布:
三、EM 算法流程(E 步 + M 步)
EM 算法是一個迭代優化算法,用于在存在隱變量的情況下進行最大似然估計。
? 初始化參數(Iteration 0)
隨機初始化以下參數:
- 高斯權重:\(\alpha_k^{(0)}\) ,滿足 \(\sum \alpha_k = 1\)
- 均值向量:\(\mu_k^{(0)} \in \mathbb{R}^D\)
- 協方差矩陣:\(\Sigma_k^{(0)} \in \mathbb{R}^{D \times D}\) ,正定對稱矩陣(可初始化為單位矩陣)
四、E 步:計算后驗概率(責任函數)
對于每個樣本 \(x_i\) ,計算其屬于第$ k$ 個高斯分布的后驗概率:
這個值表示:第 \(i\) 個樣本“屬于”第 \(k\) 個高斯分布的“責任”。
五、M 步:更新參數
定義:
- \(N_k = \sum_{i=1}^N \gamma(z_{ik})\):第 \(k\) 個高斯分布的“有效樣本數”
更新公式如下:
- 更新權重:
- 更新均值:
- 更新協方差矩陣:
六、迭代與終止條件
重復 E 步 和 M 步,直到滿足以下任意一種情況:
- 參數變化小于某個閾值(例如:\(||\theta^{(t+1)} - \theta^{(t)}|| < \epsilon\) )
- 對數似然變化很?。ㄈ纾?span id="w0obha2h00" class="math inline">\(\log p(X|\theta^{(t+1)}) - \log p(X|\theta^{(t)}) < \epsilon\) )
- 達到預設的最大迭代次數
七、對數似然函數(可用于評估收斂)
每一步可以計算當前參數下的對數似然:
這個值應隨著迭代逐漸上升并趨于穩定。
八、最終輸出
經過多次迭代后,我們得到一組最優參數估計:
- 最終的權重:\(\alpha_1, \alpha_2, ..., \alpha_K\)
- 最終的均值:\(\mu_1, \mu_2, ..., \mu_K\)
- 最終的協方差矩陣:\(\Sigma_1, \Sigma_2, ..., \Sigma_K\)
同時我們可以將每個樣本分配給最可能的類別(即選擇 $ \arg\max_k \gamma(z_{ik})$。
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from matplotlib.patches import Ellipse
# 高斯分布概率密度函數
def gaussian_pdf(X, mu, Sigma):
return multivariate_normal.pdf(X, mean=mu, cov=Sigma)
# E-step
def e_step(X, alpha, mu, Sigma):
N, K = X.shape[0], len(alpha)
gamma = np.zeros((N, K))
for i in range(N):
total = 0.0
for k in range(K):
gamma[i, k] = alpha[k] * gaussian_pdf(X[i], mu[k], Sigma[k])
total += gamma[i, k]
gamma[i, :] /= total
return gamma
# M-step
def m_step(X, gamma):
N, D = X.shape
K = gamma.shape[1]
alpha_new = np.zeros(K)
mu_new = np.zeros((K, D))
Sigma_new = [np.zeros((D, D)) for _ in range(K)]
for k in range(K):
Nk = gamma[:, k].sum()
alpha_new[k] = Nk / N
mu_new[k] = np.dot(gamma[:, k], X) / Nk
diff = X - mu_new[k]
Sigma_new[k] = np.dot(gamma[:, k] * diff.T, diff) / Nk
return alpha_new, mu_new, Sigma_new
# 繪圖函數
def plot_gmm(X, mu, Sigma, labels=None, title=""):
plt.figure(figsize=(8, 6))
if labels is not None:
colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
for k in range(len(mu)):
idx = (labels == k)
plt.scatter(X[idx, 0], X[idx, 1], c=colors[k], label=f"Cluster {k+1}", s=50)
else:
plt.scatter(X[:, 0], X[:, 1], c='gray', s=30)
for k in range(len(mu)):
plot_cov_ellipse(Sigma[k], mu[k], nstd=2, alpha=0.3)
plt.title(title)
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.legend()
plt.grid(True)
plt.pause(0.5)
def plot_cov_ellipse(cov, pos, nstd=2, color='blue', **kwargs):
eigvals, eigvecs = np.linalg.eigh(cov)
order = eigvals.argsort()[::-1]
eigvals, eigvecs = eigvals[order], eigvecs[:, order]
theta = np.degrees(np.arctan2(*eigvecs[:, 0][::-1]))
width, height = 2 * nstd * np.sqrt(eigvals)
ellip = Ellipse(xy=pos, width=width, height=height, angle=theta, color=color, **kwargs)
plt.gca().add_patch(ellip)
def main():
# 數據集
X = np.array([
[1.0, 2.0],
[1.5, 1.8],
[5.0, 8.0],
[8.0, 8.0],
[1.0, 0.6],
[9.0, 11.0]
])
N, D = X.shape
K = 2
# 初始化參數
np.random.seed(42)
alpha = np.random.rand(K)
alpha /= alpha.sum()
indices = np.random.choice(N, K, replace=False)
mu = X[indices]
Sigma = [np.eye(D) for _ in range(K)]
# 主訓練循環 + 可視化
max_iter = 20
tolerance = 1e-4
log_likelihood_prev = -np.inf
for it in range(max_iter):
gamma = e_step(X, alpha, mu, Sigma)
alpha, mu, Sigma = m_step(X, gamma)
log_likelihood = 0.0
for i in range(N):
ll = 0.0
for k in range(K):
ll += alpha[k] * gaussian_pdf(X[i], mu[k], Sigma[k])
log_likelihood += np.log(ll)
print(f"Iteration {it+1}, Log Likelihood: {log_likelihood:.4f}")
labels = np.argmax(gamma, axis=1)
plot_gmm(X, mu, Sigma, labels=labels, title=f"Iteration {it+1}")
if abs(log_likelihood - log_likelihood_prev) < tolerance:
print("Converged.")
break
log_likelihood_prev = log_likelihood
plt.show()
if __name__ == '__main__':
main()

浙公網安備 33010602011771號