pyTorch节省显存

基本上,PyTorch所做的是,每当我通过网络传递数据并将计算存储在GPU内存中时,它都会创建一个计算图,以防我想在反向传播期间计算梯度。但由于我只想执行正向传播,所以只需要为模型指定torch.no_grad()。

因此,我的代码中的for循环可以重写为:

pyTorch节省显存

为我的模型指定no_grad()会告诉PyTorch我不想存储任何以前的计算,从而释放我的GPU空间。

上一篇:李宏毅 机器学习(2017)学习笔记——1-梯度下降法实例


下一篇:梯度计算