pytorch环境下进行模型测试总是报OOM(out of memory)的解决办法

2022-10-23 09:49:11

问题描述:

  • 当模型训练好以后,进行测试或者线上部署模型时候,时长会遇到OOM(out of memory)的问题;
    在pytorch框架下进行模型训练时候,会进行梯度计算,生成计算图(大量的中间变量)以用于反向传播,计算图的生成会占用较大的显存,这时的梯度计算是不可或缺的;而在模型训练好以后,若只是想要得到一个测试结果,那么就无须再进行梯度计算,不然,若显存被大量占用,而又不能及时释放时,或者分配的量大于释放的量时,显存将会一点点或者一下子被消耗殆尽,然后发出异常(OOM)!
    在这里插入图片描述

那么,该如何解决呢?

  • 下面,就提供一种解决方法,亲测有效:
    在需要进行梯度计算的地方加上with torch.no_grad()函数,该函数的作用是取消计算图的生成,即无须再进行梯度计算用于返现传播,这将大大减少显存的占用,亦可使已分配的显存得到及时释放。具体使用如:
with torch.no_grad():
     ret= model(inputs)
  • 作者:默语..
  • 原文链接:https://blog.csdn.net/weixin_42565090/article/details/118661068
    更新时间:2022-10-23 09:49:11