1.EM算法的实现

import numpy as np
from scipy.stats import multivariate_normal
from sklearn.mixture import GaussianMixture
from mpl_toolkits.mplot3d import Axes3D
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import pairwise_distances_argmin


mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False


if __name__ == '__main__':
    style = 'myself'

    np.random.seed(0)
    # 构造均值与方差
    mu1_fact = (0, 0, 0)
    # numpy.diag()返回一个矩阵的对角线元素,或者创建一个对角阵( diagonal array.)
    cov1_fact = np.diag((1, 2, 3))
    # 依据指定的均值和协方差生成高斯分布数据
    data1 = np.random.multivariate_normal(mu1_fact, cov1_fact, 400)
    # print(data1)

    mu2_fact = (2, 2, 1)
    cov2_fact = np.array(((1, 1, 3), (1, 2, 1), (0, 0, 1)))
    data2 = np.random.multivariate_normal(mu2_fact, cov2_fact, 100)
    # print(data2)
    # 纵向叠加矩阵
    data = np.vstack((data1, data2))
    # print(data)
    y = np.array([True] * 400 + [False] * 100)

    if style == 'sklearn':

        # 高斯分布
        # n_components :混合元素(聚类)的数量,默认为1
        # covariance_type:描述要使用的协方差参数类型的字符串,必选一个(‘full’ , ‘tied’, ‘diag’, ‘spherical’),默认为full,
        # full:每个混合元素有它公用的协方差矩阵;即表示不要求方差相等,不要求方差平行坐标轴
        # tol:float类型, 默认值: 0.001.收敛阈值,当平均增益低于这个值时迭代停止
        # max_iter:最大迭代次数,默认为100
        g = GaussianMixture(n_components=2, covariance_type='full', tol=1e-6, max_iter=1000)
        g.fit(data)
        # weights_ : array-like, shape (n_components,),每个混合元素权重
        # means_ : array-like, shape (n_components, n_features),每个混合元素均值
        # covariances_ : array-like,每个混合元素的协方差,它的形状依靠协方差类型
        print('类别概率:\t', g.weights_[1])
        print('均值:\n', g.means_, '\n')
        print('方差:\n', g.covariances_, '\n')
        mu1, mu2 = g.means_
        sigma1, sigma2 = g.covariances_

    # 自定义的基于高斯分布的EM算法
    else:
        num_iter = 100
        n, d = data.shape
        # 随机指定
        # mu1 = np.random.standard_normal(d)
        # print mu1
        # mu2 = np.random.standard_normal(d)
        # print mu2
        # 均值、方差、先验概率pi
        mu1 = data.min(axis=0)
        mu2 = data.max(axis=0)
        # Numpy.identity()的document. 输入n为行数或列数,返回一个n*n的对角阵,对角线元素为1,其余为0
        sigma1 = np.identity(d)
        sigma2 = np.identity(d)
        pi = 0.5
        # EM
        for i in range(num_iter):
            # E Step
            # 两个模型
            norm1 = multivariate_normal(mu1, sigma1)
            norm2 = multivariate_normal(mu2, sigma2)
            # 先验概率*概率密度
            tau1 = pi * norm1.pdf(data)
            tau2 = (1 - pi) * norm2.pdf(data)
            # ganma值,对于样本x_i,它由第一个组分生成的概率
            gamma = tau1 / (tau1 + tau2)

            # M Step
            mu1 = np.dot(gamma, data) / np.sum(gamma)
            mu2 = np.dot((1 - gamma), data) / np.sum((1 - gamma))
            sigma1 = np.dot(gamma * (data - mu1).T, data - mu1) / np.sum(gamma)
            sigma2 = np.dot((1 - gamma) * (data - mu2).T, data - mu2) / np.sum(1 - gamma)
            pi = np.sum(gamma) / n
            # print(i, ":\t", mu1, mu2)
        print('类别概率:\t', pi)
        print('均值:\t', mu1, mu2)
        print('方差:\n', sigma1, '\n\n', sigma2, '\n')

    # 预测分类
    # multivariate_normal根据实际情况生成一个多元正态分布矩阵
    norm1 = multivariate_normal(mu1, sigma1)
    norm2 = multivariate_normal(mu2, sigma2)
    # 计算在data下的概率密度值
    tau1 = norm1.pdf(data)
    tau2 = norm2.pdf(data)

    # 面向对象
    fig = plt.figure(figsize=(13, 7), facecolor='w')
    # 3d
    ax = fig.add_subplot(121, projection='3d')
    ax.scatter(data[:, 0], data[:, 1], data[:, 2], c='b', s=30, marker='o', depthshade=True)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title(u'原始数据', fontsize=18)
    ax = fig.add_subplot(122, projection='3d')
    # pairwise_distances_argmin使用欧几里得距离,返回的是X距离Y最近点的index,如果是[0,1]则没问题。如果[1,0]则反了
    order = pairwise_distances_argmin([mu1_fact, mu2_fact], [mu1, mu2], metric='euclidean')
    print(order) # [1 0]

    # 调整顺序
    if order[0] == 0:
        c1 = tau1 > tau2
    else:
        c1 = tau1 < tau2
    c2 = ~c1
    acc = np.mean(y == c1)
    print(u'准确率:%.2f%%' % (100*acc))

    ax.scatter(data[c1, 0], data[c1, 1], data[c1, 2], c='r', s=30, marker='o', depthshade=True)
    ax.scatter(data[c2, 0], data[c2, 1], data[c2, 2], c='g', s=30, marker='^', depthshade=True)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title(u'EM算法分类', fontsize=18)
    plt.suptitle(u'EM算法的实现', fontsize=21)
    plt.subplots_adjust(top=0.90)
    plt.tight_layout()
    plt.show()

