import datetime
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

def TSne(out, test_samples_gt):
    day = datetime.datetime.now()
    day_str = day.strftime('%m_%d_%H_%M')

    features = out.cpu()
    labels = test_samples_gt.cpu().reshape(-1)
    print("labels.shape",labels.shape)
    # ind = np.argwhere(labels != 0).reshape(-1)
    # print("ind",ind.shape)
    # features = features[ind, :]
    # labels = labels[ind]
    unique_labels = np.unique(labels)
    print("Unique Labels:", unique_labels)
    CLASSES_NUM = int(max(labels)+1)

    embeddings = TSNE(n_components=2, n_jobs=24).fit_transform(features)

    vis_x = embeddings[:, 0]
    vis_y = embeddings[:, 1]

    # Adjust the alpha (opacity) and marker size for better visibility
    alpha = 1  # You can adjust this value
    marker_size = 5  # You can adjust this value

    # Adjust edgecolors and linewidths for better distinction
    edgecolors = 'k'  # Black edges
    linewidths = 0.1  # You can adjust this value

    fig = plt.scatter(vis_x, vis_y, c=labels, cmap=plt.cm.get_cmap("jet", CLASSES_NUM),
                      marker='.', alpha=alpha, s=marker_size, edgecolors=edgecolors, linewidths=linewidths)

    cbar = plt.colorbar(ticks=range(1, 1 + CLASSES_NUM))
    plt.clim(0.5, 0.5 + CLASSES_NUM)
    # Adjust colorbar label font size
    cbar.ax.tick_params(labelsize=5)  # Adjust the font size as needed

    plt.axis('off')
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)

    plt.savefig("TSNE_" + day_str + ".png", dpi=600, format='png', bbox_inches='tight', pad_inches=0)
    plt.clf()

其中,TSne这个类中包含的参数中,out表示你要可视化的图片,test_samples_gt表示图片的标签。

注意:在调用TSne类之前要先把要可视化的图片以及图片对应的标签拼接起来。

        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        feature = []
        all_label = []
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                feature.append(outputs)
                all_label.append(val_labels)
                # loss = loss_function(outputs, test_labels)
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                           epochs)
        features = torch.cat(feature)
        all_labels = torch.cat(all_label)
        print('features',features.shape)
        print("label",all_labels.shape)
        # T_SNE = tsne(features,all_labels)
        TSne(features,all_labels)

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