机器学习基础算法29-EM实践
文章目录1.EM算法的实现2.EM算法估算GMM的参数3.GMM调参:covariance_type4.EM算法无监督分类鸢尾花数据5.GMM/DPGMM(贝叶斯高斯分布)比较1.EM算法的实现import numpy as npfrom scipy.stats import multivariate_normalfrom sklearn.mixture import GaussianMixtur
·
1.EM算法的实现
import numpy as np
from scipy.stats import multivariate_normal
from sklearn.mixture import GaussianMixture
from mpl_toolkits.mplot3d import Axes3D
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import pairwise_distances_argmin
mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False
if __name__ == '__main__':
style = 'myself'
np.random.seed(0)
# 构造均值与方差
mu1_fact = (0, 0, 0)
# numpy.diag()返回一个矩阵的对角线元素,或者创建一个对角阵( diagonal array.)
cov1_fact = np.diag((1, 2, 3))
# 依据指定的均值和协方差生成高斯分布数据
data1 = np.random.multivariate_normal(mu1_fact, cov1_fact, 400)
# print(data1)
mu2_fact = (2, 2, 1)
cov2_fact = np.array(((1, 1, 3), (1, 2, 1), (0, 0, 1)))
data2 = np.random.multivariate_normal(mu2_fact, cov2_fact, 100)
# print(data2)
# 纵向叠加矩阵
data = np.vstack((data1, data2))
# print(data)
y = np.array([True] * 400 + [False] * 100)
if style == 'sklearn':
# 高斯分布
# n_components :混合元素(聚类)的数量,默认为1
# covariance_type:描述要使用的协方差参数类型的字符串,必选一个(‘full’ , ‘tied’, ‘diag’, ‘spherical’),默认为full,
# full:每个混合元素有它公用的协方差矩阵;即表示不要求方差相等,不要求方差平行坐标轴
# tol:float类型, 默认值: 0.001.收敛阈值,当平均增益低于这个值时迭代停止
# max_iter:最大迭代次数,默认为100
g = GaussianMixture(n_components=2, covariance_type='full', tol=1e-6, max_iter=1000)
g.fit(data)
# weights_ : array-like, shape (n_components,),每个混合元素权重
# means_ : array-like, shape (n_components, n_features),每个混合元素均值
# covariances_ : array-like,每个混合元素的协方差,它的形状依靠协方差类型
print('类别概率:\t', g.weights_[1])
print('均值:\n', g.means_, '\n')
print('方差:\n', g.covariances_, '\n')
mu1, mu2 = g.means_
sigma1, sigma2 = g.covariances_
# 自定义的基于高斯分布的EM算法
else:
num_iter = 100
n, d = data.shape
# 随机指定
# mu1 = np.random.standard_normal(d)
# print mu1
# mu2 = np.random.standard_normal(d)
# print mu2
# 均值、方差、先验概率pi
mu1 = data.min(axis=0)
mu2 = data.max(axis=0)
# Numpy.identity()的document. 输入n为行数或列数,返回一个n*n的对角阵,对角线元素为1,其余为0
sigma1 = np.identity(d)
sigma2 = np.identity(d)
pi = 0.5
# EM
for i in range(num_iter):
# E Step
# 两个模型
norm1 = multivariate_normal(mu1, sigma1)
norm2 = multivariate_normal(mu2, sigma2)
# 先验概率*概率密度
tau1 = pi * norm1.pdf(data)
tau2 = (1 - pi) * norm2.pdf(data)
# ganma值,对于样本x_i,它由第一个组分生成的概率
gamma = tau1 / (tau1 + tau2)
# M Step
mu1 = np.dot(gamma, data) / np.sum(gamma)
mu2 = np.dot((1 - gamma), data) / np.sum((1 - gamma))
sigma1 = np.dot(gamma * (data - mu1).T, data - mu1) / np.sum(gamma)
sigma2 = np.dot((1 - gamma) * (data - mu2).T, data - mu2) / np.sum(1 - gamma)
pi = np.sum(gamma) / n
# print(i, ":\t", mu1, mu2)
print('类别概率:\t', pi)
print('均值:\t', mu1, mu2)
print('方差:\n', sigma1, '\n\n', sigma2, '\n')
# 预测分类
# multivariate_normal根据实际情况生成一个多元正态分布矩阵
norm1 = multivariate_normal(mu1, sigma1)
norm2 = multivariate_normal(mu2, sigma2)
# 计算在data下的概率密度值
tau1 = norm1.pdf(data)
tau2 = norm2.pdf(data)
# 面向对象
fig = plt.figure(figsize=(13, 7), facecolor='w')
# 3d
ax = fig.add_subplot(121, projection='3d')
ax.scatter(data[:, 0], data[:, 1], data[:, 2], c='b', s=30, marker='o', depthshade=True)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title(u'原始数据', fontsize=18)
ax = fig.add_subplot(122, projection='3d')
# pairwise_distances_argmin使用欧几里得距离,返回的是X距离Y最近点的index,如果是[0,1]则没问题。如果[1,0]则反了
order = pairwise_distances_argmin([mu1_fact, mu2_fact], [mu1, mu2], metric='euclidean')
print(order) # [1 0]
# 调整顺序
if order[0] == 0:
c1 = tau1 > tau2
else:
c1 = tau1 < tau2
c2 = ~c1
acc = np.mean(y == c1)
print(u'准确率:%.2f%%' % (100*acc))
ax.scatter(data[c1, 0], data[c1, 1], data[c1, 2], c='r', s=30, marker='o', depthshade=True)
ax.scatter(data[c2, 0], data[c2, 1], data[c2, 2], c='g', s=30, marker='^', depthshade=True)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title(u'EM算法分类', fontsize=18)
plt.suptitle(u'EM算法的实现', fontsize=21)
plt.subplots_adjust(top=0.90)
plt.tight_layout()
plt.show()
类别概率: 0.7650337783291882
均值: [-0.123994 -0.02138048 -0.06003756] [1.9076683 1.79622192 1.11752474]
方差:
[[ 0.82563399 -0.10180706 -0.0414597 ]
[-0.10180706 2.15816316 -0.16360603]
[-0.0414597 -0.16360603 2.79283956]]
[[0.69690051 0.90370392 0.73552321]
[0.90370392 1.8856117 0.76747618]
[0.73552321 0.76747618 2.94819132]]
[0 1]
准确率:89.80%
2.EM算法估算GMM的参数
import numpy as np
from sklearn.mixture import GaussianMixture
from sklearn.model_selection import train_test_split
import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt
# 指定字体
mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False
# from matplotlib.font_manager import FontProperties
# font_set = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=15)
# fontproperties=font_set
def expand(a, b):
d = (b - a) * 0.05
return a-d, b+d
if __name__ == '__main__':
data = np.loadtxt('HeightWeight.csv', dtype=np.float, delimiter=',', skiprows=1)
print(data.shape)
y, x = np.split(data, [1, ], axis=1)
x, x_test, y, y_test = train_test_split(x, y, train_size=0.6, random_state=0)
# 高斯混合模型
gmm = GaussianMixture(n_components=2, covariance_type='full', random_state=0)
x_min = np.min(x, axis=0)
x_max = np.max(x, axis=0)
gmm.fit(x)
print('均值 = \n', gmm.means_)
print('方差 = \n', gmm.covariances_)
y_hat = gmm.predict(x)
y_test_hat = gmm.predict(x_test)
change = (gmm.means_[0][0] > gmm.means_[1][0])
if change:
z = y_hat == 0
y_hat[z] = 1
y_hat[~z] = 0
z = y_test_hat == 0
y_test_hat[z] = 1
y_test_hat[~z] = 0
acc = np.mean(y_hat.ravel() == y.ravel())
acc_test = np.mean(y_test_hat.ravel() == y_test.ravel())
acc_str = u'训练集准确率:%.2f%%' % (acc * 100)
acc_test_str = u'测试集准确率:%.2f%%' % (acc_test * 100)
print(acc_str)
print(acc_test_str)
# 样本颜色
cm_light = mpl.colors.ListedColormap(['#FF8080', '#77E0A0'])
# 图的背景色
cm_dark = mpl.colors.ListedColormap(['r', 'g'])
# 得到最小最大值并扩大
x1_min, x1_max = x[:, 0].min(), x[:, 0].max()
x2_min, x2_max = x[:, 1].min(), x[:, 1].max()
x1_min, x1_max = expand(x1_min, x1_max)
x2_min, x2_max = expand(x2_min, x2_max)
# np.mgrid生成等间隔数值点
x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]
# np.stack函数就是一个用于numpy数组堆叠的函数
grid_test = np.stack((x1.flat, x2.flat), axis=1)
grid_hat = gmm.predict(grid_test)
grid_hat = grid_hat.reshape(x1.shape)
if change:
z = grid_hat == 0
grid_hat[z] = 1
grid_hat[~z] = 0
plt.figure(figsize=(9, 7), facecolor='w')
plt.pcolormesh(x1, x2, grid_hat, cmap=cm_light)
# c=np.squeeze(Y)解决绘制散点图出现的RGBA sequence should have length 3 or 4的错误
plt.scatter(x[:, 0], x[:, 1], s=50, c=np.squeeze(y), marker='o', cmap=cm_dark, edgecolors='k')
plt.scatter(x_test[:, 0], x_test[:, 1], s=60, c=np.squeeze(y_test), marker='^', cmap=cm_dark, edgecolors='k')
# predict_proba在给定数据的情况下,预测每个分量的后验概率
p = gmm.predict_proba(grid_test)
# 设置显示宽度
np.set_printoptions(suppress=True)
print(p)
p = p[:, 0].reshape(x1.shape)
# 绘制等值线
CS = plt.contour(x1, x2, p, levels=(0.1, 0.5, 0.8), colors=list('rgb'), linewidths=2)
plt.clabel(CS, fontsize=15, fmt='%.1f', inline=True)
ax1_min, ax1_max, ax2_min, ax2_max = plt.axis()
xx = 0.9*ax1_min + 0.1*ax1_max
yy = 0.1*ax2_min + 0.9*ax2_max
plt.text(xx, yy, acc_str, fontsize=18)
yy = 0.15*ax2_min + 0.85*ax2_max
plt.text(xx, yy, acc_test_str, fontsize=18)
plt.xlim((x1_min, x1_max))
plt.ylim((x2_min, x2_max))
plt.xlabel(u'身高(cm)', fontsize='large')
plt.ylabel(u'体重(kg)', fontsize='large')
plt.title(u'EM算法估算GMM的参数', fontsize=20)
plt.grid()
plt.show()
(114, 3)
均值 =
[[160.13983374 55.93370575]
[173.50243688 65.03359308]]
方差 =
[[[ 18.82128194 12.30370549]
[ 12.30370549 31.23596113]]
[[ 23.22794989 28.48688647]
[ 28.48688647 105.81824734]]]
训练集准确率:77.94%
测试集准确率:82.61%
[[0.99999775 0.00000225]
[0.99999784 0.00000216]
[0.99999792 0.00000208]
...
[0. 1. ]
[0. 1. ]
[0. 1. ]]
3.GMM调参:covariance_type
import numpy as np
from sklearn.mixture import GaussianMixture
import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt
mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False
def expand(a, b, rate=0.05):
d = (b - a) * rate
return a-d, b+d
def accuracy_rate(y1, y2):
acc = np.mean(y1 == y2)
return acc if acc > 0.5 else 1-acc
if __name__ == '__main__':
np.random.seed(0)
#
cov1 = np.diag((1, 2))
print(cov1)
N1 = 500
N2 = 300
N = N1 + N2
# 两个多元高斯分布
x1 = np.random.multivariate_normal(mean=(1, 2), cov=cov1, size=N1)
m = np.array(((1, 1), (1, 3)))
x1 = x1.dot(m)
x2 = np.random.multivariate_normal(mean=(-1, 10), cov=cov1, size=N2)
x = np.vstack((x1, x2))
y = np.array([0]*N1 + [1]*N2)
types = ('spherical', 'diag', 'tied', 'full')
err = np.empty(len(types))
bic = np.empty(len(types))
for i, type in enumerate(types):
gmm = GaussianMixture(n_components=2, covariance_type=type, random_state=0)
gmm.fit(x)
err[i] = 1 - accuracy_rate(gmm.predict(x), y)
bic[i] = gmm.bic(x)
print('错误率:', err.ravel())
print('BIC:', bic.ravel())
xpos = np.arange(4)
plt.figure(facecolor='w')
ax = plt.axes()
b1 = ax.bar(xpos-0.3, err, width=0.3, color='#77E0A0')
b2 = ax.twinx().bar(xpos, bic, width=0.3, color='#FF8080')
plt.grid(True)
bic_min, bic_max = expand(bic.min(), bic.max())
plt.ylim((bic_min, bic_max))
plt.xticks(xpos, types)
plt.legend([b1[0], b2[0]], (u'错误率', u'BIC'))
plt.title(u'不同方差类型的误差率和BIC', fontsize=18)
plt.show()
optimal = bic.argmin()
gmm = GaussianMixture(n_components=2, covariance_type=types[optimal], random_state=0)
gmm.fit(x)
print('均值 = \n', gmm.means_)
print('方差 = \n', gmm.covariances_)
y_hat = gmm.predict(x)
cm_light = mpl.colors.ListedColormap(['#FF8080', '#77E0A0'])
cm_dark = mpl.colors.ListedColormap(['r', 'g'])
x1_min, x1_max = x[:, 0].min(), x[:, 0].max()
x2_min, x2_max = x[:, 1].min(), x[:, 1].max()
x1_min, x1_max = expand(x1_min, x1_max)
x2_min, x2_max = expand(x2_min, x2_max)
x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]
grid_test = np.stack((x1.flat, x2.flat), axis=1)
grid_hat = gmm.predict(grid_test)
grid_hat = grid_hat.reshape(x1.shape)
if gmm.means_[0][0] > gmm.means_[1][0]:
z = grid_hat == 0
grid_hat[z] = 1
grid_hat[~z] = 0
plt.figure(figsize=(9, 7), facecolor='w')
plt.pcolormesh(x1, x2, grid_hat, cmap=cm_light)
plt.scatter(x[:, 0], x[:, 1], s=30, c=y, marker='o', cmap=cm_dark, edgecolors='k')
ax1_min, ax1_max, ax2_min, ax2_max = plt.axis()
plt.xlim((x1_min, x1_max))
plt.ylim((x2_min, x2_max))
plt.title(u'GMM调参:covariance_type=%s' % types[optimal], fontsize=20)
plt.grid()
plt.show()
[[1 0]
[0 2]]
错误率: [0.385 0.315 0.3 0.00125]
BIC: [7990.71460065 7855.56050855 8006.49834359 6845.79374805]
均值 =
[[ 2.88444448 6.69484552]
[-0.97642254 10.06927801]]
方差 =
[[[ 2.87015473 6.64421303]
[ 6.64421303 18.00318872]]
[[ 0.91302546 -0.04298504]
[-0.04298504 1.9603531 ]]]
4.EM算法无监督分类鸢尾花数据
import numpy as np
import pandas as pd
from sklearn.mixture import GaussianMixture
import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import pairwise_distances_argmin
mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False
iris_feature = u'花萼长度', u'花萼宽度', u'花瓣长度', u'花瓣宽度'
def expand(a, b, rate=0.05):
d = (b - a) * rate
return a-d, b+d
if __name__ == '__main__':
path = 'iris.data'
data = pd.read_csv(path, header=None)
x_prime, y = data[np.arange(4)], data[4]
y = pd.Categorical(y).codes
n_components = 3
feature_pairs = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]
plt.figure(figsize=(10, 9), facecolor='#FFFFFF')
for k, pair in enumerate(feature_pairs):
x = x_prime[pair]
m = np.array([np.mean(x[y == i], axis=0) for i in range(3)]) # 均值的实际值
print('实际均值 = \n', m)
gmm = GaussianMixture(n_components=n_components, covariance_type='full', random_state=0)
gmm.fit(x)
print('预测均值 = \n', gmm.means_)
print('预测方差 = \n', gmm.covariances_)
y_hat = gmm.predict(x)
# 使用欧几里得距离,返回的是X距离Y最近点的index
order = pairwise_distances_argmin(m, gmm.means_, axis=1, metric='euclidean')
print('顺序:\t', order)
# 变换顺序的方法
n_sample = y.size
n_types = 3
change = np.empty((n_types, n_sample), dtype=np.bool)
for i in range(n_types):
change[i] = y_hat == order[i]
for i in range(n_types):
y_hat[change[i]] = i
acc = u'准确率:%.2f%%' % (100*np.mean(y_hat == y))
print(acc)
cm_light = mpl.colors.ListedColormap(['#FF8080', '#77E0A0', '#A0A0FF'])
cm_dark = mpl.colors.ListedColormap(['r', 'g', '#6060FF'])
x1_min, x2_min = x.min()
x1_max, x2_max = x.max()
x1_min, x1_max = expand(x1_min, x1_max)
x2_min, x2_max = expand(x2_min, x2_max)
x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]
grid_test = np.stack((x1.flat, x2.flat), axis=1)
grid_hat = gmm.predict(grid_test)
change = np.empty((n_types, grid_hat.size), dtype=np.bool)
for i in range(n_types):
change[i] = grid_hat == order[i]
for i in range(n_types):
grid_hat[change[i]] = i
grid_hat = grid_hat.reshape(x1.shape)
plt.subplot(3, 2, k+1)
plt.pcolormesh(x1, x2, grid_hat, cmap=cm_light)
plt.scatter(x[pair[0]], x[pair[1]], s=30, c=y, marker='o', cmap=cm_dark, edgecolors='k')
xx = 0.95 * x1_min + 0.05 * x1_max
yy = 0.1 * x2_min + 0.9 * x2_max
plt.text(xx, yy, acc, fontsize=14)
plt.xlim((x1_min, x1_max))
plt.ylim((x2_min, x2_max))
plt.xlabel(iris_feature[pair[0]], fontsize=14)
plt.ylabel(iris_feature[pair[1]], fontsize=14)
plt.grid()
plt.tight_layout(2)
plt.suptitle(u'EM算法无监督分类鸢尾花数据', fontsize=20)
plt.subplots_adjust(top=0.92)
plt.show()
实际均值 =
[[5.006 3.418]
[5.936 2.77 ]
[6.588 2.974]]
预测均值 =
[[5.01493896 3.4404862 ]
[6.6814044 3.0285628 ]
[5.90114537 2.74385294]]
预测方差 =
[[[0.1194876 0.08969867]
[0.08969867 0.12147459]]
[[0.36087007 0.05158991]
[0.05158991 0.08923683]]
[[0.27544608 0.08866062]
[0.08866062 0.09382524]]]
顺序: [0 2 1]
准确率:79.33%
实际均值 =
[[5.006 1.464]
[5.936 4.26 ]
[6.588 5.552]]
预测均值 =
[[5.0060006 1.46399865]
[6.58888904 5.63329718]
[6.04240777 4.41742864]]
预测方差 =
[[[0.12176525 0.01581631]
[0.01581631 0.0295045 ]]
[[0.48521779 0.36602418]
[0.36602418 0.32601109]]
[[0.28119672 0.23746926]
[0.23746926 0.31503012]]]
顺序: [0 2 1]
准确率:91.33%
实际均值 =
[[5.006 0.244]
[5.936 1.326]
[6.588 2.026]]
预测均值 =
[[5.00607264 0.23754806]
[6.56291563 2.02408174]
[5.94928821 1.32089151]]
预测方差 =
[[[0.1239802 0.01055412]
[0.01055412 0.00914172]]
[[0.41146765 0.0558402 ]
[0.0558402 0.07689828]]
[[0.29249903 0.07999787]
[0.07999787 0.0509237 ]]]
顺序: [0 2 1]
准确率:96.00%
实际均值 =
[[3.418 1.464]
[2.77 4.26 ]
[2.974 5.552]]
预测均值 =
[[3.41800009 1.46400001]
[2.80062882 4.43004172]
[2.97017899 5.56073357]]
预测方差 =
[[[0.14227691 0.01144799]
[0.01144799 0.029505 ]]
[[0.09376548 0.10702236]
[0.10702236 0.34454954]]
[[0.11477629 0.07760424]
[0.07760424 0.38871245]]]
顺序: [0 1 2]
准确率:92.67%
实际均值 =
[[3.418 0.244]
[2.77 1.326]
[2.974 2.026]]
预测均值 =
[[3.41800003 0.244 ]
[2.93629236 1.98607968]
[2.79657869 1.31224583]]
预测方差 =
[[[0.14227697 0.011208 ]
[0.011208 0.011265 ]]
[[0.11263095 0.06192916]
[0.06192916 0.08966439]]
[[0.09554395 0.04869984]
[0.04869984 0.03787478]]]
顺序: [0 2 1]
准确率:93.33%
实际均值 =
[[1.464 0.244]
[4.26 1.326]
[5.552 2.026]]
预测均值 =
[[1.46399926 0.24399973]
[4.32760641 1.36230588]
[5.60395923 2.0545418 ]]
预测方差 =
[[[0.02950475 0.00558391]
[0.00558391 0.01126496]]
[[0.25410036 0.09152045]
[0.09152045 0.05088321]]
[[0.29156635 0.03719346]
[0.03719346 0.07073705]]]
顺序: [0 1 2]
准确率:97.33%
5.GMM/DPGMM(贝叶斯高斯分布)比较
import numpy as np
from sklearn.mixture import GaussianMixture, BayesianGaussianMixture
import scipy as sp
import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
def expand(a, b, rate=0.05):
d = (b - a) * rate
return a-d, b+d
matplotlib.rcParams['font.sans-serif'] = [u'SimHei']
matplotlib.rcParams['axes.unicode_minus'] = False
if __name__ == '__main__':
np.random.seed(0)
cov1 = np.diag((1, 2))
N1 = 500
N2 = 300
N = N1 + N2
x1 = np.random.multivariate_normal(mean=(3, 2), cov=cov1, size=N1)
m = np.array(((1, 1), (1, 3)))
x1 = x1.dot(m)
x2 = np.random.multivariate_normal(mean=(-1, 10), cov=cov1, size=N2)
x = np.vstack((x1, x2))
y = np.array([0]*N1 + [1]*N2)
n_components = 3
# 绘图使用
colors = '#A0FFA0', '#2090E0', '#FF8080'
cm = mpl.colors.ListedColormap(colors)
x1_min, x1_max = x[:, 0].min(), x[:, 0].max()
x2_min, x2_max = x[:, 1].min(), x[:, 1].max()
x1_min, x1_max = expand(x1_min, x1_max)
x2_min, x2_max = expand(x2_min, x2_max)
x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]
grid_test = np.stack((x1.flat, x2.flat), axis=1)
plt.figure(figsize=(9, 9), facecolor='w')
plt.suptitle(u'GMM/DPGMM比较', fontsize=23)
ax = plt.subplot(211)
gmm = GaussianMixture(n_components=n_components, covariance_type='full', random_state=0)
gmm.fit(x)
centers = gmm.means_
covs = gmm.covariances_
print('GMM均值 = \n', centers)
print('GMM方差 = \n', covs)
y_hat = gmm.predict(x)
grid_hat = gmm.predict(grid_test)
grid_hat = grid_hat.reshape(x1.shape)
plt.pcolormesh(x1, x2, grid_hat, cmap=cm)
plt.scatter(x[:, 0], x[:, 1], s=30, c=y, cmap=cm, marker='o')
clrs = list('rgbmy')
for i, (center, cov) in enumerate(zip(centers, covs)):
value, vector = sp.linalg.eigh(cov)
width, height = value[0], value[1]
v = vector[0] / sp.linalg.norm(vector[0])
angle = 180* np.arctan(v[1] / v[0]) / np.pi
e = Ellipse(xy=center, width=width, height=height,
angle=angle, color=clrs[i], alpha=0.5, clip_box = ax.bbox)
ax.add_artist(e)
ax1_min, ax1_max, ax2_min, ax2_max = plt.axis()
plt.xlim((x1_min, x1_max))
plt.ylim((x2_min, x2_max))
plt.title(u'GMM', fontsize=20)
plt.grid(True)
# DPGMM
dpgmm = BayesianGaussianMixture(n_components=n_components, covariance_type='full', max_iter=1000, n_init=5,
weight_concentration_prior_type='dirichlet_process', weight_concentration_prior=0.1)
dpgmm.fit(x)
centers = dpgmm.means_
covs = dpgmm.covariances_
print('DPGMM均值 = \n', centers)
print('DPGMM方差 = \n', covs)
y_hat = dpgmm.predict(x)
print(y_hat)
ax = plt.subplot(212)
grid_hat = dpgmm.predict(grid_test)
grid_hat = grid_hat.reshape(x1.shape)
plt.pcolormesh(x1, x2, grid_hat, cmap=cm)
plt.scatter(x[:, 0], x[:, 1], s=30, c=y, cmap=cm, marker='o')
for i, cc in enumerate(zip(centers, covs)):
if i not in y_hat:
continue
center, cov = cc
value, vector = sp.linalg.eigh(cov)
width, height = value[0], value[1]
v = vector[0] / sp.linalg.norm(vector[0])
angle = 180* np.arctan(v[1] / v[0]) / np.pi
e = Ellipse(xy=center, width=width, height=height,
angle=angle, color='m', alpha=0.5, clip_box = ax.bbox)
ax.add_artist(e)
plt.xlim((x1_min, x1_max))
plt.ylim((x2_min, x2_max))
plt.title('DPGMM', fontsize=20)
plt.grid(True)
plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.show()
GMM均值 =
[[ 3.77430768 5.86579463]
[ 6.0239399 11.61448122]
[-0.98543679 10.0756839 ]]
GMM方差 =
[[[ 1.5383593 3.21210121]
[ 3.21210121 9.04107582]]
[[ 1.6667472 3.58655076]
[ 3.58655076 10.40673433]]
[[ 0.89079177 -0.02572518]
[-0.02572518 1.95106592]]]
DPGMM均值 =
[[ 4.87807808 8.69858646]
[-0.97320511 10.07279749]
[ 2.68191465 9.21436833]]
DPGMM方差 =
[[[ 2.88507577 6.60477348]
[ 6.60477348 17.92760296]]
[[ 0.9632641 -0.02865042]
[-0.02865042 1.98179578]]
[[ 5.104264 1.12039777]
[ 1.12039777 6.21286898]]]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
更多推荐
所有评论(0)