【深度学习】多卡训练__单机多GPU方法详解(torch.nn.DataParallel、torch.distributed)
多GPU训练能够加快模型的训练速度,而且在单卡上不能训练的模型可以使用多个小卡达到训练的目的。多GPU训练可以分为单机多卡和多机多卡这两种,后面一种也就是分布式训练——训练方式比较麻烦,而且要关注的性能问题也有很多,据网上的资料有人建议能单机训练最好单机训练,不要使用多机训练。本文主要对单机多卡训练的实现展开说明。
【深度学习】多卡训练__单机多GPU详解(torch.nn.DataParallel、torch.distributed)
文章目录
1. 介绍
多GPU训练能够加快模型的训练速度,而且在单卡上不能训练的模型可以使用多个小卡达到训练的目的。
多GPU训练可以分为单机多卡和多机多卡这两种,后面一种也就是分布式训练——训练方式比较麻烦,而且要关注的性能问题也有很多,据网上的资料有人建议能单机训练最好单机训练,不要使用多机训练。本文主要对单机多卡训练的实现展开说明。
2. 单机多GPUの方法
2.1 方法1:torch.nn.DataParallel
这是最简单最直接的方法,代码中只需要一句代码就可以完成单卡多GPU训练了。其他的代码和单卡单GPU训练是一样的。
2.1.1 API
import torch
torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)
- 1
- 2
参数:
- module:即模型,此处注意,虽然输入数据被均分到不同gpu上,但每个gpu上都要拷贝一份模型。
- device_ids:即参与训练的gpu列表,例如三块卡, device_ids = [0,1,2]。
- output_device:指定输出gpu,一般省略。在省略的情况下,默认为第一块卡,即索引为0的卡。此处有一个问题,输入计算是被几块卡均分的,但输出loss的计算是由这一张卡独自承担的,这就造成这张卡所承受的计算量要大于其他参与训练的卡。
- dim:其表示tensors被分散的维度,默认是0,nn.DataParallel将在dim0(批处理维度)中对数据进行分块,并将每个分块发送到相应的设备。
2.1.2 特点
- 优点:特别简单,实现起来容易;
- 缺点:也很明显,就是每个batch中,模型的权重都是在单一的线程上算出来的,然后分发到多个GPU上,这里就有一个GPU通信瓶颈,使得GPU的利用率不是很高,模型训练的速度也不快。
2.1.3 例子与解释
import torch
net = torch.nn.Linear(100,1)
print(net)
print('---------------------')
net = torch.nn.DataParallel(net, device_ids=[0,3])
print(net)
‘’‘输出’‘’
Linear(in_features=10, out_features=1, bias=True)
---------------------
DataParallel(
(module): Linear(in_features=10, out_features=1, bias=True)
)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
nn.DataParallel是怎么做的:
- 首先在前向过程中,
- 你的输入数据会被划分成多个子部分(以下称为副本)送到不同的device中进行计算,
- 而你的模型module是在每个device上进行复制一份,也就是说,输入的batch是会被平均分到每个device中去,但是你的模型module是要拷贝到每个devide中去的,每个模型module只需要处理每个副本即可,当然你要保证你的batch size大于你的gpu个数。
- 然后在反向传播过程中,每个副本的梯度被累加到原始模块中。
概括来说就是:DataParallel会自动帮我们将数据切分 load 到相应 GPU,将模型复制到相应 GPU,进行正向传播计算梯度并汇总。
还有一句话,官网中是这样描述的:
The parallelized module must have its parameters and buffers on device_ids[0] before running this DataParallel module.
意思是:在运行此DataParallel模块之前,并行化模块必须在device_ids [0]上具有其参数和缓冲区。在执行DataParallel之前,会首先把其模型的参数放在device_ids[0]上,一看好像也没有什么毛病,其实有个小坑。
- 举个例子,服务器是八卡的服务器,刚好前面序号是0的卡被别人占用着,于是你只能用其他的卡来,比如你用2和3号卡,如果你直接指定device_ids=[2, 3]的话会出现模型初始化错误,类似于module没有复制到在device_ids[0]上去。那么你需要在运行train之前需要添加如下两句话指定程序可见的devices,如下:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2, 3"
- 1
- 2
当你添加这两行代码后,那么device_ids[0]默认的就是第2号卡,你的模型也会初始化在第2号卡上了,而不会占用第0号卡了。
- 也就是设置上面两行代码后,那么对这个程序而言可见的只有2和3号卡,和其他的卡没有关系,这是物理上的号卡,逻辑上来说其实是对应0和1号卡,即device_ids[0]对应的就是第2号卡,device_ids[1]对应的就是第3号卡。
当然你要保证上面这两行代码需要定义在下面这两行代码之前,一般放在train.py中import一些package之后:
device_ids = [0, 1]
net = torch.nn.DataParallel(net, device_ids=device_ids)
- 1
- 2
优化器同样可以使用nn.DataParallel,如下两行代码:
optimizer = torch.optim.SGD(net.parameters(), lr=lr)
optimizer = nn.DataParallel(optimizer, device_ids=device_ids)
- 1
- 2
2.1.4 说明
1)多GPU计算减少了程序运行的时间?
很多时候发现在进行多GPU运算时,程序花费的时间反而更多了,这其实是因为你的batch_size太小了,因为torch.nn.DataParallel()这个函数是将每个batch的数据平均拆开分配到多个GPU上进行计算,计算完再返回来合并。这导致GPU之间的开关和通讯过程占了大部分的时间开销。
我们可以使用 watch -n 1 nvidia-smi
这个命令来查看每1s各个GPU的运行情况,如果发现每个GPU的占用率均低于50%,基本可以肯定你使用多GPU计算所花的时间要比单GPU计算花的时间更长了。
2)如何保存和加载多GPU网络?
- 如何来保存和加载多GPU网络,它与普通网络有一点细微的不同:
import torch
net = torch.nn.Linear(10,1) # 先构造一个网络
net = torch.nn.DataParallel(net, device_ids=[0,3]) #包裹起来
torch.save(net.module.state_dict(), './networks/multiGPU.h5') #保存网络
# 加载网络
new_net = torch.nn.Linear(10,1)
new_net.load_state_dict(torch.load(“./networks/multiGPU.h5”))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
因为DataParallel实际上是一个nn.Module,所以我们在保存时需要多调用了一个net.module,模型和优化器都需要使用net.module来得到实际的模型和优化器。
3)为什么第一块卡的显存会占用的更多一些?
最后一个参数output_device一般情况下是省略不写的,那么默认就是在device_ids[0],也就是第一块卡上,也就解释了为什么第一块卡的显存会占用的比其他卡要更多一些。
- 也就是当你调用nn.DataParallel的时候,只是在你的input数据是并行的,但是你的output loss却不是这样的,每次都会在第一块GPU相加计算,这就造成了第一块GPU的负载远远大于剩余其他的显卡。
4)直接使用nn.DataParallel的时候,训练采用多卡训练,会出现一个warning?
UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars;
will instead unsqueeze and return a vector.
- 1
- 2
说明:
- 每张卡上的loss都是要汇总到第0张卡上求梯度,更新好以后把权重分发到其余卡。但是为什么会出现这个warning,这其实和nn.DataParallel中最后一个参数dim有关,
- 其表示tensors被分散的维度,默认是0,nn.DataParallel将在dim0(批处理维度)中对数据进行分块,并将每个分块发送到相应的设备。
- 单卡的没有这个warning,多卡的时候采用nn.DataParallel训练会出现这个warning,由于计算loss的时候是分别在多卡计算的,那么返回的也就是多个loss,你使用了多少个gpu,就会返回多少个loss。(有人建议DataParallel类应该有reduce和size_average参数,比如用于聚合输出的不同loss函数,最终返回一个向量,有多少个gpu,返回的向量就有几维。)
关于这个问题在pytorch官网的issues上有过讨论,下面简单摘出一些:
- 有人提出求loss平均的方式会在不同数量的gpu上训练会以微妙的方式影响结果。模块返回该batch中所有损失的平均值,如果在4个gpu上运行,将返回4个平均值的向量。然后取这个向量的平均值。但是,如果在3个GPU或单个GPU上运行,这将不是同一个数字,因为每个GPU处理的batch size不同!
- 举个简单的例子(就直接摘原文出来):
A batch of 3 would be calculated on a single GPU and results
would be [0.3, 0.2, 0.8] and model that returns the loss would return 0.43.
If cast to DataParallel, and calculated on 2 GPUs, [GPU1 - batch 0,1], [GPU2 - batch 2]
- return values would be [0.25, 0.8] (0.25 is average between 0.2 and 0.3)
- taking the average loss of [0.25, 0.8] is now 0.525!
Calculating on 3 GPUs, one gets [0.3, 0.2, 0.8] as results and average is back to 0.43!
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
这么求平均loss确实有不合理的地方。那么有什么好的解决办法呢,可以使用size_average=False,reduce=True作为参数。每个GPU上的损失将相加,但不除以GPU上的批大小。然后将所有平行损耗相加,除以整批的大小,那么不管几块GPU最终得到的平均loss都是一样的。
pytorch贡献者也实现了这个loss求平均的功能,即通过gather的方式来求loss平均:
https://github.com/pytorch/pytorch/pull/7973/commits/c285b3626a7a4dcbbddfba1a6b217a64a3f3f3be
如果它们在一个有2个GPU的系统上运行,DP将采用多GPU路径,调用gather并返回一个向量。如果运行时有1个GPU可见,DP将采用顺序路径,完全忽略gather,因为这是不必要的,并返回一个标量。
2.2 方法2:torch.nn.parallel.DistributedDataParallel
这种方法旨在缓解nn.DataParallel方法GPU使用效率低的问题。
- 这方法会使得GPU的显存分配更加平衡一点,
- 同时这个方法是多线程的,显卡的利用效率自然也就高一点。
2.2.1 API
1)首先第一步就是要进行init_process_group的初始化,声明GPU的NCCL通信方式。
import torch
torch.distributed.init_process_group(backend='nccl')
- 1
- 2
2)其次,由于是多线程的,因此数据加载和模型加载也要做对应的修改如下:
train_data = ReadDataSet('train.tsv', args, sentences_count = None)
train_sample = torch.utils.data.distributed.DistributedSampler(train_data)
train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=(train_sample is None), sampler=train_sample)
model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) #多进程多GPU并行
- 1
- 2
- 3
- 4
- 5
2.2.2 注意事项
1)首先就是代码使用bash脚本启动的时候是不一样的,要像下面这么定义:
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 main_gpus.py \
# 后面加一些你要传入的参数
- 1
- 2
在 pyton关键字之前把可用显卡号用它CUDA_VISIBLE_DEVICES=0,1来指定;同时python关键字之后-m torch.distributed.launch --nproc_per_node=2 来指定 分布式启动和采用的节点数,也就是有几个显卡也就用几个节点。
2)其次就是日志信息的打印,在代码中直接打印的话就会打印nproc_per_node次指定的输出信息,这个时候就需要指定进程号。
if (step+1)%200 == 0 and args.local_rank==0:
print('Train Epoch[{}/{}],step[{}/{}],tra_acc{:.6f} %,loss:{:.6f}'.format(epoch,epochs,step,len(train_iter),two_pro_train_acc*100,two_pro_loss))
- 1
- 2
这样就只会打印进程为0的对应各种信息。
3)再次就是loss和准确率的合并,这里有多个线程,肯定就需要对一个batch中多个线程对应的不同loss和准确率进行合并。实现方式如下:
def reduce_tensor(tensor: torch.Tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= dist.get_world_size() # 总进程数
return rt
- 1
- 2
- 3
- 4
- 5
把各自的loss或者accuracy做分布式操作的加法,然后在求平均值。
4)最后,关于batch_size和lr的设置,这里一般可以采用batch_size = n*batch_size_base的方式;而lr = (1,n)*lr_base的方式。
2.2.3 主要代码(可以参照改成自己的)
import torch
from torch import nn
from unet.unet_transfer import UNet16, UNetResNet
from pathlib import Path
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, random_split
import torch.nn.functional as F
from torch.autograd import Variable
import shutil
from data_loader import ImgDataSet
import os
import argparse
import tqdm
import numpy as np
import scipy.ndimage as ndimage
import torch.distributed as dist
class AverageMeter(object):
“”“Computes and stores the average and current value”“”
def init(self):
self.reset()
<span class="token keyword">def</span> <span class="token function">reset</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
self<span class="token punctuation">.</span>val <span class="token operator">=</span> <span class="token number">0</span>
self<span class="token punctuation">.</span>avg <span class="token operator">=</span> <span class="token number">0</span>
self<span class="token punctuation">.</span><span class="token builtin">sum</span> <span class="token operator">=</span> <span class="token number">0</span>
self<span class="token punctuation">.</span>count <span class="token operator">=</span> <span class="token number">0</span>
<span class="token keyword">def</span> <span class="token function">update</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> val<span class="token punctuation">,</span> n<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
self<span class="token punctuation">.</span>val <span class="token operator">=</span> val
self<span class="token punctuation">.</span><span class="token builtin">sum</span> <span class="token operator">+=</span> val <span class="token operator">*</span> n
self<span class="token punctuation">.</span>count <span class="token operator">+=</span> n
self<span class="token punctuation">.</span>avg <span class="token operator">=</span> self<span class="token punctuation">.</span><span class="token builtin">sum</span> <span class="token operator">/</span> self<span class="token punctuation">.</span>count
def create_model(type =‘vgg16’):
if type ‘vgg16’:
print(‘create vgg16 model’)
model = UNet16(pretrained=True)
elif type ‘resnet101’:
encoder_depth = 101
num_classes = 1
print(‘create resnet101 model’)
model = UNetResNet(encoder_depth=encoder_depth, num_classes=num_classes, pretrained=True)
elif type == ‘resnet34’:
encoder_depth = 34
num_classes = 1
print(‘create resnet34 model’)
model = UNetResNet(encoder_depth=encoder_depth, num_classes=num_classes, pretrained=True)
else:
assert False
model.eval()
return model
def adjust_learning_rate(optimizer, epoch, lr):
“”“Sets the learning rate to the initial LR decayed by 10 every 30 epochs”“”
lr = lr * (0.1 ** (epoch // 30))
for param_group in optimizer.param_groups:
param_group[‘lr’] = lr
def find_latest_model_path(dir):
model_paths = []
epochs = []
for path in Path(dir).glob(‘*.pt’):
if ‘epoch’ not in path.stem:
continue
model_paths.append(path)
parts = path.stem.split(‘_’)
epoch = int(parts[-1])
epochs.append(epoch)
<span class="token keyword">if</span> <span class="token builtin">len</span><span class="token punctuation">(</span>epochs<span class="token punctuation">)</span> <span class="token operator">></span> <span class="token number">0</span><span class="token punctuation">:</span>
epochs <span class="token operator">=</span> np<span class="token punctuation">.</span>array<span class="token punctuation">(</span>epochs<span class="token punctuation">)</span>
max_idx <span class="token operator">=</span> np<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>epochs<span class="token punctuation">)</span>
<span class="token keyword">return</span> model_paths<span class="token punctuation">[</span>max_idx<span class="token punctuation">]</span>
<span class="token keyword">else</span><span class="token punctuation">:</span>
<span class="token keyword">return</span> <span class="token boolean">None</span>
def train(train_loader, model, criterion, optimizer, validation, args):
latest_model_path = find_latest_model_path(args.model_dir)
best_model_path <span class="token operator">=</span> os<span class="token punctuation">.</span>path<span class="token punctuation">.</span>join<span class="token punctuation">(</span><span class="token operator">*</span><span class="token punctuation">[</span>args<span class="token punctuation">.</span>model_dir<span class="token punctuation">,</span> <span class="token string">'model_best.pt'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> latest_model_path <span class="token keyword">is</span> <span class="token keyword">not</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
state <span class="token operator">=</span> torch<span class="token punctuation">.</span>load<span class="token punctuation">(</span>latest_model_path<span class="token punctuation">)</span>
epoch <span class="token operator">=</span> state<span class="token punctuation">[</span><span class="token string">'epoch'</span><span class="token punctuation">]</span>
model<span class="token punctuation">.</span>load_state_dict<span class="token punctuation">(</span>state<span class="token punctuation">[</span><span class="token string">'model'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
epoch <span class="token operator">=</span> epoch
<span class="token comment">#if latest model path does exist, best_model_path should exists as well</span>
<span class="token keyword">assert</span> Path<span class="token punctuation">(</span>best_model_path<span class="token punctuation">)</span><span class="token punctuation">.</span>exists<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">==</span> <span class="token boolean">True</span><span class="token punctuation">,</span> <span class="token string-interpolation"><span class="token string">f'best model path </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>best_model_path<span class="token punctuation">}</span></span><span class="token string"> does not exist'</span></span>
<span class="token comment">#load the min loss so far</span>
best_state <span class="token operator">=</span> torch<span class="token punctuation">.</span>load<span class="token punctuation">(</span>latest_model_path<span class="token punctuation">)</span>
min_val_los <span class="token operator">=</span> best_state<span class="token punctuation">[</span><span class="token string">'valid_loss'</span><span class="token punctuation">]</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f'Restored model at epoch </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>epoch<span class="token punctuation">}</span></span><span class="token string">. Min validation loss so far is : </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>min_val_los<span class="token punctuation">}</span></span><span class="token string">'</span></span><span class="token punctuation">)</span>
epoch <span class="token operator">+=</span> <span class="token number">1</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f'Started training model from epoch </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>epoch<span class="token punctuation">}</span></span><span class="token string">'</span></span><span class="token punctuation">)</span>
<span class="token keyword">else</span><span class="token punctuation">:</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'Started training model from epoch 0'</span><span class="token punctuation">)</span>
epoch <span class="token operator">=</span> <span class="token number">0</span>
min_val_los <span class="token operator">=</span> <span class="token number">9999</span>
valid_losses <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>
<span class="token keyword">for</span> epoch <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>epoch<span class="token punctuation">,</span> args<span class="token punctuation">.</span>n_epoch <span class="token operator">+</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
adjust_learning_rate<span class="token punctuation">(</span>optimizer<span class="token punctuation">,</span> epoch<span class="token punctuation">,</span> args<span class="token punctuation">.</span>lr<span class="token punctuation">)</span>
tq <span class="token operator">=</span> tqdm<span class="token punctuation">.</span>tqdm<span class="token punctuation">(</span>total<span class="token operator">=</span><span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>train_loader<span class="token punctuation">)</span> <span class="token operator">*</span> args<span class="token punctuation">.</span>batch_size<span class="token punctuation">)</span><span class="token punctuation">)</span>
tq<span class="token punctuation">.</span>set_description<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f'Epoch </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>epoch<span class="token punctuation">}</span></span><span class="token string">'</span></span><span class="token punctuation">)</span>
losses <span class="token operator">=</span> AverageMeter<span class="token punctuation">(</span><span class="token punctuation">)</span>
model<span class="token punctuation">.</span>train<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">for</span> i<span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token builtin">input</span><span class="token punctuation">,</span> target<span class="token punctuation">)</span> <span class="token keyword">in</span> <span class="token builtin">enumerate</span><span class="token punctuation">(</span>train_loader<span class="token punctuation">)</span><span class="token punctuation">:</span>
two_pro_loss <span class="token operator">=</span> <span class="token number">0</span>
input_var <span class="token operator">=</span> Variable<span class="token punctuation">(</span><span class="token builtin">input</span><span class="token punctuation">)</span><span class="token punctuation">.</span>cuda<span class="token punctuation">(</span>args<span class="token punctuation">.</span>local_rank<span class="token punctuation">,</span> non_blocking<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
target_var <span class="token operator">=</span> Variable<span class="token punctuation">(</span>target<span class="token punctuation">)</span><span class="token punctuation">.</span>cuda<span class="token punctuation">(</span>args<span class="token punctuation">.</span>local_rank<span class="token punctuation">,</span> non_blocking<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
masks_pred <span class="token operator">=</span> model<span class="token punctuation">(</span>input_var<span class="token punctuation">)</span>
masks_probs_flat <span class="token operator">=</span> masks_pred<span class="token punctuation">.</span>view<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span>
true_masks_flat <span class="token operator">=</span> target_var<span class="token punctuation">.</span>view<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span>
loss <span class="token operator">=</span> criterion<span class="token punctuation">(</span>masks_probs_flat<span class="token punctuation">,</span> true_masks_flat<span class="token punctuation">)</span>
two_pro_loss <span class="token operator">+=</span> reduce_tensor<span class="token punctuation">(</span>loss<span class="token punctuation">)</span><span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># 有多个进程,把进程0和1的loss加起来平均</span>
losses<span class="token punctuation">.</span>update<span class="token punctuation">(</span>two_pro_loss<span class="token punctuation">)</span>
tq<span class="token punctuation">.</span>set_postfix<span class="token punctuation">(</span>loss<span class="token operator">=</span><span class="token string">'{:.5f}'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>losses<span class="token punctuation">.</span>avg<span class="token punctuation">)</span><span class="token punctuation">)</span>
tq<span class="token punctuation">.</span>update<span class="token punctuation">(</span>args<span class="token punctuation">.</span>batch_size<span class="token punctuation">)</span>
<span class="token comment"># compute gradient and do SGD step</span>
optimizer<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span>
loss<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span>
optimizer<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span>
valid_metrics <span class="token operator">=</span> validation<span class="token punctuation">(</span>model<span class="token punctuation">,</span> valid_loader<span class="token punctuation">,</span> criterion<span class="token punctuation">)</span>
valid_loss <span class="token operator">=</span> valid_metrics<span class="token punctuation">[</span><span class="token string">'valid_loss'</span><span class="token punctuation">]</span>
valid_losses<span class="token punctuation">.</span>append<span class="token punctuation">(</span>valid_loss<span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f'\tvalid_loss = </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>valid_loss<span class="token punctuation">:</span><span class="token format-spec">.5f</span><span class="token punctuation">}</span></span><span class="token string">'</span></span><span class="token punctuation">)</span>
tq<span class="token punctuation">.</span>close<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token comment">#save the model of the current epoch</span>
epoch_model_path <span class="token operator">=</span> os<span class="token punctuation">.</span>path<span class="token punctuation">.</span>join<span class="token punctuation">(</span><span class="token operator">*</span><span class="token punctuation">[</span>args<span class="token punctuation">.</span>model_dir<span class="token punctuation">,</span> <span class="token string-interpolation"><span class="token string">f'model_epoch_</span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>epoch<span class="token punctuation">}</span></span><span class="token string">.pt'</span></span><span class="token punctuation">]</span><span class="token punctuation">)</span>
torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span><span class="token punctuation">{<!-- --></span>
<span class="token string">'model'</span><span class="token punctuation">:</span> model<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
<span class="token string">'epoch'</span><span class="token punctuation">:</span> epoch<span class="token punctuation">,</span>
<span class="token string">'valid_loss'</span><span class="token punctuation">:</span> valid_loss<span class="token punctuation">,</span>
<span class="token string">'train_loss'</span><span class="token punctuation">:</span> losses<span class="token punctuation">.</span>avg
<span class="token punctuation">}</span><span class="token punctuation">,</span> epoch_model_path<span class="token punctuation">)</span>
<span class="token keyword">if</span> valid_loss <span class="token operator"><</span> min_val_los<span class="token punctuation">:</span>
min_val_los <span class="token operator">=</span> valid_loss
torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span><span class="token punctuation">{<!-- --></span>
<span class="token string">'model'</span><span class="token punctuation">:</span> model<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
<span class="token string">'epoch'</span><span class="token punctuation">:</span> epoch<span class="token punctuation">,</span>
<span class="token string">'valid_loss'</span><span class="token punctuation">:</span> valid_loss<span class="token punctuation">,</span>
<span class="token string">'train_loss'</span><span class="token punctuation">:</span> losses<span class="token punctuation">.</span>avg
<span class="token punctuation">}</span><span class="token punctuation">,</span> best_model_path<span class="token punctuation">)</span>
def reduce_tensor(tensor: torch.Tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= dist.get_world_size() # 总进程数
return rt
def validate(model, val_loader, criterion):
losses = AverageMeter()
model.eval()
with torch.no_grad():
for i, (input, target) in enumerate(val_loader):
two_pro_loss = 0
input_var = Variable(input).cuda(args.local_rank, non_blocking=True)
target_var = Variable(target).cuda(args.local_rank, non_blocking=True)
output <span class="token operator">=</span> model<span class="token punctuation">(</span>input_var<span class="token punctuation">)</span>
loss <span class="token operator">=</span> criterion<span class="token punctuation">(</span>output<span class="token punctuation">,</span> target_var<span class="token punctuation">)</span>
two_pro_loss <span class="token operator">+=</span> reduce_tensor<span class="token punctuation">(</span>loss<span class="token punctuation">)</span>
losses<span class="token punctuation">.</span>update<span class="token punctuation">(</span>loss<span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> input_var<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token keyword">return</span> <span class="token punctuation">{<!-- --></span><span class="token string">'valid_loss'</span><span class="token punctuation">:</span> losses<span class="token punctuation">.</span>avg<span class="token punctuation">}</span>
def save_check_point(state, is_best, file_name = ‘checkpoint.pth.tar’):
torch.save(state, file_name)
if is_best:
shutil.copy(file_name, ‘model_best.pth.tar’)
def calc_crack_pixel_weight(mask_dir):
avg_w = 0.0
n_files = 0
for path in Path(mask_dir).glob(‘.’):
n_files += 1
m = ndimage.imread(path)
ncrack = np.sum((m > 0)[:])
w = float(ncrack)/(m.shape[0]*m.shape[1])
avg_w = avg_w + (1-w)
avg_w <span class="token operator">/=</span> <span class="token builtin">float</span><span class="token punctuation">(</span>n_files<span class="token punctuation">)</span>
<span class="token keyword">return</span> avg_w <span class="token operator">/</span> <span class="token punctuation">(</span><span class="token number">1.0</span> <span class="token operator">-</span> avg_w<span class="token punctuation">)</span>
if name == ‘main’:
parser = argparse.ArgumentParser(description=‘PyTorch ImageNet Training’)
parser.add_argument(‘–n_epoch’, default=10, type=int, metavar=‘N’, help=‘number of total epochs to run’)
parser.add_argument(‘–lr’, default=0.001, type=float, metavar=‘LR’, help=‘initial learning rate’)
parser.add_argument(‘–momentum’, default=0.9, type=float, metavar=‘M’, help=‘momentum’)
parser.add_argument(‘–print_freq’, default=20, type=int, metavar=‘N’, help=‘print frequency (default: 10)’)
parser.add_argument(‘–weight_decay’, default=1e-4, type=float, metavar=‘W’, help=‘weight decay (default: 1e-4)’)
parser.add_argument(‘–batch_size’, default=4, type=int, help=‘weight decay (default: 1e-4)’)
parser.add_argument(‘–num_workers’, default=4, type=int, help=‘num_workers’)
parser.add_argument(‘–data_dir’,type=str, default=‘dataset’, help=‘input dataset directory’)
parser.add_argument(‘–model_dir’, type=str, default=‘model’, help=‘output model directory’)
parser.add_argument(‘–model_type’, type=str, required=False, default=‘vgg16’, choices=[‘vgg16’, ‘resnet101’, ‘resnet34’])
parser.add_argument(‘–local_rank’, type=int, default=-1)
args = parser.parse_args()
torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>set_device<span class="token punctuation">(</span>args<span class="token punctuation">.</span>local_rank<span class="token punctuation">)</span>
<span class="token comment"># os.environ['LOCAL_RANK'] = -1</span>
dist<span class="token punctuation">.</span>init_process_group<span class="token punctuation">(</span>backend<span class="token operator">=</span><span class="token string">'nccl'</span><span class="token punctuation">)</span>
os<span class="token punctuation">.</span>makedirs<span class="token punctuation">(</span>args<span class="token punctuation">.</span>model_dir<span class="token punctuation">,</span> exist_ok<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
DIR_IMG <span class="token operator">=</span> os<span class="token punctuation">.</span>path<span class="token punctuation">.</span>join<span class="token punctuation">(</span>args<span class="token punctuation">.</span>data_dir<span class="token punctuation">,</span> <span class="token string">'images'</span><span class="token punctuation">)</span>
DIR_MASK <span class="token operator">=</span> os<span class="token punctuation">.</span>path<span class="token punctuation">.</span>join<span class="token punctuation">(</span>args<span class="token punctuation">.</span>data_dir<span class="token punctuation">,</span> <span class="token string">'masks'</span><span class="token punctuation">)</span>
img_names <span class="token operator">=</span> <span class="token punctuation">[</span>path<span class="token punctuation">.</span>name <span class="token keyword">for</span> path <span class="token keyword">in</span> Path<span class="token punctuation">(</span>DIR_IMG<span class="token punctuation">)</span><span class="token punctuation">.</span>glob<span class="token punctuation">(</span><span class="token string">'*.jpg'</span><span class="token punctuation">)</span><span class="token punctuation">]</span>
mask_names <span class="token operator">=</span> <span class="token punctuation">[</span>path<span class="token punctuation">.</span>name <span class="token keyword">for</span> path <span class="token keyword">in</span> Path<span class="token punctuation">(</span>DIR_MASK<span class="token punctuation">)</span><span class="token punctuation">.</span>glob<span class="token punctuation">(</span><span class="token string">'*.jpg'</span><span class="token punctuation">)</span><span class="token punctuation">]</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f'total images = </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span><span class="token builtin">len</span><span class="token punctuation">(</span>img_names<span class="token punctuation">)</span><span class="token punctuation">}</span></span><span class="token string">'</span></span><span class="token punctuation">)</span>
model <span class="token operator">=</span> create_model<span class="token punctuation">(</span>args<span class="token punctuation">.</span>model_type<span class="token punctuation">)</span>
optimizer <span class="token operator">=</span> torch<span class="token punctuation">.</span>optim<span class="token punctuation">.</span>SGD<span class="token punctuation">(</span>model<span class="token punctuation">.</span>parameters<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> args<span class="token punctuation">.</span>lr<span class="token punctuation">,</span>
momentum<span class="token operator">=</span>args<span class="token punctuation">.</span>momentum<span class="token punctuation">,</span>
weight_decay<span class="token operator">=</span>args<span class="token punctuation">.</span>weight_decay<span class="token punctuation">)</span>
criterion <span class="token operator">=</span> nn<span class="token punctuation">.</span>BCEWithLogitsLoss<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token comment"># ori</span>
channel_means <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token number">0.485</span><span class="token punctuation">,</span> <span class="token number">0.456</span><span class="token punctuation">,</span> <span class="token number">0.406</span><span class="token punctuation">]</span>
channel_stds <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token number">0.229</span><span class="token punctuation">,</span> <span class="token number">0.224</span><span class="token punctuation">,</span> <span class="token number">0.225</span><span class="token punctuation">]</span>
<span class="token comment"># dam</span>
<span class="token comment"># channel_means = [0.595, 0.608, 0.604]</span>
<span class="token comment"># channel_stds = [0.047, 0.047, 0.047]</span>
train_tfms <span class="token operator">=</span> transforms<span class="token punctuation">.</span>Compose<span class="token punctuation">(</span><span class="token punctuation">[</span>transforms<span class="token punctuation">.</span>ToTensor<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
transforms<span class="token punctuation">.</span>Normalize<span class="token punctuation">(</span>channel_means<span class="token punctuation">,</span> channel_stds<span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
val_tfms <span class="token operator">=</span> transforms<span class="token punctuation">.</span>Compose<span class="token punctuation">(</span><span class="token punctuation">[</span>transforms<span class="token punctuation">.</span>ToTensor<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
transforms<span class="token punctuation">.</span>Normalize<span class="token punctuation">(</span>channel_means<span class="token punctuation">,</span> channel_stds<span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
mask_tfms <span class="token operator">=</span> transforms<span class="token punctuation">.</span>Compose<span class="token punctuation">(</span><span class="token punctuation">[</span>transforms<span class="token punctuation">.</span>ToTensor<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
dataset <span class="token operator">=</span> ImgDataSet<span class="token punctuation">(</span>img_dir<span class="token operator">=</span>DIR_IMG<span class="token punctuation">,</span> img_fnames<span class="token operator">=</span>img_names<span class="token punctuation">,</span> img_transform<span class="token operator">=</span>train_tfms<span class="token punctuation">,</span> mask_dir<span class="token operator">=</span>DIR_MASK<span class="token punctuation">,</span> mask_fnames<span class="token operator">=</span>mask_names<span class="token punctuation">,</span> mask_transform<span class="token operator">=</span>mask_tfms<span class="token punctuation">)</span>
train_size <span class="token operator">=</span> <span class="token builtin">int</span><span class="token punctuation">(</span><span class="token number">0.85</span><span class="token operator">*</span><span class="token builtin">len</span><span class="token punctuation">(</span>dataset<span class="token punctuation">)</span><span class="token punctuation">)</span>
valid_size <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>dataset<span class="token punctuation">)</span> <span class="token operator">-</span> train_size
train_dataset<span class="token punctuation">,</span> valid_dataset <span class="token operator">=</span> random_split<span class="token punctuation">(</span>dataset<span class="token punctuation">,</span> <span class="token punctuation">[</span>train_size<span class="token punctuation">,</span> valid_size<span class="token punctuation">]</span><span class="token punctuation">)</span>
train_sample <span class="token operator">=</span> torch<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>data<span class="token punctuation">.</span>distributed<span class="token punctuation">.</span>DistributedSampler<span class="token punctuation">(</span>train_dataset<span class="token punctuation">)</span>
valid_sample <span class="token operator">=</span> torch<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>data<span class="token punctuation">.</span>distributed<span class="token punctuation">.</span>DistributedSampler<span class="token punctuation">(</span>valid_dataset<span class="token punctuation">)</span>
train_loader <span class="token operator">=</span> DataLoader<span class="token punctuation">(</span>train_dataset<span class="token punctuation">,</span> args<span class="token punctuation">.</span>batch_size<span class="token punctuation">,</span> shuffle<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> num_workers<span class="token operator">=</span>args<span class="token punctuation">.</span>num_workers<span class="token punctuation">,</span> sampler<span class="token operator">=</span>train_sample<span class="token punctuation">)</span>
valid_loader <span class="token operator">=</span> DataLoader<span class="token punctuation">(</span>valid_dataset<span class="token punctuation">,</span> args<span class="token punctuation">.</span>batch_size<span class="token punctuation">,</span> shuffle<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> num_workers<span class="token operator">=</span>args<span class="token punctuation">.</span>num_workers<span class="token punctuation">,</span> sampler<span class="token operator">=</span>valid_sample<span class="token punctuation">)</span>
model<span class="token punctuation">.</span>cuda<span class="token punctuation">(</span>args<span class="token punctuation">.</span>local_rank<span class="token punctuation">)</span>
train<span class="token punctuation">(</span>train_loader<span class="token punctuation">,</span> model<span class="token punctuation">,</span> criterion<span class="token punctuation">,</span> optimizer<span class="token punctuation">,</span> validate<span class="token punctuation">,</span> args<span class="token punctuation">)</span>
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117
- 118
- 119
- 120
- 121
- 122
- 123
- 124
- 125
- 126
- 127
- 128
- 129
- 130
- 131
- 132
- 133
- 134
- 135
- 136
- 137
- 138
- 139
- 140
- 141
- 142
- 143
- 144
- 145
- 146
- 147
- 148
- 149
- 150
- 151
- 152
- 153
- 154
- 155
- 156
- 157
- 158
- 159
- 160
- 161
- 162
- 163
- 164
- 165
- 166
- 167
- 168
- 169
- 170
- 171
- 172
- 173
- 174
- 175
- 176
- 177
- 178
- 179
- 180
- 181
- 182
- 183
- 184
- 185
- 186
- 187
- 188
- 189
- 190
- 191
- 192
- 193
- 194
- 195
- 196
- 197
- 198
- 199
- 200
- 201
- 202
- 203
- 204
- 205
- 206
- 207
- 208
- 209
- 210
- 211
- 212
- 213
- 214
- 215
- 216
- 217
- 218
- 219
- 220
- 221
- 222
- 223
- 224
- 225
- 226
- 227
- 228
- 229
- 230
- 231
- 232
- 233
- 234
- 235
- 236
- 237
- 238
- 239
- 240
- 241
- 242
- 243
- 244
- 245
- 246
- 247
- 248
- 249
- 250
- 251
- 252
- 253
- 254
- 255
- 256
- 257
- 258
- 259
- 260
- 261
- 262
- 263
- 264
- 265
- 266
- 267
- 268
- 269
- 270
- 271
!!!!!!!你可能会运行出错,记得一定要按照下面的去运行代码!!!!!!!!!
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 main_gpus.py \
# 后面加一些你要传入的参数
- 1
- 2
2.2.4 对比
我自己的跑成功了,但是还是很慢,不过由原来的3小时变成52min一个epoch。
这里直接贴上 这位大佬 的对比吧。
- 单GPU batch_size = 8 lr_base
时间为208s,验证集最高准确率是59.2%
-
多GPU batch_size = 16 lr=[1,n]*lr_base(方法:nn.DataParallel(model))
其实在每个forward中batch_size 是8;按照经验来说lr应该是要扩大相应的倍数的
总时间是140s,最高准确率是61.4%,相比单卡速度提升了48.6%,耗费时间减少了32.7%。
- 多进程多GPU并行 batch_size = 8 lr=lr_base (方法:model=nn.parallel.DistributedDataParallel(model,device_ids=[args.local_rank]))
这里的lr并没有变化
时间118s,速度提升了76.3%,耗费时间减少了43.3%,验证集准确率62.8%
如果要采用单机多卡训练模型,无疑是采用nn.parallel.DistributedDataParallel这种方式,速度最快;有限时间内,训练效果最好。
3. 单机多卡训练下的加速trick——梯度累加
单机多卡训练模型加速方式有采用混合精度和梯度累加等。这里只有梯度累加能够起加速作用的训练是多卡训练才能享受到的,单卡并不能加速。简单的分析就是,多卡训练需要一个梯度同步的过程,就是GPU之间在每一个batch的计算上都会进行通信,这个时间就会导致训练处于等待状态。而梯度累加就是变相增大batch_size,减小batch数目,从而减少GPU之间的通信,起到加速作用。当然梯度累加的代码实现也比较简单,正常的训练代码:
for i, (inputs, labels) in enumerate(training_set):
loss = model(inputs, labels) # 计算loss
optimizer.zero_grad() # 清空梯度
loss.backward() # 反向计算梯度
optimizer.step() # 更新参数
- 1
- 2
- 3
- 4
- 5
使用梯度累加:
for i, (inputs, labels) in enumerate(training_set):
loss = model(inputs, labels) # 计算loss
loss = loss / accumulation_steps # Normalize our loss (if averaged)
loss.backward() # 反向计算梯度,累加到之前梯度上
if (i+1) % accumulation_steps == 0:
optimizer.step() # 更新参数
model.zero_grad() # 清空梯度
- 1
- 2
- 3
- 4
- 5
- 6
- 7
当然这里的效果取决于模型的大小,模型越大收益越大。
下面是采用robert_large模型、2张3090显卡做文本2分类的一个速度(1W训练集和1K验证集):
单机多卡
--accumulation_steps 2
roberta_large textclassification task train and dev 2 epochs with grad accmulation time is 686.4237
tra_acc73.325002 %,dev_acc76.400002 %,best_dev_acc76.400002 %
*******************************************************************************
--accumulation_steps 5
roberta_large textclassification task train and dev 2 epochs with grad accmulation time is 578.8834
tra_acc73.329997 %,dev_acc75.500000 %,best_dev_acc76.100006 %
***
--accumulation_steps 10
roberta_large textclassification task train and dev 2 epochs with grad accmulation time is 579.5692
tra_acc71.015000 %,dev_acc75.400402 %,best_dev_acc77.300002 %
***
--accumulation_steps 20
roberta_large textclassification task train and dev 2 epochs with grad accmulation time is 613.6300s
tra_acc64.775002 %,dev_acc78.199997 %,best_dev_acc78.199997 %
***
--accumulation_steps 20
roberta_large textclassification task train and dev 2 epochs with grad accmulation time is 580.7058
tra_acc64.754999 %,dev_acc77.400002 %,best_dev_acc77.400002 %
***
--accumulation_steps 50
roberta_large textclassification task train and dev 2 epochs with grad accmulation time is 621.0073
tra_acc53.034997 %,dev_acc71.900002 %,best_dev_acc71.900002 %
***
--accumulation_steps 80
roberta_large textclassification task train and dev 2 epochs with grad accmulation time is 568.5933
tra_acc43.325001 %,dev_acc67.199997 %,best_dev_acc67.199997 %
***
--accumulation_steps 80
roberta_large textclassification task train and dev 2 epochs with grad accmulation time is 575.0746
tra_acc44.005001 %,dev_acc67.500000 %,best_dev_acc67.500000 %
***
--accumulation_steps 0
roberta_large textclassification task train and dev 2 epochs time is 718.4363s
tra_acc74.285001 %,dev_acc73.199997 %,best_dev_acc73.199997 %
***
--accumulation_steps 0
roberta_large textclassification task train and dev 2 epochs time is 694.9744
tra_acc74.559999 %,dev_acc74.000000 %,best_dev_acc74.000000 %
单卡单GPU
***
trian and eval model time is 1023.3577s
tra_acc64.715000 %,dev_acc71.400000 %,best_dev_acc71.400000 %
***
trian and eval model time is 1034.7063
tra_acc72.760000 %,dev_acc74.300000 %,best_dev_acc74.300000 %
***
结论:
单卡3090耗时:1029s
双卡3090耗时:707s——提升:45.5%
双卡3090+梯度累加耗时: 580s——提升77.4%,21.9%
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
可以看到在当前数据集和模型的情况下,accumulation_step = 10 可以取得最好的效果,相对于单卡提速77.4%;双卡梯度累加相对于双卡不采用梯度累加提速21.9%,前提是模型的准确率并没有降低。这个trick就很好用了。
4. 参考
【1】https://blog.csdn.net/HUSTHY/article/details/108226256
【2】https://zhuanlan.zhihu.com/p/86441879
【3】https://zhuanlan.zhihu.com/p/145427849
【4】https://blog.csdn.net/qq_38410428/article/details/119392993
【5】https://blog.csdn.net/wangkaidehao/article/details/104411682
【6】https://zhuanlan.zhihu.com/p/102697821
更多推荐
所有评论(0)