加载imagenet预训练模型:
首先需要安装timm库:
pip install timm
timm库使用的话参考https://rwightman.github.io/pytorch-image-models/
这里加载了输入图像大小为224*224,patch大小为16*16的vit模型
from timm import create_model as creat model = creat('vit_base_patch16_224', pretrained=True, num_classes=1000)
从头开始训练模型的话需要安装vit-pytorch库:
pip install vit-pytorch
这里建议添加豆瓣镜像源来加速下载:
pip config set global.index-url https://pypi.douban.com/simple/ # 统一更改为豆瓣源(用户) pip config set global.index-url https://pypi.python.org/simple/ #恢复默认
这里写了个数据加载模块来实现对于imagenet的验证集无类别标签文件夹的情况的图像标签读取
class ImageTxtDataset(data.Dataset): def __init__(self, txt_path: str, folder_name, transform): self.transform = transform self.data_dir = os.path.dirname(txt_path) self.imgs_path = [] self.labels = [] self.folder_name = folder_name with open(txt_path, 'r') as f: lines = f.readlines() for line in lines: img_path, label = line.split() label = int(label.strip()) img_path = os.path.join(self.data_dir, self.folder_name, img_path) self.labels.append(label) self.imgs_path.append(img_path) def __len__(self): return len(self.imgs_path) def __getitem__(self, i): path, label = self.imgs_path[i], self.labels[i] image = Image.open(path).convert("RGB") if self.transform is not None: image = self.transform(image) return image, label
完整代码如下:
import argparse from timm import create_model as creat import torch import torch.nn as nn import torch.utils.data as data import matplotlib.pyplot as plt import os from PIL import Image import random import shutil import time import warnings import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.distributed as dist import torch.optim import torch.multiprocessing as mp import torch.utils.data import torch.utils.data.distributed import torchvision.transforms as transforms plt.ion() # interactive mode from vit_pytorch import ViT class ImageTxtDataset(data.Dataset): def __init__(self, txt_path: str, folder_name, transform): self.transform = transform self.data_dir = os.path.dirname(txt_path) self.imgs_path = [] self.labels = [] self.folder_name = folder_name with open(txt_path, 'r') as f: lines = f.readlines() for line in lines: img_path, label = line.split() label = int(label.strip()) img_path = os.path.join(self.data_dir, self.folder_name, img_path) self.labels.append(label) self.imgs_path.append(img_path) def __len__(self): return len(self.imgs_path) def __getitem__(self, i): path, label = self.imgs_path[i], self.labels[i] image = Image.open(path).convert("RGB") if self.transform is not None: image = self.transform(image) return image, label parser = argparse.ArgumentParser(description='VIT ImageNet Training') parser.add_argument('--train_label', metavar='TRAIN_LABEL_DIR', default='E:/imagenet/train_label.txt', help='path to train_label') parser.add_argument('--val_label', metavar='VAL_LABEL_DIR', default='E:/imagenet/validation_label.txt', help='path to validation_label') parser.add_argument('--train_data', metavar='TRAIN', default='train', help='path to train_dataset') parser.add_argument('--val_data', metavar='VAL', default='val', help='path to val_dataset') parser.add_argument('-j', '--workers', default=0, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=300, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') parser.add_argument('-b', '--batch-size', default=1, type=int, metavar='N', help='mini-batch size (default: 256), this is the total ' 'batch size of all GPUs on the current node when ' 'using Data Parallel or Distributed Data Parallel') parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') parser.add_argument('-p', '--print-freq', default=10, type=int, metavar='N', help='print frequency (default: 10)') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training') parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training') parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, help='url used to set up distributed training') parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') parser.add_argument('--multiprocessing-distributed', action='store_true', help='Use multi-processing distributed training to launch ' 'N processes per node, which has N GPUs. This is the ' 'fastest way to use PyTorch for either single node or ' 'multi node data parallel training') best_acc1 = 0 def main(): args = parser.parse_args() if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') if args.gpu is not None: warnings.warn('You have chosen a specific GPU. This will completely ' 'disable data parallelism.') if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) args.distributed = args.world_size > 1 or args.multiprocessing_distributed ngpus_per_node = torch.cuda.device_count() # print(ngpus_per_node) if args.multiprocessing_distributed: args.world_size = ngpus_per_node * args.world_size mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) else: main_worker(args.gpu, ngpus_per_node, args) def main_worker(gpu, ngpus_per_node, args): global best_acc1 print(args) args.gpu = gpu if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if not args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) if args.pretrained: model = creat('vit_base_patch16_224', pretrained=True, num_classes=1000) else: model = ViT( image_size = 224, patch_size = 16, num_classes = 1000, dim = 1024, depth = 6, heads = 16, mlp_dim = 2048, dropout = 0.1, emb_dropout = 0.1 ) if not torch.cuda.is_available(): print('using CPU, this will be slow') elif args.distributed: if args.gpu is not None: torch.cuda.set_device(args.gpu) model.cuda(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) else: model.cuda() model = torch.nn.parallel.DistributedDataParallel(model) elif args.gpu is not None: torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) else: model = torch.nn.DataParallel(model).cuda() # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda(args.gpu) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) if args.gpu is None: checkpoint = torch.load(args.resume) else: # Map model to be loaded to specified single gpu. loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.resume, map_location=loc) args.start_epoch = checkpoint['epoch'] best_acc1 = checkpoint['best_acc1'] if args.gpu is not None: # best_acc1 may be from a checkpoint from a different GPU best_acc1 = best_acc1.to(args.gpu) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = ImageTxtDataset(args.train_label, args.train_data, transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: train_sampler = None # args.workers=0 train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) val_dataset = ImageTxtDataset(args.val_label, args.val_data, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize,])) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if args.evaluate: validate(val_loader, model, criterion, args) return for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) adjust_learning_rate(optimizer, epoch, args) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, args) # evaluate on validation set acc1 = validate(val_loader, model, criterion, args) # remember best acc@1 and save checkpoint is_best = acc1 > best_acc1 best_acc1 = max(acc1, best_acc1) if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): save_checkpoint({ 'epoch': epoch + 1, 'arch': 'vit', 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 'optimizer' : optimizer.state_dict(), }, is_best) def train(train_loader, model, criterion, optimizer, epoch, args): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter( len(train_loader), [batch_time, data_time, losses, top1, top5], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() # print('adg') for i, (images, target) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) if args.gpu is not None: images = images.cuda(args.gpu, non_blocking=True) if torch.cuda.is_available(): target = target.cuda(args.gpu, non_blocking=True) # compute output output = model(images) loss = criterion(output, target) # measure accuracy and record loss acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), images.size(0)) top1.update(acc1[0], images.size(0)) top5.update(acc5[0], images.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) def validate(val_loader, model, criterion, args): batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter( len(val_loader), [batch_time, losses, top1, top5], prefix='Test: ') # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(val_loader): if args.gpu is not None: images = images.cuda(args.gpu, non_blocking=True) if torch.cuda.is_available(): target = target.cuda(args.gpu, non_blocking=True) # compute output output = model(images) loss = criterion(output, target) # measure accuracy and record loss acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), images.size(0)) top1.update(acc1[0], images.size(0)) top5.update(acc5[0], images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' .format(top1=top1, top5=top5)) return top1.avg def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): torch.save(state, filename) if is_best: shutil.copyfile(filename, 'model_best.pth.tar') class AverageMeter(object): def __init__(self, name, fmt=':f'): self.name = name self.fmt = fmt self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__(self): fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' return fmtstr.format(**self.__dict__) class ProgressMeter(object): def __init__(self, num_batches, meters, prefix=""): self.batch_fmtstr = self._get_batch_fmtstr(num_batches) self.meters = meters self.prefix = prefix def display(self, batch): entries = [self.prefix + self.batch_fmtstr.format(batch)] entries += [str(meter) for meter in self.meters] print('\t'.join(entries)) def _get_batch_fmtstr(self, num_batches): num_digits = len(str(num_batches // 1)) fmt = '{:' + str(num_digits) + 'd}' return '[' + fmt + '/' + fmt.format(num_batches) + ']' def adjust_learning_rate(optimizer, epoch, args): lr = args.lr * (0.1 ** (epoch // 30)) for param_group in optimizer.param_groups: param_group['lr'] = lr def accuracy(output, target, topk=(1,)): with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res if __name__ == '__main__': main() 使用方法如下:
文件名为main.py
CPU训练;
python main.py --lr 0.001 --data (自己的数据集文件路径)
单GPU从头训练:
python main.py --gpu 0 --lr 0.001 --data (自己的数据集文件路径)
单GPU使用预训练模型训练:
python main.py --gpu 0 --lr 0.001 --data (自己的数据集文件路径)--pretrained
多GPU从头训练(四卡为例):
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --data --lr 0.001(自己的数据集文件路径) --world-size 1 --rank 0 --b 64
多GPU使用预训练模型训练(四卡为例):
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --data --lr 0.001(自己的数据集文件路径) --world-size 1 --rank 0 --b 64 --pretrained