转载自我的个人网站 https://wzw21.cn/2022/02/20/hello-pytorch/
在 PyTorch For Audio and Music Processing 入门代码的基础上添加了一些注释和新的内容
- Download dataset
- Create data loader
- Build model
- Train
- Save trained model
- Load model
- Predict
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
def download_mnist_datasets():
train_data = datasets.MNIST(
root="data",
download=True,
train=True,
transform=ToTensor()
)
val_data = datasets.MNIST(
root="data",
download=True,
train=False,
transform=ToTensor()
)
return train_data, val_data
class SimpleNet(nn.Module):
def __init__(self): # constructor
super().__init__()
self.flatten = nn.Flatten()
self.dense_layers = nn.Sequential(
nn.Linear(28*28, 256), # Fully Connected layer (input_shape, output_shape)
nn.ReLU(),
nn.Linear(256, 10)
)
self.softmax = nn.Softmax(dim=1)
def forward(self, input_data):
flattened_data = self.flatten(input_data)
logits = self.dense_layers(flattened_data) # logits here means the input of the final softmax
predictions = self.softmax(logits)
return predictions
Need more code than Tensorflow 2.x or Keras!
def train_one_epoch(model, data_loader, loss_fn, optimizer, device):
model.train() # change to train mode
loss_sum = 0.
correct = 0
for inputs, targets in data_loader:
inputs, targets = inputs.to(device), targets.to(device)
# calculate loss
predictions = model(inputs) # this will call forward function automatically
loss = loss_fn(predictions, targets)
# backpropagate loss and update weights
optimizer.zero_grad() # reset grads
loss.backward() # calculate grads
optimizer.step() # update weights
loss_sum += loss.item() # item() returns the value of this tensor as a standard Python number
with torch.no_grad():
_, predictions_indexes = torch.max(predictions, 1) # get predicted indexes
correct += torch.sum(predictions_indexes == targets)
# or correct += (predictions.argmax(1) == targets).type(torch.float).sum().item()
print(f"Train loss: {(loss_sum / len(data_loader)):.4f}, train accuracy: {(correct / len(data_loader.dataset)):.4f}")
def val_one_epoch(model, data_loader, loss_fn, device):
model.eval() # change to eval mode
loss_sum = 0.
correct = 0
with torch.no_grad():
for inputs, targets in data_loader:
inputs, targets = inputs.to(device), targets.to(device)
predictions = model(inputs)
loss = loss_fn(predictions, targets)
loss_sum += loss.item()
_, predictions_indexes = torch.max(predictions, 1)
correct += torch.sum(predictions_indexes == targets)
print(f"Validation loss: {(loss_sum / len(data_loader)):.4f}, validation accuracy: {(correct / len(data_loader.dataset)):.4f}")
def train(model, train_data_loader, val_data_loader, loss_fn, optimizer, device, epochs):
for i in range(epochs):
print(f"Epoch {i+1}")
train_one_epoch(model, train_data_loader, loss_fn, optimizer, device)
val_one_epoch(model, val_data_loader, loss_fn, device)
print("-----------------------")
print("Training is done")
def predict(model, input, target, class_mapping):
# input's shape = torch.Size([1, 28, 28])
model.eval() # change to eval mode
with torch.no_grad(): # don't have to calculate grads here
predictions = model(input)
# predictions' shape = torch.Size([1, 10])
predicted_index = predictions[0].argmax(0)
predicted = class_mapping[predicted_index]
expected = class_mapping[target]
return predicted, expected
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print(f"Using {device} device")
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = .001
class_mapping = [
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9"
]
Using cuda device
# download MNIST dataset
train_data, val_data = download_mnist_datasets()
print("Dataset downloaded")
# create a data loader for the train set
train_data_loader = DataLoader(train_data, batch_size=BATCH_SIZE)
val_data_loader = DataLoader(val_data, batch_size=BATCH_SIZE)
Dataset downloaded
# build model
simple_net = SimpleNet().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(simple_net.parameters(), lr=LEARNING_RATE)
# train model
train(simple_net, train_data_loader, val_data_loader, loss_fn, optimizer, device, EPOCHS)
# save model
torch.save(simple_net.state_dict(), "simple_net.pth")
print("Model saved")
# torch.save(model.state_dict(), "my_model.pth") # only save parameters
# torch.save(model, "my_model.pth") # save the whole model
# checkpoint = {"net": model.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch}
Epoch 1
Train loss: 1.5717, train accuracy: 0.9036
Validation loss: 1.5280, validation accuracy: 0.9388
-----------------------
Epoch 2
Train loss: 1.5148, train accuracy: 0.9506
Validation loss: 1.5153, validation accuracy: 0.9507
-----------------------
Epoch 3
Train loss: 1.5008, train accuracy: 0.9629
Validation loss: 1.5016, validation accuracy: 0.9625
-----------------------
Epoch 4
Train loss: 1.4924, train accuracy: 0.9707
Validation loss: 1.4958, validation accuracy: 0.9680
-----------------------
Epoch 5
Train loss: 1.4871, train accuracy: 0.9760
Validation loss: 1.4919, validation accuracy: 0.9702
-----------------------
Epoch 6
Train loss: 1.4837, train accuracy: 0.9789
Validation loss: 1.4884, validation accuracy: 0.9742
-----------------------
Epoch 7
Train loss: 1.4811, train accuracy: 0.9814
Validation loss: 1.4885, validation accuracy: 0.9736
-----------------------
Epoch 8
Train loss: 1.4787, train accuracy: 0.9837
Validation loss: 1.4896, validation accuracy: 0.9724
-----------------------
Epoch 9
Train loss: 1.4771, train accuracy: 0.9851
Validation loss: 1.4884, validation accuracy: 0.9739
-----------------------
Epoch 10
Train loss: 1.4758, train accuracy: 0.9863
Validation loss: 1.4889, validation accuracy: 0.9732
-----------------------
Training is done
Model saved
# load model
reloaded_simple_net = SimpleNet()
state_dict = torch.load("simple_net.pth")
reloaded_simple_net.load_state_dict(state_dict)
# make an inference
input, target = val_data[0][0], val_data[0][1]
predicted, expected = predict(reloaded_simple_net, input, target, class_mapping)
print(f"Predicted: '{predicted}', expected: '{expected}'")
Predicted: '7', expected: '7'