类别概率:	 0.7650337783291882
均值:	 [-0.123994   -0.02138048 -0.06003756] [1.9076683  1.79622192 1.11752474]
方差:
 [[ 0.82563399 -0.10180706 -0.0414597 ]
 [-0.10180706  2.15816316 -0.16360603]
 [-0.0414597  -0.16360603  2.79283956]] 

 [[0.69690051 0.90370392 0.73552321]
 [0.90370392 1.8856117  0.76747618]
 [0.73552321 0.76747618 2.94819132]] 

[0 1]
准确率:89.80%

在这里插入图片描述

2.EM算法估算GMM的参数

import numpy as np
from sklearn.mixture import GaussianMixture
from sklearn.model_selection import train_test_split
import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt

# 指定字体
mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False
# from matplotlib.font_manager import FontProperties
# font_set = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=15)
# fontproperties=font_set


def expand(a, b):
    d = (b - a) * 0.05
    return a-d, b+d


if __name__ == '__main__':

    data = np.loadtxt('HeightWeight.csv', dtype=np.float, delimiter=',', skiprows=1)
    print(data.shape)
    y, x = np.split(data, [1, ], axis=1)
    x, x_test, y, y_test = train_test_split(x, y, train_size=0.6, random_state=0)
    # 高斯混合模型
    gmm = GaussianMixture(n_components=2, covariance_type='full', random_state=0)
    x_min = np.min(x, axis=0)
    x_max = np.max(x, axis=0)
    gmm.fit(x)
    print('均值 = \n', gmm.means_)
    print('方差 = \n', gmm.covariances_)
    y_hat = gmm.predict(x)
    y_test_hat = gmm.predict(x_test)
    change = (gmm.means_[0][0] > gmm.means_[1][0])
    if change:
        z = y_hat == 0
        y_hat[z] = 1
        y_hat[~z] = 0
        z = y_test_hat == 0
        y_test_hat[z] = 1
        y_test_hat[~z] = 0

    acc = np.mean(y_hat.ravel() == y.ravel())
    acc_test = np.mean(y_test_hat.ravel() == y_test.ravel())
    acc_str = u'训练集准确率:%.2f%%' % (acc * 100)
    acc_test_str = u'测试集准确率:%.2f%%' % (acc_test * 100)
    print(acc_str)
    print(acc_test_str)

    # 样本颜色
    cm_light = mpl.colors.ListedColormap(['#FF8080', '#77E0A0'])
    # 图的背景色
    cm_dark = mpl.colors.ListedColormap(['r', 'g'])

    # 得到最小最大值并扩大
    x1_min, x1_max = x[:, 0].min(), x[:, 0].max()
    x2_min, x2_max = x[:, 1].min(), x[:, 1].max()
    x1_min, x1_max = expand(x1_min, x1_max)
    x2_min, x2_max = expand(x2_min, x2_max)

    # np.mgrid生成等间隔数值点
    x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]
    # np.stack函数就是一个用于numpy数组堆叠的函数
    grid_test = np.stack((x1.flat, x2.flat), axis=1)
    grid_hat = gmm.predict(grid_test)
    grid_hat = grid_hat.reshape(x1.shape)

    if change:
        z = grid_hat == 0
        grid_hat[z] = 1
        grid_hat[~z] = 0
    plt.figure(figsize=(9, 7), facecolor='w')
    plt.pcolormesh(x1, x2, grid_hat, cmap=cm_light)
    # c=np.squeeze(Y)解决绘制散点图出现的RGBA sequence should have length 3 or 4的错误
    plt.scatter(x[:, 0], x[:, 1], s=50, c=np.squeeze(y), marker='o', cmap=cm_dark, edgecolors='k')
    plt.scatter(x_test[:, 0], x_test[:, 1], s=60, c=np.squeeze(y_test), marker='^', cmap=cm_dark, edgecolors='k')

    # predict_proba在给定数据的情况下,预测每个分量的后验概率
    p = gmm.predict_proba(grid_test)
    # 设置显示宽度
    np.set_printoptions(suppress=True)
    print(p)
    p = p[:, 0].reshape(x1.shape)

    # 绘制等值线
    CS = plt.contour(x1, x2, p, levels=(0.1, 0.5, 0.8), colors=list('rgb'), linewidths=2)
    plt.clabel(CS, fontsize=15, fmt='%.1f', inline=True)
    ax1_min, ax1_max, ax2_min, ax2_max = plt.axis()
    xx = 0.9*ax1_min + 0.1*ax1_max
    yy = 0.1*ax2_min + 0.9*ax2_max
    plt.text(xx, yy, acc_str, fontsize=18)
    yy = 0.15*ax2_min + 0.85*ax2_max
    plt.text(xx, yy, acc_test_str, fontsize=18)
    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    plt.xlabel(u'身高(cm)', fontsize='large')
    plt.ylabel(u'体重(kg)', fontsize='large')
    plt.title(u'EM算法估算GMM的参数', fontsize=20)
    plt.grid()
    plt.show()
