numpy实现VIT vision transformer在MNIST

实现的github,每一层输出都单独和pytorch对齐了

ZouJiu1/numpy_transformer: transformer which using numpy,vision transformer of VIT, MNIST testset precision > 90%,mutil-attention, patch embed, position embed, full connect, convolution, etc. train normally, save model, restore model (github.com)
github.com/ZouJiu1/numpy_transformer

训练codes,由于是CPU训练所以速度较慢
写了些个模型,可以正常训练的,还支持保存模型和加载模型重训
训练MNIST数据集,testset准确率可以达到90%
总共实现了这几个层:
numpy实现vision transformer图像输入的patch - 知乎 (zhihu.com)

numpy实现vision transformer的position embedding - 知乎 (zhihu.com)

numpy实现multi-attention层的前向传播和反向传播 - 知乎 (zhihu.com)
convolution层: Convolution卷积层的前向传播和反向传播 - 知乎 (zhihu.com)
Fullconnect层: 全连接层的前向传播和反向传播 - 知乎 (zhihu.com)
Cross Entropy和MSE损失函数层: 损失函数的前向传播和反向传播 - 知乎 (zhihu.com)
flatten层和Relu层: flatten层和Relu层的前向传播和反向传播 - 知乎 (zhihu.com)

完整的codes

import os
from net.Convolution import convolution_layer
from net.loss import cross_entropy_loss, mean_square_loss
from net.fullconnect import fclayer
from net.activation import ReLU
from net.flatten import flatten_layer
import numpy as np
import pickle
from net.layernorm import layer_norm
from PatchEmbed import PatchEmbed_flatten, PatchEmbed_convolution
from Position_add import Position_learnable
from attention import attention_layer
from classify import classify_layer
from net.layernorm import layer_norm

from torchvision import datasets
from PIL import Image
import pandas as pd
from copy import deepcopy

abspath = os.path.abspath(__file__)
filename = abspath.split(os.sep)[-1]
abspath = abspath.replace(filename, "")

# https://en.wikipedia.org/wiki/AlexNet
# https://pytorch.org/vision/stable/_modules/torchvision/models/alexnet.html#alexnet
# https://github.com/l5shi/Image-Recognition-on-MNIST-dataset/blob/master/AlexNet.ipynb

