Top-1 错误率 = 1.0 - Top-1正确率

Top-5 错误率 = 1.0 - Top-5正确率 【Top-1,Top-5中的Top指的是一个图片中的概率前1和前5,不是所有图片中预测最好的1个或5个图片】

实现代码:

def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)
    # torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor)
    _, pred = output.topk(maxk, 1, True, True)  # Returns the k largest elements of the given input tensor along a given dimension.
    # _ 是dim=1维度下的max k个values;pred是dim=1维度下values的下标。参见:https://blog.csdn.net/u014264373/article/details/86525621
    pred = pred.t()
    # >> > b = torch.tensor([[2, 2], [3, 3], [5, 5]])
    # >> > print(b.size())
    # torch.Size([3, 2])
    # >> > a.expand_as(b)
    # tensor([[2, 2],
    #         [3, 3],
    #         [4, 4]])
    # >> > a
    # tensor([[2],
    #         [3],
    #         [4]])
    correct = pred.eq(target.view(1, -1).expand_as(pred))  # expand_as 把一个tensor变成和函数括号内一样形状的tensor,用法与expand()类似
    # outputs = torch.FloatTensor([[1], [2], [3]])
    # targets = torch.FloatTensor([[0], [2], [3]])
    # print(targets.eq(outputs.data)) = [0,1,1]

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0/batch_size))   # 如果batch_size = 10,表示10张图片一起技计算。那么correct_k是10张的正确度相加。
    return res
def evaluteTop1(model, loader):
    model.eval()
    
    correct = 0
    total = len(loader.dataset)

    for x,y in loader:
        x,y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()
        #correct += torch.eq(pred, y).sum().item()
    return correct / total

def evaluteTop5(model, loader):
    model.eval()
    correct = 0
    total = len(loader.dataset)
    for x, y in loader:
        x,y = x.to(device),y.to(device)
        with torch.no_grad():
            logits = model(x)
            maxk = max((1,5))        
            y_resize = y.view(-1,1)
            _, pred = logits.topk(maxk, 1, True, True)
            correct += torch.eq(pred, y_resize).sum().float().item()
    return correct / total

 

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