(114, 3)
均值 = 
 [[160.13983374  55.93370575]
 [173.50243688  65.03359308]]
方差 = 
 [[[ 18.82128194  12.30370549]
  [ 12.30370549  31.23596113]]

 [[ 23.22794989  28.48688647]
  [ 28.48688647 105.81824734]]]
训练集准确率:77.94%
测试集准确率:82.61%
[[0.99999775 0.00000225]
 [0.99999784 0.00000216]
 [0.99999792 0.00000208]
 ...
 [0.         1.        ]
 [0.         1.        ]
 [0.         1.        ]]

在这里插入图片描述

3.GMM调参:covariance_type

import numpy as np
from sklearn.mixture import GaussianMixture
import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt

mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False


def expand(a, b, rate=0.05):
    d = (b - a) * rate
    return a-d, b+d


def accuracy_rate(y1, y2):
    acc = np.mean(y1 == y2)
    return acc if acc > 0.5 else 1-acc


if __name__ == '__main__':
    np.random.seed(0)
    #
    cov1 = np.diag((1, 2))
    print(cov1)
    N1 = 500
    N2 = 300
    N = N1 + N2
    # 两个多元高斯分布
    x1 = np.random.multivariate_normal(mean=(1, 2), cov=cov1, size=N1)
    m = np.array(((1, 1), (1, 3)))
    x1 = x1.dot(m)
    x2 = np.random.multivariate_normal(mean=(-1, 10), cov=cov1, size=N2)
    x = np.vstack((x1, x2))
    y = np.array([0]*N1 + [1]*N2)

    types = ('spherical', 'diag', 'tied', 'full')
    err = np.empty(len(types))
    bic = np.empty(len(types))
    for i, type in enumerate(types):
        gmm = GaussianMixture(n_components=2, covariance_type=type, random_state=0)
        gmm.fit(x)
        err[i] = 1 - accuracy_rate(gmm.predict(x), y)
        bic[i] = gmm.bic(x)
    print('错误率:', err.ravel())
    print('BIC:', bic.ravel())
    xpos = np.arange(4)
    plt.figure(facecolor='w')
    ax = plt.axes()
    b1 = ax.bar(xpos-0.3, err, width=0.3, color='#77E0A0')
    b2 = ax.twinx().bar(xpos, bic, width=0.3, color='#FF8080')
    plt.grid(True)
    bic_min, bic_max = expand(bic.min(), bic.max())
    plt.ylim((bic_min, bic_max))
    plt.xticks(xpos, types)
    plt.legend([b1[0], b2[0]], (u'错误率', u'BIC'))
    plt.title(u'不同方差类型的误差率和BIC', fontsize=18)
    plt.show()

    optimal = bic.argmin()
    gmm = GaussianMixture(n_components=2, covariance_type=types[optimal], random_state=0)
    gmm.fit(x)
    print('均值 = \n', gmm.means_)
    print('方差 = \n', gmm.covariances_)
    y_hat = gmm.predict(x)

    cm_light = mpl.colors.ListedColormap(['#FF8080', '#77E0A0'])
    cm_dark = mpl.colors.ListedColormap(['r', 'g'])
    x1_min, x1_max = x[:, 0].min(), x[:, 0].max()
    x2_min, x2_max = x[:, 1].min(), x[:, 1].max()
    x1_min, x1_max = expand(x1_min, x1_max)
    x2_min, x2_max = expand(x2_min, x2_max)
    x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]
    grid_test = np.stack((x1.flat, x2.flat), axis=1)
    grid_hat = gmm.predict(grid_test)
    grid_hat = grid_hat.reshape(x1.shape)
    if gmm.means_[0][0] > gmm.means_[1][0]:
        z = grid_hat == 0
        grid_hat[z] = 1
        grid_hat[~z] = 0
    plt.figure(figsize=(9, 7), facecolor='w')
    plt.pcolormesh(x1, x2, grid_hat, cmap=cm_light)
    plt.scatter(x[:, 0], x[:, 1], s=30, c=y, marker='o', cmap=cm_dark, edgecolors='k')

    ax1_min, ax1_max, ax2_min, ax2_max = plt.axis()
    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    plt.title(u'GMM调参:covariance_type=%s' % types[optimal], fontsize=20)
    plt.grid()
    plt.show()
