Pytorch-Lightning基本方法介绍

文章目录
LIGHTNINGMODULE
  Minimal Example
  一些基本方法
    Training
      Training loop
      Validation loop
      Test loop
    Inference
      Inference in research
      Inference in production
  LightningModule API(略)
LIGHTNINGMODULE
LightningModule将PyTorch代码整理成5个部分:

  • Computations (init).
  • Train loop (training_step)
  • Validation loop (validation_step)
  • Test loop (test_step)
  • Optimizers (configure_optimizers)

Minimal Example
所需要的方法:

import pytorch_lightning as pl
class LitModel(pl.LightningModule):

     def __init__(self):
         super().__init__()
         self.l1 = torch.nn.Linear(28 * 28, 10)

     def forward(self, x):
         return torch.relu(self.l1(x.view(x.size(0), -1)))

     def training_step(self, batch, batch_idx):
         x, y = batch
         y_hat = self(x)
         loss = F.cross_entropy(y_hat, y)
         return loss

     def configure_optimizers(self):
         return torch.optim.Adam(self.parameters(), lr=0.02)

使用下面的代码进行训练:

train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()))
trainer = pl.Trainer()
model = LitModel()

trainer.fit(model, train_loader)

一些基本方法

Training

Training loop

使用training_step方法来增加training loop

class LitClassifier(pl.LightningModule):

     def __init__(self, model):
         super().__init__()
         self.model = model

     def training_step(self, batch, batch_idx):
         x, y = batch
         y_hat = self.model(x)
         loss = F.cross_entropy(y_hat, y)
         return loss

如果需要在epoch-level进行度量,并进行记录,可以使用*.log*方法

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)

    # logs metrics for each training_step,
    # and the average across the epoch, to the progress bar and logger
    self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
    return loss

如果需要对每个training_step的输出做一些操作,可以通过改写training_epoch_end来实现

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    preds = ...
    return {'loss': loss, 'other_stuff': preds}

def training_epoch_end(self, training_step_outputs):
   for pred in training_step_outputs:
       # do something

如果需要对每个batch分配到不同GPU上进行训练,可以采用training_step_end方法来实现

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    pred = ...
    return {'loss': loss, 'pred': pred}

def training_step_end(self, batch_parts):
    gpu_0_prediction = batch_parts.pred[0]['pred']
    gpu_1_prediction = batch_parts.pred[1]['pred']

    # do something with both outputs
    return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2

def training_epoch_end(self, training_step_outputs):
   for out in training_step_outputs:
       # do something with preds
Validation loop

增加一个validation loop,可以通过改写LightningModule中的validation_step来实现

class LitModel(pl.LightningModule):
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('val_loss', loss)

对validation进行epoch-level度量,可以通过改写validation_epoch_end实现

def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    pred =  ...
    return pred

def validation_epoch_end(self, validation_step_outputs):
   for pred in validation_step_outputs:
       # do something with a pred

如果需要validation进行数据并行计算(多GPU),可以通过validation_step_end方法实现

def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    pred = ...
    return {'loss': loss, 'pred': pred}

def validation_step_end(self, batch_parts):
    gpu_0_prediction = batch_parts.pred[0]['pred']
    gpu_1_prediction = batch_parts.pred[1]['pred']

    # do something with both outputs
    return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2

def validation_epoch_end(self, validation_step_outputs):
   for out in validation_step_outputs:
       # do something with preds
Test loop

增加一个test loop的过程和上面增加validation loop是相同的,唯一不同的是,只有在使用*.test()*的时候,test loop才会被调用

model = Model()
trainer = Trainer()
trainer.fit()

# automatically loads the best weights for you
trainer.test(model)

这里,有两种方式调用test():

# call after training
trainer = Trainer()
trainer.fit(model)

# automatically auto-loads the best weights
trainer.test(test_dataloaders=test_dataloader)

# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
trainer.test(model, test_dataloaders=test_dataloader)    

Inference

对于研究,LightningModules像系统一样结构化

import pytorch_lightning as pl
import torch
from torch import nn

