T-SNE可视化代码
T-SNE可视化代码,以及食用方法!
·
import datetime
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
def TSne(out, test_samples_gt):
day = datetime.datetime.now()
day_str = day.strftime('%m_%d_%H_%M')
features = out.cpu()
labels = test_samples_gt.cpu().reshape(-1)
print("labels.shape",labels.shape)
# ind = np.argwhere(labels != 0).reshape(-1)
# print("ind",ind.shape)
# features = features[ind, :]
# labels = labels[ind]
unique_labels = np.unique(labels)
print("Unique Labels:", unique_labels)
CLASSES_NUM = int(max(labels)+1)
embeddings = TSNE(n_components=2, n_jobs=24).fit_transform(features)
vis_x = embeddings[:, 0]
vis_y = embeddings[:, 1]
# Adjust the alpha (opacity) and marker size for better visibility
alpha = 1 # You can adjust this value
marker_size = 5 # You can adjust this value
# Adjust edgecolors and linewidths for better distinction
edgecolors = 'k' # Black edges
linewidths = 0.1 # You can adjust this value
fig = plt.scatter(vis_x, vis_y, c=labels, cmap=plt.cm.get_cmap("jet", CLASSES_NUM),
marker='.', alpha=alpha, s=marker_size, edgecolors=edgecolors, linewidths=linewidths)
cbar = plt.colorbar(ticks=range(1, 1 + CLASSES_NUM))
plt.clim(0.5, 0.5 + CLASSES_NUM)
# Adjust colorbar label font size
cbar.ax.tick_params(labelsize=5) # Adjust the font size as needed
plt.axis('off')
fig.axes.get_xaxis().set_visible(False)
fig.axes.get_yaxis().set_visible(False)
plt.savefig("TSNE_" + day_str + ".png", dpi=600, format='png', bbox_inches='tight', pad_inches=0)
plt.clf()
其中,TSne这个类中包含的参数中,out表示你要可视化的图片,test_samples_gt表示图片的标签。
注意:在调用TSne类之前要先把要可视化的图片以及图片对应的标签拼接起来。
net.eval()
acc = 0.0 # accumulate accurate number / epoch
feature = []
all_label = []
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
feature.append(outputs)
all_label.append(val_labels)
# loss = loss_function(outputs, test_labels)
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
epochs)
features = torch.cat(feature)
all_labels = torch.cat(all_label)
print('features',features.shape)
print("label",all_labels.shape)
# T_SNE = tsne(features,all_labels)
TSne(features,all_labels)
更多推荐
所有评论(0)