解决网络训练验证过程中显存增加的原因

最近在训练网络时发现网络训练了几个epoch之后就会出现OOM
一开始以为是内存不够,后来才发现是在网络训练过程中,显存会不断的增加。
针对以上的问题,查找资料总结了三种有用的方式

  1. 训练过程过程中,保存参数加.item()
    原代码:
def train_one_epoch(
    model, criterion, train_dataloader, optimizer, epoch, clip_max_norm
):
    model.train()
    device = next(model.parameters()).device
    train_loss = 0
   
    for i, d in enumerate(train_dataloader):
        d = d.to(device)

        optimizer.zero_grad()
        out_net = model(d)
        loss = criterion(out_net, d, epoch)
        train_loss += loss

        loss.backward()
        if clip_max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm)
        optimizer.step()

更改后:

def train_one_epoch(
    model, criterion, train_dataloader, optimizer, epoch, clip_max_norm
):
    model.train()
    device = next(model.parameters()).device
    train_loss = 0
   
    for i, d in enumerate(train_dataloader):
        d = d.to(device)

        optimizer.zero_grad()
        out_net = model(d)
        loss = criterion(out_net, d, epoch)
        train_loss += loss.item()

        loss.backward()
        if clip_max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm)
        optimizer.step()

原因可以参考:https://zhuanlan.zhihu.com/p/85838274

  1. model.eval()放在with torch.no_grad()后:
    原代码:
def test_epoch(epoch, test_dataloader, model, criterion):
    model.eval()
    device = next(model.parameters()).device
    valid_loss = 0

    with torch.no_grad():
        for d in test_dataloader:
			d = d.to(device)
			out_net = model(d)
			loss = criterion(out_net, d, epoch)
			valid_loss += loss.item()
			

修改后:

def test_epoch(epoch, test_dataloader, model, criterion):
    valid_loss = 0

    with torch.no_grad():
	    model.eval()
	    device = next(model.parameters()).device
        for d in test_dataloader:
			d = d.to(device)
			out_net = model(d)
			loss = criterion(out_net, d, epoch)
			valid_loss += loss.item()

3.使用torch.cuda.empty_cache()清空不用显存:

def test_epoch(epoch, test_dataloader, model, criterion):
    valid_loss = 0

    with torch.no_grad():
	    model.eval()
	    device = next(model.parameters()).device
        for d in test_dataloader:
			d = d.to(device)
			out_net = model(d)
			loss = criterion(out_net, d, epoch)
			valid_loss += loss.item()
	torch.cuda.empty_cache()

使用参考:https://www.i4k.xyz/article/zxyhhjs2017/92795831

上一篇:格创东智选择 TDengine,实现海量数据实时全生命周期管理


下一篇:pytorch代码练习