class Autoencoder(pl.LightningModule):

     def __init__(self, latent_dim=2):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim))
        self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28))

     def training_step(self, batch, batch_idx):
        x, _ = batch

        # encode
        x = x.view(x.size(0), -1)
        z = self.encoder(x)

        # decode
        recons = self.decoder(z)

        # reconstruction
        reconstruction_loss = nn.functional.mse_loss(recons, x)
        return reconstruction_loss

     def validation_step(self, batch, batch_idx):
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        recons = self.decoder(z)
        reconstruction_loss = nn.functional.mse_loss(recons, x)
        self.log('val_reconstruction', reconstruction_loss)

     def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0002)

可以用如下方式训练

autoencoder = Autoencoder()
trainer = pl.Trainer(gpus=1)
trainer.fit(autoencoder, train_dataloader, val_dataloader)

lightning inference部分的方法:

  • training_step
  • validation_step
  • test_step
  • configure_optimizers

注意到在这个例子中,train loop和val loop完全相同,我们可以重复使用这部分代码

class Autoencoder(pl.LightningModule):

     def __init__(self, latent_dim=2):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim))
        self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28))

     def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch)

        return loss

     def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch)
        self.log('val_loss', loss)

     def shared_step(self, batch):
        x, _ = batch

        # encode
        x = x.view(x.size(0), -1)
        z = self.encoder(x)

        # decode
        recons = self.decoder(z)

        # loss
        return nn.functional.mse_loss(recons, x)

     def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0002)

注:我们创建了所有loop都可以使用的一个新方法shared_step,这个方法的名字可以任意取

Inference in research

如果需要进行系统推断,可以将forward方法加入到LightningModule中

class Autoencoder(pl.LightningModule):
    def forward(self, x):
        return self.decoder(x)

在复杂系统中增加forward的优势,使得可以进行包含inference procedure等

class Seq2Seq(pl.LightningModule):

    def forward(self, x):
        embeddings = self(x)
        hidden_states = self.encoder(embeddings)
        for h in hidden_states:
            # decode
            ...
        return decoded
Inference in production

在LightningModule中迭代不同的模型

import pytorch_lightning as pl
from pytorch_lightning.metrics import functional as FM

class ClassificationTask(pl.LightningModule):

     def __init__(self, model):
         super().__init__()
         self.model = model

     def training_step(self, batch, batch_idx):
         x, y = batch
         y_hat = self.model(x)
         loss = F.cross_entropy(y_hat, y)
         return loss

     def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        acc = FM.accuracy(y_hat, y)

        # loss is tensor. The Checkpoint Callback is monitoring 'checkpoint_on'
        metrics = {'val_acc': acc, 'val_loss': loss}
        self.log_dict(metrics)
        return metrics

     def test_step(self, batch, batch_idx):
        metrics = self.validation_step(batch, batch_idx)
        metrics = {'test_acc': metrics['val_acc'], 'test_loss': metrics['val_loss']}
        self.log_dict(metrics)

     def configure_optimizers(self):
         return torch.optim.Adam(self.model.parameters(), lr=0.02)

然后将任意适合该task的模型传进去

for model in [resnet50(), vgg16(), BidirectionalRNN()]:
    task = ClassificationTask(model)

    trainer = Trainer(gpus=2)
    trainer.fit(task, train_dataloader, val_dataloader)

tasks可以任意复杂,比如,可以实现GAN训练,self-supervised,甚至RL

class GANTask(pl.LightningModule):

     def __init__(self, generator, discriminator):
         super().__init__()
         self.generator = generator
         self.discriminator = discriminator
     ...

del)

trainer = Trainer(gpus=2)
trainer.fit(task, train_dataloader, val_dataloader)
tasks可以任意复杂,比如,可以实现GAN训练,self-supervised,甚至RL

```python
class GANTask(pl.LightningModule):

     def __init__(self, generator, discriminator):
         super().__init__()
         self.generator = generator
         self.discriminator = discriminator
     ...

 

上一篇:【codevs1026】逃跑的拉尔夫


下一篇:强大的 debug工具