注意 .to(device)就是把数据从内存放到GPU显存
import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader class RandomDataset(Dataset): def __init__(self, size, length): self.len = length self.data = torch.randn(length, size) def __getitem__(self, index): return self.data[index] def __len__(self): return self.len class Model(nn.Module): # Our model def __init__(self, input_size, output_size): super(Model, self).__init__() self.fc = nn.Linear(input_size, output_size) def forward(self, input): output = self.fc(input) print("\tIn Model: input size", input.size(), "output size", output.size()) return output # Parameters and DataLoaders input_size = 5 output_size = 2 batch_size = 30 data_size = 100 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #这里没用到分布式 rand_loader = DataLoader( dataset=RandomDataset(input_size, data_size), batch_size=batch_size, shuffle=True ) #模型定义时没有用到分布式 model = Model(input_size, output_size) if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs model = nn.DataParallel(model) model.to(device) for data in rand_loader: input = data.to(device) output = model(input) print("Outside: input size", input.size(), "output_size", output.size())
单卡
2卡
如果直接 python lian.py, 会直接用到10卡
If you have no GPU or one GPU, when we batch 30 inputs and 30 outputs, the model gets 30 and outputs 30 as expected. But if you have multiple GPUs, then the result will be different.
DataParallel splits your data automatically and sends job orders to multiple models on several GPUs. After each model finishes their job, DataParallel collects and merges the results before returning it to you.
相关文章
- 01-30Pytorch之cifar10数据集训练
- 01-30[pytorch]单多机下多GPU下分布式负载均衡训练
- 01-30Pytorch之分布式训练 —— Data Parallel
- 01-30pytorch 之 保存不同形式的预训练模型
- 01-30pytorch分布式训练方法总结
- 01-30PyTorch分布式训练详解教程 scatter, gather & isend, irecv & all_reduce & DDP
- 01-30云原生的弹性 AI 训练系列之二:PyTorch 1.9.0 弹性分布式训练的设计与实现
- 01-30[源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler
- 01-30(七十八)springcloud+springboot+uniapp+vue b2b2c 分布式微服务电子商务商城之Spring Cloud集成Spring Data
- 01-30pytorch模型加DDP进行单机多卡分布式训练