[pytorch] 运行一段时间后 GPU OOM

2022-09-28 10:46:12

pytorch的dataloader会将数据传到GPU上,这个过程GPU的mem占用会逐渐增加,为了避免GPUmen被无用的数据占用,可以在每个step后用del删除一些变量,也可以使用torch.cuda.empty_cache()释放显存:

del targets, input_k, input_mask
torch.cuda.empty_cache()

这时能观察到GPU的显存一直在动态变化。

但是上述方式不是一个根本的解决方案,因为他受到峰值的影响很大。比如某个batch的数据量明显大于其他batch,可能模型处理该batch时显存会不够用,这也会导致OOM,虽然其他的batch都能顺利执行。

显存的占用跟这几个因素相关:

  • 模型参数量
  • batch size
  • 一个batch的数据 size

通常我们不希望改变模型参数量,所以只能通过动态调整batch-size,使得一个batch的数据 size不会导致显存OOM:

ilen = int(sorted_data[start][1]['input'][0]['shape'][0])
olen = int(sorted_data[start][1]['output'][0]['shape'][0])
# if ilen = 1000 and max_length_in = 800
# then b = batchsize / 2
# and max(1, .) avoids batchsize = 0
# 太长的句子会被动态改变bsz,单独成一个batch,否则padding的部分就太多了,数据量太大,OOM
factor = max(int(ilen / max_length_in), int(olen / max_length_out))
b = max(1, int(batch_size / (1 + factor)))
#b = batch_size
end = min(len(sorted_data), start + b)
minibatch.append(sorted_data[start:end])
if end == len(sorted_data):
    break
start = end

此外,如何选择一个合适的batchsize也是个很重要的问题,我们可以先对所有数据按照大小(长短)排好序(降序),不进行shuffle,按照64,32,16依次尝试bsz,如果模型在执行第一个batch的时候没出现OOM,那么以后一定也不会出现OOM(因为降序排列了数据,所以前面的batch的数据size最大)。

  • 作者:ASR_THU
  • 原文链接:https://blog.csdn.net/zongza/article/details/98647490
    更新时间:2022-09-28 10:46:12