[[1 0]
 [0 2]]
错误率: [0.385   0.315   0.3     0.00125]
BIC: [7990.71460065 7855.56050855 8006.49834359 6845.79374805]
均值 = 
 [[ 2.88444448  6.69484552]
 [-0.97642254 10.06927801]]
方差 = 
 [[[ 2.87015473  6.64421303]
  [ 6.64421303 18.00318872]]

 [[ 0.91302546 -0.04298504]
  [-0.04298504  1.9603531 ]]]

在这里插入图片描述
在这里插入图片描述

4.EM算法无监督分类鸢尾花数据

import numpy as np
import pandas as pd
from sklearn.mixture import GaussianMixture
import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import pairwise_distances_argmin

mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False

iris_feature = u'花萼长度', u'花萼宽度', u'花瓣长度', u'花瓣宽度'


def expand(a, b, rate=0.05):
    d = (b - a) * rate
    return a-d, b+d


if __name__ == '__main__':
    path = 'iris.data'
    data = pd.read_csv(path, header=None)
    x_prime, y = data[np.arange(4)], data[4]
    y = pd.Categorical(y).codes

    n_components = 3
    feature_pairs = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]
    plt.figure(figsize=(10, 9), facecolor='#FFFFFF')
    for k, pair in enumerate(feature_pairs):
        x = x_prime[pair]
        m = np.array([np.mean(x[y == i], axis=0) for i in range(3)])  # 均值的实际值
        print('实际均值 = \n', m)

        gmm = GaussianMixture(n_components=n_components, covariance_type='full', random_state=0)
        gmm.fit(x)
        print('预测均值 = \n', gmm.means_)
        print('预测方差 = \n', gmm.covariances_)
        y_hat = gmm.predict(x)
        # 使用欧几里得距离,返回的是X距离Y最近点的index
        order = pairwise_distances_argmin(m, gmm.means_, axis=1, metric='euclidean')
        print('顺序:\t', order)

        # 变换顺序的方法
        n_sample = y.size
        n_types = 3
        change = np.empty((n_types, n_sample), dtype=np.bool)
        for i in range(n_types):
            change[i] = y_hat == order[i]
        for i in range(n_types):
            y_hat[change[i]] = i
        acc = u'准确率:%.2f%%' % (100*np.mean(y_hat == y))
        print(acc)

        cm_light = mpl.colors.ListedColormap(['#FF8080', '#77E0A0', '#A0A0FF'])
        cm_dark = mpl.colors.ListedColormap(['r', 'g', '#6060FF'])
        x1_min, x2_min = x.min()
        x1_max, x2_max = x.max()
        x1_min, x1_max = expand(x1_min, x1_max)
        x2_min, x2_max = expand(x2_min, x2_max)
        x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]
        grid_test = np.stack((x1.flat, x2.flat), axis=1)
        grid_hat = gmm.predict(grid_test)

        change = np.empty((n_types, grid_hat.size), dtype=np.bool)
        for i in range(n_types):
            change[i] = grid_hat == order[i]
        for i in range(n_types):
            grid_hat[change[i]] = i

        grid_hat = grid_hat.reshape(x1.shape)
        plt.subplot(3, 2, k+1)
        plt.pcolormesh(x1, x2, grid_hat, cmap=cm_light)
        plt.scatter(x[pair[0]], x[pair[1]], s=30, c=y, marker='o', cmap=cm_dark, edgecolors='k')
        xx = 0.95 * x1_min + 0.05 * x1_max
        yy = 0.1 * x2_min + 0.9 * x2_max
        plt.text(xx, yy, acc, fontsize=14)
        plt.xlim((x1_min, x1_max))
        plt.ylim((x2_min, x2_max))
        plt.xlabel(iris_feature[pair[0]], fontsize=14)
        plt.ylabel(iris_feature[pair[1]], fontsize=14)
        plt.grid()
    plt.tight_layout(2)
    plt.suptitle(u'EM算法无监督分类鸢尾花数据', fontsize=20)
    plt.subplots_adjust(top=0.92)
    plt.show()