def transformer_image_train(num_classes):
    epoch = 30
    batchsize = 100
    lr = 0.001
    embed_dim = 96
    images_shape = (batchsize, 1, 30-2, 30-2)
    n_patch = 7
    patchnorm = True
    # [0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 0], [1, 1, 1]
    fixed     = 1 #False
    cls_token = 0 #True
    num_h = [2*2] * 6 #[3, 6, 12, 3, 6, 12]
    patch_convolu = 0 #False

    if patch_convolu:
        choose = "_pc"
    else:
        choose = '_pf'
    if patchnorm:
        choose += "_pn"
    if fixed:
        choose += "_fixed"
    if cls_token:
        choose += "_clstoken"
    logfile = os.path.join(logdir, 'log_transformer_of_image_%s.txt'%choose)
    fpwrite = open(logfile, 'w', encoding='utf-8')

    if patch_convolu:
        patchemb = PatchEmbed_convolution(embed_dim, images_shape, n_patch, patchnorm = patchnorm)
    else:
        patchemb = PatchEmbed_flatten(embed_dim, images_shape, n_patch, patchnorm = patchnorm)
    positionL = Position_learnable(n_patch, embed_dim, fixed = fixed)
    att1 = attention_layer(embed_dim, num_h[0])
    att2 = attention_layer(embed_dim, num_h[1])
    att3 = attention_layer(embed_dim, num_h[2])
    layers = [patchemb, positionL, att1, att2, att3]

    att4 = attention_layer(embed_dim, num_h[3])
    att5 = attention_layer(embed_dim, num_h[4])
    att6 = attention_layer(embed_dim, num_h[5])
    layers += [att4, att5, att6]

    norm = layer_norm(embed_dim)
    flatten     = flatten_layer()
    cll = classify_layer(embed_dim, batchsize, n_patch, num_classes, cls_token)
    if not cls_token:
        layers += [norm, flatten, cll]
    else:
        layers += [norm, cll]
    
    datapath = os.path.join(abspath, 'dataset')
    os.makedirs(datapath, exist_ok=True)
    modelpath = os.path.join(abspath, 'model')
    os.makedirs(modelpath, exist_ok=True)

    if os.path.exists(pretrained_model):
        with open(pretrained_model, 'rb') as obj:
            models = pickle.load(obj)
        cnt = 0
        for l in layers:
            k = dir(l)
            if 'restore_model' in k and 'save_model' in k:
                l.restore_model(models[cnt])
                cnt += 1    

    datatest = datasets.MNIST(root = datapath, train=False, download=True)
    datatrain = datasets.MNIST(root = datapath, train=True, download=True)
    testdata, testlabel = datatest._load_data()
    datas, labels = datatrain._load_data()
    # */255
    testdata, testlabel = testdata.cpu().numpy() / 255, testlabel.cpu().numpy()
    datas, labels = datas.cpu().numpy() / 255, labels.cpu().numpy()
    #one-hot
    test_label = np.zeros((len(testlabel), 10))
    test_label[range(len(testlabel)), testlabel] = 1
    test_l = testlabel.copy()
    testlabel = test_label.copy()
    train_label = np.zeros((len(labels), 10))
    train_label[range(len(labels)), labels] = 1
    train_l = labels.copy()
    labels = train_label.copy()
    del test_label, train_label

    number_image = datas.shape[0]
    # for i in range(number_image):
    #     img = datas[i, :, :]
    #     Image.fromarray(img.cpu().numpy()).save(os.path.join(abspath, 'dataset', str(i) + ".jpg"))

    loss = 999999
    iters = number_image//batchsize + number_image%batchsize
    dot = np.power(0.001, 1/epoch)
    for i in range(epoch):
        meanloss = 0
        # if i!=0:
            # lr = lr * dot
        if i==20:
            lr = lr * 0.1
        elif i==26:
            lr = lr * 0.1
        k = np.arange(len(train_l))
        np.random.shuffle(k)
        datas = datas[k]
        labels = labels[k]
        
        train_l = train_l[k]
        for j in range(iters):
            images = datas[j*batchsize:(j+1)*batchsize, :, :]
            label = labels[j*batchsize:(j+1)*batchsize, :]
            label_single = train_l[j*batchsize:(j+1)*batchsize]
            images = images[:, np.newaxis, :, :]
            if len(images)==0:
                continue
            for l in range(len(layers)):
                if isinstance(layers[l], classify_layer):
                    if cls_token:
                        images = layers[l].forward(images[:, 0])
                    else:
                        images = layers[l].forward(images)
                else:
                    images = layers[l].forward(images)
            loss, delta, predict = cross_entropy_loss(images, label)
            meanloss += loss
            p = np.argmax(predict, axis=-1)
            precision = np.sum(label_single==p) / len(label_single)
                
            fpwrite.write("epoch:{}, lr: {:.6f}, loss: {:.6f}, iters: {}, precision: {:.6f}\n".format(i, lr, loss, j, precision))
            fpwrite.flush()
            for l in range(len(layers)-1, -1, -1):
                delta = layers[l].backward(delta)
                layers[l].update(lr)
                layers[l].setzero()
        acc = 0
        length = 0
        k = np.arange(len(testdata))
        # np.random.seed(999999666)
        np.random.shuffle(k)
        testdata = testdata[k]
        test_l = test_l[k]
        testlabel = testlabel[k]
        # if i==epoch-1:
        #     num = len(testdata)
        # else:
        num = len(testdata)//(1000)
        dic = {i:0 for i in range(10)}
        for j in range(num):
            images = testdata[j*batchsize:(j+1)*batchsize, :, :]
            images = images[:, np.newaxis, :, :]
            if images.shape[0]==0:
                continue
            label = testlabel[j * batchsize:(j+1) * batchsize, :]
            label_single = test_l[j * batchsize:(j+1) * batchsize]
            for l in range(len(layers)):
                if isinstance(layers[l], classify_layer):
                    if cls_token:
                        images = layers[l].forward(images[:, 0])
                    else:
                        images = layers[l].forward(images)
                else:
                    images = layers[l].forward(images)
            loss, delta, predict = cross_entropy_loss(images, label)
            p = np.argmax(predict, axis=-1)
            length += len(label_single)
            acc += np.sum(label_single==p)
            
            for ij in range(len(p)):
                if p[ij]==label_single[ij]:
                    dic[p[ij]] += 1
            
        precision = acc / length
        meanloss = meanloss / iters
        # savemodel
        allmodel = []
        for l in layers:
            k = dir(l)
            if 'restore_model' in k and 'save_model' in k:
                allmodel.append(l.save_model())
        name = "epoch_"+str(i)+"_loss_"+str(round(meanloss, 6))+"_pre_"+str(round(precision, 6))+"_%s.pkl"%choose
        
        with open(os.path.join(modelpath, name), 'wb') as obj:
            pickle.dump(allmodel, obj)
            
        # dic['precision'] = precision
        # df = pd.DataFrame(dic, index=np.arange(1)).T
        # df.to_csv(os.path.join(abspath, name.replace(".pkl", ".csv")), index=True)

        fpwrite.write("epoch: {}, testset precision: {}\n\n".format(i, precision))
        fpwrite.flush()
    fpwrite.close()

if __name__ =="__main__":
    savepath = abspath
    pretrained_model = r''
    logdir = os.path.join(savepath, 'log')
    os.makedirs(logdir, exist_ok=True)
    transformer_image_train(10)

'''
https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py
https://github.com/UdbhavPrasad072300/Transformer-Implementations/blob/main/notebooks/MNIST%20Classification%20-%20ViT.ipynb
https://github.com/s-chh/PyTorch-Vision-Transformer-ViT-MNIST/tree/main
https://itp.uni-frankfurt.de/~gros/StudentProjects/WS22_23_VisualTransformer/
https://jamesmccaffrey.wordpress.com/2023/01/10/a-naive-transformer-architecture-for-mnist-classification-using-pytorch/
https://medium.com/mlearning-ai/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c
https://github.com/BrianPulfer/PapersReimplementations/blob/main/vit/vit_torch.py
https://github.com/microsoft/Swin-Transformer
https://huggingface.co/docs/transformers/v4.27.0/model_doc/vit

vision_transformer/vit_jax/models_vit.py at main · google-research/vision_transformer (github.com)

Vision Transformers from Scratch (PyTorch): A step-by-step guide | by Brian Pulfer | MLearning.ai | Medium

s-chh/PyTorch-Vision-Transformer-ViT-MNIST: Simplified Pytorch implementation of Vision Transformer (ViT) for MNIST dataset. (github.com)

Transformer-Implementations/notebooks/MNIST Classification - ViT.ipynb at main · UdbhavPrasad072300/Transformer-Implementations (github.com)

Visual transformer on MNIST (uni-frankfurt.de)

A Naive Transformer Architecture for MNIST Classification Using PyTorch | James D. McCaffrey (wordpress.com)

PapersReimplementations/vit/vit_torch.py at main · BrianPulfer/PapersReimplementations (github.com)

microsoft/Swin-Transformer: This is an official implementation for “Swin Transformer: Hierarchical Vision Transformer using Shifted Windows”. (github.com)

https://zhuanlan.zhihu.com/p/645326689

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