python多图拼接并利用resnet提取特征

代码功能:

1、将多张图拼接成一张大图;

2、基于resnet提取大图的特征。

import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.autograd import Variable
import numpy as np
from PIL import Image  
from os import listdir
#resnet50提取图像特征
transform1 = transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()])
resnet50_feature_extractor = models.resnet50(pretrained = True)
resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
for param in resnet50_feature_extractor.parameters():
    param.requires_grad = False

images = [Image.open(fn) for fn in listdir() if fn.endswith('.png')]
#多图拼接成一张长图
if len(images)>0:
    width,height=images[0].size
    for image in images:
        w,h = image.size
        if w>width:width=w
        if h>height: height=h
    longImg =  Image.new(images[0].mode,(width,height*len(images)))
    for i,im in enumerate(images):
        longImg.paste(im,box=(0,i*height))#拼接
        
    #长图向量化    
    imgarr = np.array(longImg)
    if imgarr.shape[2] == 4: #四通道转为三通道
        img1 = img.convert("RGB")
    #img = Image.fromarray(img.astype('uint8')).convert('RGB')
    img2 = transform1(img1)
    x = Variable(torch.unsqueeze(img2, dim=0).float(), requires_grad=False)
    y = resnet50_feature_extractor(x)
    y = y.data.numpy()
    print (y.shape)

 

上一篇:resnet_v2.py源码详解 resnet v1与resnet v2


下一篇:[资源]ResNet caffemodel[百度网盘]