实际均值 = 
 [[5.006 3.418]
 [5.936 2.77 ]
 [6.588 2.974]]
预测均值 = 
 [[5.01493896 3.4404862 ]
 [6.6814044  3.0285628 ]
 [5.90114537 2.74385294]]
预测方差 = 
 [[[0.1194876  0.08969867]
  [0.08969867 0.12147459]]

 [[0.36087007 0.05158991]
  [0.05158991 0.08923683]]

 [[0.27544608 0.08866062]
  [0.08866062 0.09382524]]]
顺序:	 [0 2 1]
准确率:79.33%
实际均值 = 
 [[5.006 1.464]
 [5.936 4.26 ]
 [6.588 5.552]]
预测均值 = 
 [[5.0060006  1.46399865]
 [6.58888904 5.63329718]
 [6.04240777 4.41742864]]
预测方差 = 
 [[[0.12176525 0.01581631]
  [0.01581631 0.0295045 ]]

 [[0.48521779 0.36602418]
  [0.36602418 0.32601109]]

 [[0.28119672 0.23746926]
  [0.23746926 0.31503012]]]
顺序:	 [0 2 1]
准确率:91.33%
实际均值 = 
 [[5.006 0.244]
 [5.936 1.326]
 [6.588 2.026]]
预测均值 = 
 [[5.00607264 0.23754806]
 [6.56291563 2.02408174]
 [5.94928821 1.32089151]]
预测方差 = 
 [[[0.1239802  0.01055412]
  [0.01055412 0.00914172]]

 [[0.41146765 0.0558402 ]
  [0.0558402  0.07689828]]

 [[0.29249903 0.07999787]
  [0.07999787 0.0509237 ]]]
顺序:	 [0 2 1]
准确率:96.00%
实际均值 = 
 [[3.418 1.464]
 [2.77  4.26 ]
 [2.974 5.552]]
预测均值 = 
 [[3.41800009 1.46400001]
 [2.80062882 4.43004172]
 [2.97017899 5.56073357]]
预测方差 = 
 [[[0.14227691 0.01144799]
  [0.01144799 0.029505  ]]

 [[0.09376548 0.10702236]
  [0.10702236 0.34454954]]

 [[0.11477629 0.07760424]
  [0.07760424 0.38871245]]]
顺序:	 [0 1 2]
准确率:92.67%
实际均值 = 
 [[3.418 0.244]
 [2.77  1.326]
 [2.974 2.026]]
预测均值 = 
 [[3.41800003 0.244     ]
 [2.93629236 1.98607968]
 [2.79657869 1.31224583]]
预测方差 = 
 [[[0.14227697 0.011208  ]
  [0.011208   0.011265  ]]

 [[0.11263095 0.06192916]
  [0.06192916 0.08966439]]

 [[0.09554395 0.04869984]
  [0.04869984 0.03787478]]]
顺序:	 [0 2 1]
准确率:93.33%
实际均值 = 
 [[1.464 0.244]
 [4.26  1.326]
 [5.552 2.026]]
预测均值 = 
 [[1.46399926 0.24399973]
 [4.32760641 1.36230588]
 [5.60395923 2.0545418 ]]
预测方差 = 
 [[[0.02950475 0.00558391]
  [0.00558391 0.01126496]]

 [[0.25410036 0.09152045]
  [0.09152045 0.05088321]]

 [[0.29156635 0.03719346]
  [0.03719346 0.07073705]]]
顺序:	 [0 1 2]
准确率:97.33%

在这里插入图片描述

5.GMM/DPGMM(贝叶斯高斯分布)比较

import numpy as np
from sklearn.mixture import GaussianMixture, BayesianGaussianMixture
import scipy as sp
import matplotlib as mpl
import matplotlib.colors
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse


def expand(a, b, rate=0.05):
    d = (b - a) * rate
    return a-d, b+d


matplotlib.rcParams['font.sans-serif'] = [u'SimHei']
matplotlib.rcParams['axes.unicode_minus'] = False


if __name__ == '__main__':
    np.random.seed(0)
    cov1 = np.diag((1, 2))
    N1 = 500
    N2 = 300
    N = N1 + N2
    x1 = np.random.multivariate_normal(mean=(3, 2), cov=cov1, size=N1)
    m = np.array(((1, 1), (1, 3)))
    x1 = x1.dot(m)
    x2 = np.random.multivariate_normal(mean=(-1, 10), cov=cov1, size=N2)
    x = np.vstack((x1, x2))
    y = np.array([0]*N1 + [1]*N2)
    n_components = 3

    # 绘图使用
    colors = '#A0FFA0', '#2090E0', '#FF8080'
    cm = mpl.colors.ListedColormap(colors)
    x1_min, x1_max = x[:, 0].min(), x[:, 0].max()
    x2_min, x2_max = x[:, 1].min(), x[:, 1].max()
    x1_min, x1_max = expand(x1_min, x1_max)
    x2_min, x2_max = expand(x2_min, x2_max)
    x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]
    grid_test = np.stack((x1.flat, x2.flat), axis=1)

    plt.figure(figsize=(9, 9), facecolor='w')
    plt.suptitle(u'GMM/DPGMM比较', fontsize=23)

    ax = plt.subplot(211)
    gmm = GaussianMixture(n_components=n_components, covariance_type='full', random_state=0)
    gmm.fit(x)
    centers = gmm.means_
    covs = gmm.covariances_
    print('GMM均值 = \n', centers)
    print('GMM方差 = \n', covs)
    y_hat = gmm.predict(x)

    grid_hat = gmm.predict(grid_test)
    grid_hat = grid_hat.reshape(x1.shape)
    plt.pcolormesh(x1, x2, grid_hat, cmap=cm)
    plt.scatter(x[:, 0], x[:, 1], s=30, c=y, cmap=cm, marker='o')

    clrs = list('rgbmy')
    for i, (center, cov) in enumerate(zip(centers, covs)):
        value, vector = sp.linalg.eigh(cov)
        width, height = value[0], value[1]
        v = vector[0] / sp.linalg.norm(vector[0])
        angle = 180* np.arctan(v[1] / v[0]) / np.pi
        e = Ellipse(xy=center, width=width, height=height,
                    angle=angle, color=clrs[i], alpha=0.5, clip_box = ax.bbox)
        ax.add_artist(e)

    ax1_min, ax1_max, ax2_min, ax2_max = plt.axis()
    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    plt.title(u'GMM', fontsize=20)
    plt.grid(True)

    # DPGMM
    dpgmm = BayesianGaussianMixture(n_components=n_components, covariance_type='full', max_iter=1000, n_init=5,
                                    weight_concentration_prior_type='dirichlet_process', weight_concentration_prior=0.1)
    dpgmm.fit(x)
    centers = dpgmm.means_
    covs = dpgmm.covariances_
    print('DPGMM均值 = \n', centers)
    print('DPGMM方差 = \n', covs)
    y_hat = dpgmm.predict(x)
    print(y_hat)

    ax = plt.subplot(212)
    grid_hat = dpgmm.predict(grid_test)
    grid_hat = grid_hat.reshape(x1.shape)
    plt.pcolormesh(x1, x2, grid_hat, cmap=cm)
    plt.scatter(x[:, 0], x[:, 1], s=30, c=y, cmap=cm, marker='o')

    for i, cc in enumerate(zip(centers, covs)):
        if i not in y_hat:
            continue
        center, cov = cc
        value, vector = sp.linalg.eigh(cov)
        width, height = value[0], value[1]
        v = vector[0] / sp.linalg.norm(vector[0])
        angle = 180* np.arctan(v[1] / v[0]) / np.pi
        e = Ellipse(xy=center, width=width, height=height,
                    angle=angle, color='m', alpha=0.5, clip_box = ax.bbox)
        ax.add_artist(e)

    plt.xlim((x1_min, x1_max))
    plt.ylim((x2_min, x2_max))
    plt.title('DPGMM', fontsize=20)
    plt.grid(True)

    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    plt.show()
GMM均值 = 
 [[ 3.77430768  5.86579463]
 [ 6.0239399  11.61448122]
 [-0.98543679 10.0756839 ]]
GMM方差 = 
 [[[ 1.5383593   3.21210121]
  [ 3.21210121  9.04107582]]

 [[ 1.6667472   3.58655076]
  [ 3.58655076 10.40673433]]

 [[ 0.89079177 -0.02572518]
  [-0.02572518  1.95106592]]]
DPGMM均值 = 
 [[ 4.87807808  8.69858646]
 [-0.97320511 10.07279749]
 [ 2.68191465  9.21436833]]
DPGMM方差 = 
 [[[ 2.88507577  6.60477348]
  [ 6.60477348 17.92760296]]

 [[ 0.9632641  -0.02865042]
  [-0.02865042  1.98179578]]

 [[ 5.104264    1.12039777]
  [ 1.12039777  6.21286898]]]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]

在这里插入图片描述

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